GAT.py
Class GCN(num_features:int , num_classes:int , hidden_channels:list=[8], num_heads:list=[8], dropout:float=0.6, attn_dropout:float=0.6):
GAT模型
PARAMETERS
num_features(int) - 输入层节点数,一般等于特征数。
num_classes(int) - 输出层节点数,一般等于类别数。
hidden_channels(list) - 隐藏层列表
num_heads(list) - 注意头数量列表
dropout(float) - 模型dropout率
attn_dropout(float) - 注意头dropout率
forward(x:tg.data.Data.x, edge_index:tg.data.Data.edge_index) -> Tensor
PARAMETERS
x(Tensor) - 节点特征张量,一般为二维张量,第一维表示节点,第二维表示某节点特征。
edge_index(Tensor) - 边张量,一般为2行e列,e为边数,2行对应列元素表示一条有向边。
Last updated