Shortcuts

unike.module.strategy.RGCNSampling 源代码

# coding:utf-8
#
# unike/module/strategy/RGCNSampling.py
#
# created by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 16, 2023
# updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 18, 2023
#
# 该脚本定义了 R-GCN 模型的训练策略.

"""
NegativeSampling - 训练策略类,包含损失函数。
"""

import dgl
import torch
import typing
from ..loss import Loss
from ..model import Model
from .Strategy import Strategy

[文档]class RGCNSampling(Strategy): """ 将模型和损失函数封装到一起,方便模型训练,用于 ``R-GCN`` :cite:`R-GCN`。 例子:: from unike.data import GraphDataLoader from unike.module.model import RGCN from unike.module.loss import RGCNLoss from unike.module.strategy import RGCNSampling from unike.config import Trainer, GraphTester dataloader = GraphDataLoader( in_path = "../../benchmarks/FB15K237/", batch_size = 60000, neg_ent = 10, test = True, test_batch_size = 100, num_workers = 16 ) # define the model rgcn = RGCN( ent_tol = dataloader.train_sampler.ent_tol, rel_tol = dataloader.train_sampler.rel_tol, dim = 500, num_layers = 2 ) # define the loss function model = RGCNSampling( model = rgcn, loss = RGCNLoss(model = rgcn, regularization = 1e-5) ) # test the model tester = GraphTester(model = rgcn, data_loader = dataloader, use_gpu = True, device = 'cuda:0') # train the model trainer = Trainer(model = model, data_loader = dataloader.train_dataloader(), epochs = 10000, lr = 0.0001, use_gpu = True, device = 'cuda:0', tester = tester, test = True, valid_interval = 500, log_interval = 500, save_interval = 500, save_path = '../../checkpoint/rgcn.pth' ) trainer.run() """
[文档] def __init__( self, model: Model = None, loss: Loss = None): """创建 RGCNSampling 对象。 :param model: R-GCN 模型 :type model: :py:class:`unike.module.model.RGCN` :param loss: 损失函数。 :type loss: :py:class:`unike.module.loss.Loss` """ super(RGCNSampling, self).__init__() #: R-GCN 模型,即 :py:class:`unike.module.model.RGCN` self.model: Model = model #: 损失函数,即 :py:class:`unike.module.loss.Loss` self.loss: Loss = loss
[文档] def forward( self, data: dict[str, typing.Union[dgl.DGLGraph, torch.Tensor]]) -> torch.Tensor: """计算最后的损失值。定义每次调用时执行的计算。 :py:class:`torch.nn.Module` 子类必须重写 :py:meth:`torch.nn.Module.forward`。 :param data: 数据 :type data: dict[str, typing.Union[dgl.DGLGraph, torch.Tensor]] :returns: 损失值 :rtype: torch.Tensor """ graph = data["graph"] entity = data['entity'] relation = data['relation'] norm = data['norm'] triples = data["triples"] label = data["label"] score = self.model(graph, entity, relation, norm, triples) loss = self.loss(score, label) return loss
[文档]def get_rgcn_sampling_hpo_config() -> dict[str, dict[str, typing.Any]]: """返回 :py:class:`RGCNSampling` 的默认超参数优化配置。 默认配置为:: parameters_dict = { 'strategy': { 'value': 'RGCNSampling' } } :returns: :py:class:`RGCNSampling` 的默认超参数优化配置 :rtype: dict[str, dict[str, typing.Any]] """ parameters_dict = { 'strategy': { 'value': 'RGCNSampling' } } return parameters_dict

Docs

Access comprehensive developer documentation for UniKE

View Docs