unike.module.loss.CompGCNLoss 源代码
# coding:utf-8
#
# unike/module/loss/CompGCNLoss.py
#
# created by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 19, 2024
# updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on May 23, 2024
#
# 该脚本定义了 CompGCNLoss 类.
"""
CompGCNLoss - 损失函数类,CompGCN 原论文中应用这种损失函数完成模型学习。
"""
import torch
from .Loss import Loss
from typing import Any
from ..model import CompGCN
[文档]class CompGCNLoss(Loss):
"""
``CompGCN`` :cite:`CompGCN` 原论文中应用这种损失函数完成模型训练。
.. Note:: :py:meth:`forward` 中的正样本评分函数的得分应大于负样本评分函数的得分。
例子::
from unike.module.loss import CompGCNLoss
from unike.module.strategy import CompGCNSampling
# define the loss function
model = CompGCNSampling(
model = compgcn,
loss = CompGCNLoss(model = compgcn),
ent_tol = dataloader.get_ent_tol()
)
"""
[文档] def __init__(
self,
model: CompGCN):
"""创建 CompGCNLoss 对象。
:param model: 模型
:type model: CompGCN
"""
super(CompGCNLoss, self).__init__()
#: 模型
self.model: CompGCN = model
#: 损失函数
self.loss: torch.nn.BCELoss = torch.nn.BCELoss()
[文档] def forward(
self,
pred: torch.Tensor,
label: torch.Tensor) -> torch.Tensor:
"""计算 CompGCNLoss 损失函数。定义每次调用时执行的计算。:py:class:`torch.nn.Module` 子类必须重写 :py:meth:`torch.nn.Module.forward`。
:param pred: 模型的得分。
:type pred: torch.Tensor
:param labels: 标签
:type labels: torch.Tensor
:returns: 损失值
:rtype: torch.Tensor
"""
loss = self.loss(pred, label)
return loss
[文档]def get_compgcn_loss_hpo_config() -> dict[str, dict[str, Any]]:
"""返回 :py:class:`CompGCNLoss` 的默认超参数优化配置。
默认配置为::
parameters_dict = {
'loss': {
'value': 'CompGCNLoss'
}
}
:returns: :py:class:`CompGCNLoss` 的默认超参数优化配置
:rtype: dict[str, dict[str, typing.Any]]
"""
parameters_dict = {
'loss': {
'value': 'CompGCNLoss'
}
}
return parameters_dict