Shortcuts

unike.data.TradSampler 源代码

# coding:utf-8
#
# unike/data/TradSampler.py
#
# created by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 28, 2024
# updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Feb 25, 2024
#
# 为 KGReader 增加构建负三元组的函数,用于平移模型和语义匹配模型.

"""
TradSampler - 为 KGReader 增加构建负三元组的函数,用于平移模型和语义匹配模型。
"""

import torch
import typing
import numpy as np
from .KGReader import KGReader

[文档]class TradSampler(KGReader): """ 平移模型和语义匹配模型的采样器的基类。 """
[文档] def __init__( self, in_path: str = "./", ent_file: str = "entity2id.txt", rel_file: str = "relation2id.txt", train_file: str = "train2id.txt", batch_size: int | None = None, neg_ent: int = 1): """创建 TradSampler 对象。 :param in_path: 数据集目录 :type in_path: str :param ent_file: entity2id.txt :type ent_file: str :param rel_file: relation2id.txt :type rel_file: str :param train_file: train2id.txt :type train_file: str :param batch_size: batch size 在该采样器中不起作用,只是占位符。 :type batch_size: int | None :param neg_ent: 对于每一个正三元组, 构建的负三元组的个数, 替换 entity :type neg_ent: int """ super().__init__( in_path=in_path, ent_file=ent_file, rel_file=rel_file, train_file=train_file ) #: batch size self.batch_size: int = batch_size #: 对于每一个正三元组, 构建的负三元组的个数, 替换 entity (head + tail) self.neg_ent: int = neg_ent self.get_hr2t_rt2h_from_train()
[文档] def sampling( self, pos_triples: list[tuple[int, int, int]]) -> dict[str, typing.Union[str, torch.Tensor]]: """平移模型和语义匹配模型的训练集数据采样函数。该方法未实现,子类必须重写该方法,否则抛出 :py:class:`NotImplementedError` 错误。 :param pos_triples: 知识图谱中的正确三元组 :type pos_triples: list[tuple[int, int, int]] :returns: 平移模型和语义匹配模型的训练数据 :rtype: dict[str, typing.Union[str, torch.Tensor]] """ raise NotImplementedError
[文档] def corrupt_head( self, t: int, r: int, num_max: int = 1) -> np.ndarray: """替换头实体构建负三元组。 :param t: 尾实体 :type t: int :param r: 关系 :type r: int :param num_max: 一次负采样的个数 :type num_max: int :returns: 负三元组的头实体列表 :rtype: numpy.ndarray """ tmp = torch.randint(low=0, high=self.ent_tol, size=(num_max,)).numpy() mask = np.in1d(tmp, self.rt2h_train[(r, t)], assume_unique=True, invert=True) neg = tmp[mask] return neg
[文档] def corrupt_tail( self, h: int, r: int, num_max: int = 1) -> np.ndarray: """替换尾实体构建负三元组。 :param h: 头实体 :type h: int :param r: 关系 :type r: int :param num_max: 一次负采样的个数 :type num_max: int :returns: 负三元组的尾实体列表 :rtype: numpy.ndarray """ tmp = torch.randint(low=0, high=self.ent_tol, size=(num_max,)).numpy() mask = np.in1d(tmp, self.hr2t_train[(h, r)], assume_unique=True, invert=True) neg = tmp[mask] return neg

Docs

Access comprehensive developer documentation for UniKE

View Docs