Shortcuts

unike.data.KGReader 源代码

# coding:utf-8
#
# unike/data/KGReader.py
#
# created by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 17, 2024
# updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Feb 25, 2024
#
# 从文件中读取知识图谱.

"""
KGReader - 从文件中读取知识图谱。
"""

import os
import numpy as np
import collections

[文档]class KGReader: """ 从文件中读取知识图谱。 """
[文档] def __init__( self, in_path: str = "./", ent_file: str = "entity2id.txt", rel_file: str = "relation2id.txt", train_file: str = "train2id.txt"): """创建 KGReader 对象。 :param in_path: 数据集目录 :type in_path: str :param ent_file: entity2id.txt :type ent_file: str :param rel_file: relation2id.txt :type rel_file: str :param train_file: train2id.txt :type train_file: str """ #: 数据集目录 self.in_path: str = in_path #: entity2id.txt self.ent_file: str = ent_file #: relation2id.txt self.rel_file: str = rel_file #: train2id.txt self.train_file: str = train_file #: 实体的个数 self.ent_tol: int = 0 #: 关系的个数 self.rel_tol: int = 0 #: 训练集三元组的个数 self.train_tol: int = 0 #: 实体->ID self.ent2id: dict = {} #: 关系->ID self.rel2id: dict = {} #: ID->实体 self.id2ent: dict = {} #: ID->关系 self.id2rel: dict = {} #: 训练集三元组 self.train_triples: list[tuple[int, int, int]] = [] #: 训练集中所有 h-r 对对应的 t 集合 self.hr2t_train: collections.defaultdict[set] = collections.defaultdict(set) #: 训练集中所有 r-t 对对应的 h 集合 self.rt2h_train: collections.defaultdict[set] = collections.defaultdict(set) self.get_id() self.get_train_triples_id()
[文档] def get_id(self): """读取 :py:attr:`ent_file` 文件和 :py:attr:`rel_file` 文件。""" with open(os.path.join(self.in_path, self.ent_file)) as f: self.ent_tol = (int)(f.readline()) for line in f: entity, eid = line.strip().split("\t") self.ent2id[entity] = int(eid) self.id2ent[int(eid)] = entity with open(os.path.join(self.in_path, self.rel_file)) as f: self.rel_tol = (int)(f.readline()) for line in f: relation, rid = line.strip().split("\t") self.rel2id[relation] = int(rid) self.id2rel[int(rid)] = relation
[文档] def get_train_triples_id(self): """读取 :py:attr:`train_file` 文件。""" with open(os.path.join(self.in_path, self.train_file)) as f: self.train_tol = (int)(f.readline()) for line in f: h, t, r = line.strip().split() self.train_triples.append((int(h), int(r), int(t)))
[文档] def get_hr2t_rt2h_from_train(self): """获得 :py:attr:`hr2t_train` 和 :py:attr:`rt2h_train` 。""" for h, r, t in self.train_triples: self.hr2t_train[(h, r)].add(t) self.rt2h_train[(r, t)].add(h) for h, r in self.hr2t_train: self.hr2t_train[(h, r)] = np.array(list(self.hr2t_train[(h, r)])) for r, t in self.rt2h_train: self.rt2h_train[(r, t)] = np.array(list(self.rt2h_train[(r, t)]))
[文档] def get_hr_train(self): """用于 ``CompGCN`` :cite:`CompGCN` 训练,因为 ``CompGCN`` :cite:`CompGCN` 的组合运算仅需要头实体和关系。 如果想获得更详细的信息请访问 :ref:`CompGCN <compgcn>`。 """ self.t_triples = self.train_triples self.train_triples = [(hr, list(t)) for (hr,t) in self.hr2t_train.items()]
[文档] def get_train(self) -> list[tuple[int, int, int]]: """ 返回训练集三元组。 :returns: :py:attr:`train_triples` :rtype: list[tuple[int, int, int]] """ return self.train_triples

Docs

Access comprehensive developer documentation for UniKE

View Docs