TTV.py

class TTV(data:torch_geometric.data.Data, model:nn.Module, optimizer:torch.optim.Optimizer, criterion:nn.Module, epochs:int, wandbname: str=""):

TTV, train valid test类,用于简化训练代码。

PARAMETERS

  • data(Tensor) - 数据,一般为一张图,包含了x,edge_index,mask等等

  • model(nn.Module) - 模型

  • optimizer(torch.optim.Optimizer) - 优化器

  • criterion(nn.Module) - 损失函数

  • epochs(int) - 训练轮数

  • wandbname(str) - wandb项目名称,为空时不上传wandb。

train() -> None

训练,包含valid,无打印,会上传wandb,如果指定了wandbname。

test() -> None

测试。打印acc。

tocsv(name:str='GCN') -> None

把模型预测的节点信息保存到csv。名称为name+_nodes.csv,两列,第一列node_id,第二列label。

Last updated