Shortcuts

unike.data.TestSampler 源代码

# coding:utf-8
#
# unike/data/TestSampler.py
#
# created by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 29, 2024
# updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Feb 25, 2024
#
# 测试数据采样器基类.

"""
TestSampler - 测试数据采样器基类。
"""

import os
import torch
import typing
from .TradSampler import TradSampler
from .RGCNSampler import RGCNSampler
from .CompGCNSampler import CompGCNSampler
from collections import defaultdict as ddict
from ..utils import construct_type_constrain

[文档]class TestSampler(object): """测试数据采样器基类。 """
[文档] def __init__( self, sampler: typing.Union[TradSampler, RGCNSampler, CompGCNSampler], valid_file: str = "valid2id.txt", test_file: str = "test2id.txt", type_constrain: bool = True): """创建 TestSampler 对象。 :param sampler: 训练数据采样器。 :type sampler: typing.Union[TradSampler, RGCNSampler, CompGCNSampler] :param valid_file: valid2id.txt :type valid_file: str :param test_file: test2id.txt :type test_file: str :param type_constrain: 是否报告 type_constrain.txt 限制的测试结果 :type type_constrain: bool """ #: 训练数据采样器 self.sampler: typing.Union[TradSampler, RGCNSampler, CompGCNSampler] = sampler #: 实体的个数 self.ent_tol: int = sampler.ent_tol #: valid2id.txt self.valid_file: str = valid_file #: test2id.txt self.test_file: str = test_file #: 验证集三元组的个数 self.valid_tol: int = 0 #: 测试集三元组的个数 self.test_tol: int = 0 #: 验证集三元组 self.valid_triples: list[tuple[int, int, int]] = [] #: 测试集三元组 self.test_triples: list[tuple[int, int, int]] = [] #: 知识图谱所有三元组 self.all_true_triples: set[tuple[int, int, int]] = set() self.get_valid_test_triples_id() #: 知识图谱中所有 h-r 对对应的 t 集合 self.hr2t_all: ddict[set] = ddict(set) #: 知识图谱中所有 r-t 对对应的 h 集合 self.rt2h_all: ddict[set] = ddict(set) #: 是否报告 type_constrain.txt 限制的测试结果 self.type_constrain: bool = type_constrain if self.type_constrain: construct_type_constrain( in_path = self.sampler.in_path, train_file = self.sampler.train_file, valid_file = self.valid_file, test_file = self.test_file ) #: 知识图谱中所有 r 存在头实体种类 self.rel_heads: ddict[set] = ddict(set) #: 知识图谱中所有 r 存在尾实体种类 self.rel_tails: ddict[set] = ddict(set) self.get_type_constrain_id()
[文档] def get_valid_test_triples_id(self): """读取 :py:attr:`valid_file` 文件和 :py:attr:`test_file` 文件。""" with open(os.path.join(self.sampler.in_path, self.valid_file)) as f: self.valid_tol = (int)(f.readline()) for line in f: h, t, r = line.strip().split() self.valid_triples.append((int(h), int(r), int(t))) with open(os.path.join(self.sampler.in_path, self.test_file)) as f: self.test_tol = (int)(f.readline()) for line in f: h, t, r = line.strip().split() self.test_triples.append((int(h), int(r), int(t))) self.all_true_triples = set( self.sampler.train_triples + self.valid_triples + self.test_triples )
[文档] def get_type_constrain_id(self): """读取 type_constrain.txt 文件。""" with open(os.path.join(self.sampler.in_path, "type_constrain.txt")) as f: rel_tol = (int)(f.readline()) first_line = True for line in f: rel_types = line.strip().split("\t") for entity in rel_types[2:]: if first_line: self.rel_heads[int(rel_types[0])].add(int(entity)) else: self.rel_tails[int(rel_types[0])].add(int(entity)) first_line = not first_line for rel in self.rel_heads: self.rel_heads[rel] = torch.tensor(list(self.rel_heads[rel])) for rel in self.rel_tails: self.rel_tails[rel] = torch.tensor(list(self.rel_tails[rel]))
[文档] def get_hr2t_rt2h_from_all(self): """获得 :py:attr:`hr2t_all` 和 :py:attr:`rt2h_all` 。""" for h, r, t in self.all_true_triples: self.hr2t_all[(h, r)].add(t) self.rt2h_all[(r, t)].add(h) for h, r in self.hr2t_all: self.hr2t_all[(h, r)] = torch.tensor(list(self.hr2t_all[(h, r)])) for r, t in self.rt2h_all: self.rt2h_all[(r, t)] = torch.tensor(list(self.rt2h_all[(r, t)]))
[文档] def sampling( self, data: list[tuple[int, int, int]]) -> dict[str, torch.Tensor]: """采样函数。该方法未实现,子类必须重写该方法,否则抛出 :py:class:`NotImplementedError` 错误。 :param data: 测试的正确三元组 :type data: list[tuple[int, int, int]] :returns: 测试数据 :rtype: dict[str, torch.Tensor] """ raise NotImplementedError
[文档] def get_valid(self) -> list[tuple[int, int, int]]: """ 返回验证集三元组。 :returns: :py:attr:`valid_triples` :rtype: list[tuple[int, int, int]] """ return self.valid_triples
[文档] def get_test(self) -> list[tuple[int, int, int]]: """ 返回测试集三元组。 :returns: :py:attr:`test_triples` :rtype: list[tuple[int, int, int]] """ return self.test_triples
[文档] def get_all_true_triples(self) -> set[tuple[int, int, int]]: """ 返回知识图谱所有三元组。 :returns: :py:attr:`all_true_triples` :rtype: set[tuple[int, int, int]] """ return self.all_true_triples

Docs

Access comprehensive developer documentation for UniKE

View Docs