unike.module.strategy.NegativeSampling 源代码
# coding:utf-8
#
# unike/module/strategy/NegativeSampling.py
#
# git pull from OpenKE-PyTorch by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on May 7, 2023
# updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on May 9, 2024
#
# 该脚本定义了平移模型和语义匹配模型的训练策略.
"""
NegativeSampling - 训练策略类,包含损失函数。
"""
import torch
import typing
from ..loss import Loss
from ..model import Model
from .Strategy import Strategy
[文档]class NegativeSampling(Strategy):
"""
将模型和损失函数封装到一起,方便模型训练。
例子::
from unike.module.model import TransE
from unike.module.loss import MarginLoss
from unike.module.strategy import NegativeSampling
# define the model
transe = TransE(
ent_tol = dataloader.get_ent_tol(),
rel_tol = dataloader.get_rel_tol(),
dim = 50,
p_norm = 1,
norm_flag = True
)
# define the loss function
model = NegativeSampling(
model = transe,
loss = MarginLoss(margin = 1.0),
regul_rate = 0.01
)
"""
[文档] def __init__(
self,
model: Model = None,
loss: Loss = None,
regul_rate: float = 0.0,
l3_regul_rate: float = 0.0):
"""创建 NegativeSampling 对象。
:param model: KGE 模型
:type model: :py:class:`unike.module.model.Model`
:param loss: 损失函数。
:type loss: :py:class:`unike.module.loss.Loss`
:param regul_rate: 权重衰减系数
:type regul_rate: float
:param l3_regul_rate: l3 正则化系数
:type l3_regul_rate: float
"""
super(NegativeSampling, self).__init__()
#: KGE 模型,即 :py:class:`unike.module.model.Model`
self.model: Model = model
#: 损失函数,即 :py:class:`unike.module.loss.Loss`
self.loss: Loss = loss
#: 权重衰减系数
self.regul_rate: float = regul_rate
#: l3 正则化系数
self.l3_regul_rate: float = l3_regul_rate
[文档] def forward(self, data: dict[str, typing.Union[torch.Tensor, str]]) -> torch.Tensor:
"""计算最后的损失值。定义每次调用时执行的计算。
:py:class:`torch.nn.Module` 子类必须重写 :py:meth:`torch.nn.Module.forward`。
:param data: 数据
:type data: dict[str, typing.Union[torch.Tensor, str]]
:returns: 损失值
:rtype: torch.Tensor
"""
pos_sample = data["positive_sample"]
neg_sample = data["negative_sample"]
mode = data["mode"]
pos_score = self.model(pos_sample)
if mode == "bern":
neg_score = self.model(neg_sample)
neg_score = neg_score.view(pos_score.shape[0], -1)
else:
neg_score = self.model(pos_sample, neg_sample, mode)
loss_res = self.loss(pos_score, neg_score)
if self.regul_rate != 0:
loss_res += self.regul_rate * self.model.regularization(data)
if self.l3_regul_rate != 0:
loss_res += self.l3_regul_rate * self.model.l3_regularization()
return loss_res
[文档]def get_negative_sampling_hpo_config() -> dict[str, dict[str, typing.Any]]:
"""返回 :py:class:`NegativeSampling` 的默认超参数优化配置。
默认配置为::
parameters_dict = {
'strategy': {
'value': 'NegativeSampling'
},
'regul_rate': {
'value': 0.0
},
'l3_regul_rate': {
'value': 0.0
}
}
:returns: :py:class:`NegativeSampling` 的默认超参数优化配置
:rtype: dict[str, dict[str, typing.Any]]
"""
parameters_dict = {
'strategy': {
'value': 'NegativeSampling'
},
'regul_rate': {
'value': 0.0
},
'l3_regul_rate': {
'value': 0.0
}
}
return parameters_dict