Shortcuts

unike.module.loss.MarginLoss 源代码

# coding:utf-8
#
# unike/module/loss/MarginLoss.py
#
# git pull from OpenKE-PyTorch by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on May 7, 2023
# updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on May 9, 2024
#
# 该脚本定义了 margin-based ranking criterion 损失函数.

"""
MarginLoss - 损失函数类,TransE 原论文中应用这种损失函数完成模型学习。
"""

import torch
import numpy as np
from typing import Any
import torch.nn as nn
import torch.nn.functional as F
from .Loss import Loss

[文档]class MarginLoss(Loss): """ ``TransE`` :cite:`TransE` 原论文中应用这种损失函数完成模型训练。 .. Note:: :py:meth:`forward` 中的正样本评分函数的得分应小于负样本评分函数的得分。 例子:: from unike.module.model import TransE from unike.module.loss import MarginLoss from unike.module.strategy import NegativeSampling # define the model transe = TransE( ent_tol = dataloader.get_ent_tol(), rel_tol = dataloader.get_rel_tol(), dim = 50, p_norm = 1, norm_flag = True ) # define the loss function model = NegativeSampling( model = transe, loss = MarginLoss(margin = 1.0), regul_rate = 0.01 ) """
[文档] def __init__( self, adv_temperature: float | None = None, margin: float = 6.0): """创建 MarginLoss 对象。 :param adv_temperature: RotatE 提出的自我对抗负采样中的温度。 :type adv_temperature: float :param margin: gamma。 :type margin: float """ super(MarginLoss, self).__init__() #: gamma self.margin: torch.nn.parameter.Parameter = nn.Parameter(torch.Tensor([margin])) self.margin.requires_grad = False if adv_temperature != None: #: RotatE 提出的自我对抗负采样中的温度。 self.adv_temperature: torch.nn.parameter.Parameter = nn.Parameter(torch.Tensor([adv_temperature])) self.adv_temperature.requires_grad = False #: 是否启用 RotatE 提出的自我对抗负采样。 self.adv_flag: bool = True else: self.adv_flag: bool = False
[文档] def get_weights( self, n_score: torch.Tensor) -> torch.Tensor: """计算 RotatE 提出的自我对抗负采样中的负样本的分布概率。 :param n_score: 负样本评分函数的得分。 :type n_score: torch.Tensor :returns: 自我对抗负采样中的负样本的分布概率 :rtype: torch.Tensor """ return F.softmax(-n_score * self.adv_temperature, dim = -1).detach()
[文档] def forward( self, p_score: torch.Tensor, n_score: torch.Tensor) -> torch.Tensor: """计算 margin-based ranking criterion 损失函数。定义每次调用时执行的计算。 :py:class:`torch.nn.Module` 子类必须重写 :py:meth:`torch.nn.Module.forward`。 :param p_score: 正样本评分函数的得分。 :type p_score: torch.Tensor :param n_score: 负样本评分函数的得分。 :type n_score: torch.Tensor :returns: 损失值 :rtype: torch.Tensor """ if self.adv_flag: return (self.get_weights(n_score) * torch.max(p_score - n_score, -self.margin)).sum(dim = -1).mean() + self.margin else: return (torch.max(p_score - n_score, -self.margin)).mean() + self.margin
[文档]def get_margin_loss_hpo_config() -> dict[str, dict[str, Any]]: """返回 :py:class:`MarginLoss` 的默认超参数优化配置。 默认配置为:: parameters_dict = { 'loss': { 'value': 'MarginLoss' }, 'adv_temperature': { 'values': [1.0, 3.0, 6.0] }, 'margin': { 'values': [1.0, 3.0, 6.0] } } :returns: :py:class:`MarginLoss` 的默认超参数优化配置 :rtype: dict[str, dict[str, typing.Any]] """ parameters_dict = { 'loss': { 'value': 'MarginLoss' }, 'adv_temperature': { 'values': [1.0, 3.0, 6.0] }, 'margin': { 'values': [1.0, 3.0, 6.0] } } return parameters_dict

Docs

Access comprehensive developer documentation for UniKE

View Docs