Shortcuts

unike.data.RevSampler 源代码

# coding:utf-8
#
# unike/data/RevSampler.py
#
# created by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 15, 2024
# updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 16, 2024
#
# 为 KGReader 增加相反关系,用于图神经网络模型.

"""
RevSampler - 为 KGReader 增加相反关系,用于图神经网络模型。
"""

import os
from .KGReader import KGReader

[文档]class RevSampler(KGReader): """在训练集中增加相反关系. 对于每一个三元组 (h, r, t),生成相反关系三元组 (t, r`, h): r` = r + rel_tol。 """
[文档] def __init__( self, in_path: str = "./", ent_file: str = "entity2id.txt", rel_file: str = "relation2id.txt", train_file: str = "train2id.txt"): """创建 RevSampler 对象。 :param in_path: 数据集目录 :type in_path: str :param ent_file: entity2id.txt :type ent_file: str :param rel_file: relation2id.txt :type rel_file: str :param train_file: train2id.txt :type train_file: str """ super().__init__( in_path=in_path, ent_file=ent_file, rel_file=rel_file, train_file=train_file ) self.add_reverse_relation() self.add_train_reverse_triples() self.get_hr2t_rt2h_from_train()
[文档] def add_reverse_relation(self): """增加相反关系:r` = r + rel_tol""" with open(os.path.join(self.in_path, self.rel_file)) as f: f.readline() for line in f: relation, rid = line.strip().split("\t") self.rel2id[relation + "_reverse"] = int(rid) + self.rel_tol self.id2rel[int(rid) + self.rel_tol] = relation + "_reverse" self.rel_tol = len(self.rel2id)
[文档] def add_train_reverse_triples(self): """对于每一个三元组 (h, r, t),生成相反关系三元组 (t, r`, h): r` = r + rel_tol。""" tol = int(self.rel_tol / 2) with open(os.path.join(self.in_path, self.train_file)) as f: f.readline() for line in f: h, t, r = line.strip().split() self.train_triples.append( (int(t), int(r) + tol, int(h)) )

Docs

Access comprehensive developer documentation for UniKE

View Docs