Shortcuts

unike.module.model.Model 源代码

# coding:utf-8
#
# unike/module/model/Model.py
# 
# git pull from OpenKE-PyTorch by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on May 7, 2023
# updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Feb 25, 2023
# 
# 该头文件定义了 Model.

"""Model 类 - 所有 KGE 模型的基类"""

import torch
from ..BaseModule import BaseModule

[文档]class Model(BaseModule): """ 继承自 :py:class:`unike.module.BaseModule`,仅仅增加了两个属性::py:attr:`ent_tol` 和 :py:attr:`rel_tol`。 """
[文档] def __init__( self, ent_tol: int, rel_tol: int): """创建 Model 对象。 :param ent_tol: 实体的个数 :type ent_tol: int :param rel_tol: 关系的个数 :type rel_tol: int """ super(Model, self).__init__() #: 实体的种类 self.ent_tol: int = ent_tol #: 关系的种类 self.rel_tol: int = rel_tol
[文档] def forward(self) -> torch.Tensor: """ 定义每次调用时执行的计算。该方法未实现,子类必须重写该方法,否则抛出 :py:class:`NotImplementedError` 错误。 :py:class:`torch.nn.Module` 子类必须重写 :py:meth:`torch.nn.Module.forward`。 :returns: 三元组的得分 :rtype: torch.Tensor """ raise NotImplementedError
[文档] def predict(self) -> torch.Tensor: """ KGE 模型的推理方法。该方法未实现,子类必须重写该方法,否则抛出 :py:class:`NotImplementedError` 错误。 :returns: 三元组的得分 :rtype: torch.Tensor """ raise NotImplementedError
[文档] def tri2emb( self, triples: torch.Tensor, negs: torch.Tensor = None, mode: str = 'single') -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ 返回三元组对应的嵌入向量。 :param triples: 正确的三元组 :type triples: torch.Tensor :param negs: 负三元组类别 :type negs: torch.Tensor :param mode: 模式 :type triples: str :returns: 头实体、关系和尾实体的嵌入向量 :rtype: tuple[torch.Tensor, torch.Tensor, torch.Tensor] """ if mode == "single": head_emb = self.ent_embeddings(triples[:, 0]).unsqueeze(1) relation_emb = self.rel_embeddings(triples[:, 1]).unsqueeze(1) tail_emb = self.ent_embeddings(triples[:, 2]).unsqueeze(1) elif mode == "head-batch" or mode == "head_predict": if negs is None: head_emb = self.ent_embeddings.weight.data.unsqueeze(0) else: head_emb = self.ent_embeddings(negs) relation_emb = self.rel_embeddings(triples[:, 1]).unsqueeze(1) tail_emb = self.ent_embeddings(triples[:, 2]).unsqueeze(1) elif mode == "tail-batch" or mode == "tail_predict": head_emb = self.ent_embeddings(triples[:, 0]).unsqueeze(1) relation_emb = self.rel_embeddings(triples[:, 1]).unsqueeze(1) if negs is None: tail_emb = self.ent_embeddings.weight.data.unsqueeze(0) else: tail_emb = self.ent_embeddings(negs) return head_emb, relation_emb, tail_emb

Docs

Access comprehensive developer documentation for UniKE

View Docs