备注
Go to the end to download the full example code
CompGCN-FB15K237-single-gpu || CompGCN-FB15K237-single-gpu-wandb || CompGCN-FB15K237-single-gpu-hpo
CompGCN-FB15K237-single-gpu-wandb¶
备注
created by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on May 7, 2023
备注
updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on May 23, 2024
备注
last run by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on May 23, 2024
这一部分介绍如何用一个 GPU 在 FB15K237 知识图谱上训练 CompGCN [VSNT20],使用 wandb 记录实验结果。
导入数据¶
UniKE 有一个工具用于导入数据: unike.data.KGEDataLoader。
import os
from unike.utils import WandbLogger
from unike.data import KGEDataLoader, CompGCNSampler, CompGCNTestSampler
from unike.module.model import CompGCN
from unike.module.loss import CompGCNLoss
from unike.module.strategy import CompGCNSampling
from unike.config import Trainer, Tester
首先初始化 unike.utils.WandbLogger 日志记录器,它是对 wandb 初始化操作的一层简单封装。
wandb_logger = WandbLogger().set_config(
project="unike",
name="compgcn",
config=dict(
in_path = os.path.join(os.path.dirname(__file__), '../../benchmarks/FB15K237/'),
batch_size = 2048,
test = True,
test_batch_size = 256,
num_workers = 16,
dim = 100,
use_gpu = True,
device = 'cuda:0',
prediction = "tail",
epochs = 2000,
lr = 0.0001,
valid_interval = 100,
log_interval = 100,
save_interval = 100,
save_path = '../../checkpoint/compgcn.pth'
)
)
config = wandb_logger.config
UniKE 提供了很多数据集,它们很多都是 KGE 原论文发表时附带的数据集。
unike.data.KGEDataLoader 包含 in_path 用于传递数据集目录。
dataloader = KGEDataLoader(
in_path = config.in_path,
batch_size = config.batch_size,
test = config.test,
test_batch_size = config.test_batch_size,
num_workers = config.num_workers,
train_sampler = CompGCNSampler,
test_sampler = CompGCNTestSampler
)
导入模型¶
UniKE 提供了很多 KGE 模型,它们都是目前最常用的基线模型。我们下面将要导入
unike.module.model.CompGCN,它提出于 2017 年,是第一个图神经网络模型,
# define the model
compgcn = CompGCN(
ent_tol = dataloader.get_ent_tol(),
rel_tol = dataloader.get_rel_tol(),
dim = config.dim
)
损失函数¶
我们这里使用了 CompGCN [VSNT20] 原论文使用的损失函数:unike.module.loss.CompGCNLoss,
unike.module.strategy.CompGCNSampling 对
unike.module.loss.CompGCNLoss 进行了封装。
# define the loss function
model = CompGCNSampling(
model = compgcn,
loss = CompGCNLoss(model = compgcn),
ent_tol = dataloader.get_ent_tol()
)
训练模型¶
UniKE 将训练循环包装成了 unike.config.Trainer,
可以运行它的 unike.config.Trainer.run() 函数进行模型学习;
也可以通过传入 unike.config.Tester,
使得训练器能够在训练过程中评估模型。
# test the model
tester = Tester(
model = compgcn, data_loader = dataloader,
use_gpu = config.use_gpu, device = config.device,
prediction = config.prediction
)
# train the model
trainer = Trainer(model = model, data_loader = dataloader.train_dataloader(),
epochs = config.epochs, lr = config.lr, use_gpu = config.use_gpu, device = config.device,
tester = tester, test = config.test, valid_interval = config.valid_interval,
log_interval = config.log_interval, save_interval = config.save_interval,
save_path = config.save_path, wandb_logger = wandb_logger
)
trainer.run()
备注
上述代码的运行日志可以从 此处 下载。
备注
上述代码的运行报告可以从 此处 下载。
Total running time of the script: ( 0 minutes 0.000 seconds)