Shortcuts

unike.module.strategy.CompGCNSampling 源代码

# coding:utf-8
#
# unike/module/strategy/CompGCNSampling.py
#
# created by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 16, 2023
# updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 20, 2023
#
# 该脚本定义了 CompGCN 模型的训练策略.

"""
CompGCNSampling - 训练策略类,包含损失函数。
"""

import dgl
import torch
import typing
from ..loss import Loss
from ..model import CompGCN
from .Strategy import Strategy

[文档]class CompGCNSampling(Strategy): """ 将模型和损失函数封装到一起,方便模型训练,用于 ``CompGCN`` :cite:`CompGCN`。 例子:: from unike.module.model import CompGCN from unike.module.loss import CompGCNLoss from unike.module.strategy import CompGCNSampling from unike.config import Trainer, GraphTester # define the model compgcn = CompGCN( ent_tol = dataloader.train_sampler.ent_tol, rel_tol = dataloader.train_sampler.rel_tol, dim = 100 ) # define the loss function model = CompGCNSampling( model = compgcn, loss = CompGCNLoss(model = compgcn), ent_tol = dataloader.train_sampler.ent_tol ) # test the model tester = GraphTester(model = compgcn, data_loader = dataloader, use_gpu = True, device = 'cuda:0', prediction = "tail") # train the model trainer = Trainer(model = model, data_loader = dataloader.train_dataloader(), epochs = 2000, lr = 0.0001, use_gpu = True, device = 'cuda:0', tester = tester, test = True, valid_interval = 50, log_interval = 50, save_interval = 50, save_path = '../../checkpoint/compgcn.pth' ) trainer.run() """
[文档] def __init__( self, model: CompGCN = None, loss: Loss = None, smoothing: float = 0.1, ent_tol: int = None): """创建 CompGCNSampling 对象。 :param model: CompGCN 模型 :type model: :py:class:`unike.module.model.CompGCN` :param loss: 损失函数。 :type loss: :py:class:`unike.module.loss.Loss` :param smoothing: smoothing :type smoothing: float :param ent_tol: 实体个数 :type ent_tol: int """ super(CompGCNSampling, self).__init__() #: CompGCN 模型,即 :py:class:`unike.module.model.CompGCN` self.model: CompGCN = model #: 损失函数,即 :py:class:`unike.module.loss.Loss` self.loss: Loss = loss #: smoothing self.smoothing: float = smoothing #: 实体个数 self.ent_tol: int = ent_tol
[文档] def forward( self, data: dict[str, typing.Union[dgl.DGLGraph, torch.Tensor]]) -> torch.Tensor: """计算最后的损失值。定义每次调用时执行的计算。 :py:class:`torch.nn.Module` 子类必须重写 :py:meth:`torch.nn.Module.forward`。 :param data: 数据 :type data: dict[str, typing.Union[dgl.DGLGraph, torch.Tensor]] :returns: 损失值 :rtype: torch.Tensor """ graph = data["graph"] relation = data['relation'] norm = data['norm'] sample = data["sample"] label = data["label"] score = self.model(graph, relation, norm, sample) label = (1.0 - self.smoothing) * label + (1.0 / self.ent_tol) loss = self.loss(score, label) return loss
[文档]def get_compgcn_sampling_hpo_config() -> dict[str, dict[str, typing.Any]]: """返回 :py:class:`CompGCNSampling` 的默认超参数优化配置。 默认配置为:: parameters_dict = { 'strategy': { 'value': 'CompGCNSampling' }, 'smoothing': { 'value': 0.1 } } :returns: :py:class:`CompGCNSampling` 的默认超参数优化配置 :rtype: dict[str, dict[str, typing.Any]] """ parameters_dict = { 'strategy': { 'value': 'CompGCNSampling' }, 'smoothing': { 'value': 0.1 } } return parameters_dict

Docs

Access comprehensive developer documentation for UniKE

View Docs