Shortcuts

unike.data.UniSampler 源代码

# coding:utf-8
#
# unike/data/UniSampler.py
#
# created by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 16, 2024
# updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 29, 2024
#
# 平移模型和语义匹配模型的训练集数据采样器.

"""
UniSampler - 平移模型和语义匹配模型的训练集数据采样器。
"""

import torch
import typing
import warnings
import numpy as np
from .TradSampler import TradSampler
from typing_extensions import override

warnings.filterwarnings("ignore")

[文档]class UniSampler(TradSampler): """平移模型和语义匹配模型的训练集普通的数据采样器(均值分布)。 """
[文档] def __init__( self, in_path: str = "./", ent_file: str = "entity2id.txt", rel_file: str = "relation2id.txt", train_file: str = "train2id.txt", batch_size: int | None = None, neg_ent: int = 1): """创建 UniSampler 对象。 :param in_path: 数据集目录 :type in_path: str :param ent_file: entity2id.txt :type ent_file: str :param rel_file: relation2id.txt :type rel_file: str :param train_file: train2id.txt :type train_file: str :param batch_size: batch size 在该采样器中不起作用,只是占位符。 :type batch_size: int | None :param neg_ent: 对于每一个正三元组, 构建的负三元组的个数, 替换 entity :type neg_ent: int """ super().__init__( in_path=in_path, ent_file=ent_file, rel_file=rel_file, train_file=train_file, batch_size = batch_size, neg_ent = neg_ent ) self.cross_sampling_flag = 0
[文档] @override def sampling( self, pos_triples: list[tuple[int, int, int]]) -> dict[str, typing.Union[str, torch.Tensor]]: """平移模型和语义匹配模型的训练集普通的数据采样函数(均匀分布)。 :param pos_triples: 知识图谱中的正确三元组 :type pos_triples: list[tuple[int, int, int]] :returns: 平移模型和语义匹配模型的训练数据 :rtype: dict[str, typing.Union[str, torch.Tensor]] """ batch_data = {} neg_ent_sample = [] self.cross_sampling_flag = 1 - self.cross_sampling_flag if self.cross_sampling_flag == 0: batch_data['mode'] = "head-batch" for h, r, t in pos_triples: neg_head = self.head_batch(t, r, self.neg_ent) neg_ent_sample.append(neg_head) else: batch_data['mode'] = "tail-batch" for h, r, t in pos_triples: neg_tail = self.tail_batch(h, r, self.neg_ent) neg_ent_sample.append(neg_tail) batch_data["positive_sample"] = torch.LongTensor(np.array(pos_triples)) batch_data['negative_sample'] = torch.LongTensor(np.array(neg_ent_sample)) return batch_data
[文档] def head_batch( self, t: int, r: int, neg_size: int= None) -> np.ndarray: """替换头实体构建负三元组。 :param t: 尾实体 :type t: int :param r: 关系 :type r: int :param neg_size: 负三元组个数 :type neg_size: int :returns: 负三元组中的头实体列表 :rtype: numpy.ndarray """ neg_list = [] neg_cur_size = 0 while neg_cur_size < neg_size: neg_tmp = self.corrupt_head(t, r, num_max=(neg_size - neg_cur_size) * 2) neg_list.append(neg_tmp) neg_cur_size += len(neg_tmp) return np.concatenate(neg_list)[:neg_size]
[文档] def tail_batch( self, h: int, r: int, neg_size: int = None) -> np.ndarray: """替换尾实体构建负三元组。 :param h: 头实体 :type h: int :param r: 关系 :type r: int :param neg_size: 负三元组个数 :type neg_size: int :returns: 负三元组中的尾实体列表 :rtype: numpy.ndarray """ neg_list = [] neg_cur_size = 0 while neg_cur_size < neg_size: neg_tmp = self.corrupt_tail(h, r, num_max=(neg_size - neg_cur_size) * 2) neg_list.append(neg_tmp) neg_cur_size += len(neg_tmp) return np.concatenate(neg_list)[:neg_size]

Docs

Access comprehensive developer documentation for UniKE

View Docs