Shortcuts

unike.data.BernSampler 源代码

# coding:utf-8
#
# unike/data/BernSampler.py
#
# created by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 30, 2024
# updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Feb 25, 2024
#
# 该脚本定义了 BernSampler 类.

"""
BernSampler - 平移模型和语义匹配模型的训练集数据采样器。
"""

import torch
import typing
import random
import collections
import numpy as np
from .TradSampler import TradSampler
from typing_extensions import override

[文档]class BernSampler(TradSampler): """ 平移模型和语义匹配模型的训练集 Bern 数据采样器(伯努利分布),如果想获得更详细的信息请访问 :ref:`TransH <transh>`。 """
[文档] def __init__( self, in_path: str = "./", ent_file: str = "entity2id.txt", rel_file: str = "relation2id.txt", train_file: str = "train2id.txt", batch_size: int | None = None, neg_ent: int = 1): """创建 BernSampler 对象。 :param in_path: 数据集目录 :type in_path: str :param ent_file: entity2id.txt :type ent_file: str :param rel_file: relation2id.txt :type rel_file: str :param train_file: train2id.txt :type train_file: str :param batch_size: batch size 在该采样器中不起作用,只是占位符。 :type batch_size: int | None :param neg_ent: 对于每一个正三元组, 构建的负三元组的个数, 替换 entity :type neg_ent: int """ super().__init__( in_path=in_path, ent_file=ent_file, rel_file=rel_file, train_file=train_file, batch_size = batch_size, neg_ent = neg_ent ) self.tph, self.hpt = self.get_tph_hpt()
[文档] def get_tph_hpt(self) -> tuple[collections.defaultdict[float], collections.defaultdict[float]]: """计算 tph 和 hpt。 :returns: tph 和 hpt :rtype: tuple[collections.defaultdict[float], collections.defaultdict[float]] """ h_of_r = collections.defaultdict(set) t_of_r = collections.defaultdict(set) freq_rel = collections.defaultdict(float) tph = collections.defaultdict(float) hpt = collections.defaultdict(float) for h, r, t in self.train_triples: freq_rel[r] += 1.0 h_of_r[r].add(h) t_of_r[r].add(t) for r in h_of_r: tph[r] = freq_rel[r] / len(h_of_r[r]) hpt[r] = freq_rel[r] / len(t_of_r[r]) return tph, hpt
[文档] @override def sampling( self, pos_triples: list[tuple[int, int, int]]) -> dict[str, typing.Union[str, torch.Tensor]]: """平移模型和语义匹配模型的训练集 bern 的数据采样函数(伯努利分布)。 :param pos_triples: 知识图谱中的正确三元组 :type pos_triples: list[tuple[int, int, int]] :returns: 平移模型和语义匹配模型的训练数据 :rtype: dict[str, typing.Union[str, torch.Tensor]] """ batch_data = {} neg_ent_sample = [] batch_data['mode'] = 'bern' for h, r, t in pos_triples: neg_ent = self.__normal_batch(h, r, t, self.neg_ent) neg_ent_sample += neg_ent batch_data["positive_sample"] = torch.LongTensor(np.array(pos_triples)) batch_data["negative_sample"] = torch.LongTensor(np.array(neg_ent_sample)) return batch_data
def __normal_batch( self, h: int, r: int, t: int, neg_size: int) -> list[tuple[int, int, int]]: """Bern 负采样函数 :param h: 头实体 :type h: int :param r: 关系 :type r: int :param t: 尾实体 :type t: int :param neg_size: 负三元组个数 :type neg_size: int :returns: 负三元组中的头实体列表 :rtype: list[tuple[int, int, int]] """ neg_size_h = 0 neg_size_t = 0 prob = self.hpt[r] / (self.hpt[r] + self.tph[r]) for _ in range(neg_size): if random.random() < prob: neg_size_t += 1 else: neg_size_h += 1 res = [] neg_list_h = [] neg_cur_size = 0 while neg_cur_size < neg_size_h: neg_tmp_h = self.corrupt_head(t, r, num_max=(neg_size_h - neg_cur_size) * 2) neg_list_h.append(neg_tmp_h) neg_cur_size += len(neg_tmp_h) if neg_list_h != []: neg_list_h = np.concatenate(neg_list_h) for hh in neg_list_h[:neg_size_h]: res.append((hh, r, t)) neg_list_t = [] neg_cur_size = 0 while neg_cur_size < neg_size_t: neg_tmp_t = self.corrupt_tail(h, r, num_max=(neg_size_t - neg_cur_size) * 2) neg_list_t.append(neg_tmp_t) neg_cur_size += len(neg_tmp_t) if neg_list_t != []: neg_list_t = np.concatenate(neg_list_t) for tt in neg_list_t[:neg_size_t]: res.append((h, r, tt)) return res

Docs

Access comprehensive developer documentation for UniKE

View Docs