unike.data.RGCNTestSampler 源代码
# coding:utf-8
#
# unike/data/RGCNTestSampler.py
#
# created by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 16, 2024
# updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 29, 2024
#
# R-GCN 的测试数据采样器.
"""
RGCNTestSampler - R-GCN 的测试数据采样器。
"""
import os
import dgl
import torch
import typing
import numpy as np
from .TestSampler import TestSampler
from .RGCNSampler import RGCNSampler
from .CompGCNSampler import CompGCNSampler
from typing_extensions import override
[文档]class RGCNTestSampler(TestSampler):
"""``R-GCN`` :cite:`R-GCN` 的测试数据采样器。
例子::
from unike.data import RGCNTestSampler, CompGCNTestSampler
from torch.utils.data import DataLoader
#: 测试数据采样器
test_sampler: typing.Union[typing.Type[RGCNTestSampler], typing.Type[CompGCNTestSampler]] = test_sampler(
sampler=train_sampler,
valid_file=valid_file,
test_file=test_file,
)
#: 验证集三元组
data_val: list[tuple[int, int, int]] = test_sampler.get_valid()
#: 测试集三元组
data_test: list[tuple[int, int, int]] = test_sampler.get_test()
val_dataloader = DataLoader(
data_val,
shuffle=False,
batch_size=test_batch_size,
num_workers=num_workers,
pin_memory=True,
collate_fn=test_sampler.sampling,
)
test_dataloader = DataLoader(
data_test,
shuffle=False,
batch_size=test_batch_size,
num_workers=num_workers,
pin_memory=True,
collate_fn=test_sampler.sampling,
)
"""
[文档] def __init__(
self,
sampler: typing.Union[RGCNSampler, CompGCNSampler],
valid_file: str = "valid2id.txt",
test_file: str = "test2id.txt",
type_constrain: bool = True):
"""创建 RGCNTestSampler 对象。
:param sampler: 训练数据采样器。
:type sampler: typing.Union[RGCNSampler, CompGCNSampler]
:param valid_file: valid2id.txt
:type valid_file: str
:param test_file: test2id.txt
:type test_file: str
:param type_constrain: 是否报告 type_constrain.txt 限制的测试结果
:type type_constrain: bool
"""
super().__init__(
sampler=sampler,
valid_file=valid_file,
test_file=test_file,
type_constrain = type_constrain
)
#: 训练集三元组
self.triples: list[tuple[int, int, int]] = self.sampler.t_triples if isinstance(self.sampler, CompGCNSampler) else self.sampler.train_triples
#: 幂
self.power: float = -1
self.add_valid_test_reverse_triples()
self.get_hr2t_rt2h_from_all()
[文档] @override
def get_valid_test_triples_id(self):
"""读取 :py:attr:`valid_file` 文件和 :py:attr:`test_file` 文件。"""
with open(os.path.join(self.sampler.in_path, self.valid_file)) as f:
self.valid_tol = (int)(f.readline())
for line in f:
h, t, r = line.strip().split()
self.valid_triples.append((int(h), int(r), int(t)))
with open(os.path.join(self.sampler.in_path, self.test_file)) as f:
self.test_tol = (int)(f.readline())
for line in f:
h, t, r = line.strip().split()
self.test_triples.append((int(h), int(r), int(t)))
[文档] def add_valid_test_reverse_triples(self):
"""对于每一个三元组 (h, r, t),生成相反关系三元组 (t, r`, h): r` = r + rel_tol。"""
tol = int(self.sampler.rel_tol / 2)
with open(os.path.join(self.sampler.in_path, self.valid_file)) as f:
f.readline()
for line in f:
h, t, r = line.strip().split()
self.valid_triples.append(
(int(t), int(r) + tol, int(h))
)
with open(os.path.join(self.sampler.in_path, self.test_file)) as f:
f.readline()
for line in f:
h, t, r = line.strip().split()
self.test_triples.append(
(int(t), int(r) + tol, int(h))
)
self.all_true_triples = set(
self.triples + self.valid_triples + self.test_triples
)
[文档] @override
def get_type_constrain_id(self):
"""读取 type_constrain.txt 文件。"""
tol = int(self.sampler.rel_tol / 2)
with open(os.path.join(self.sampler.in_path, "type_constrain.txt")) as f:
rel_tol = (int)(f.readline())
first_line = True
for line in f:
rel_types = line.strip().split("\t")
for entity in rel_types[2:]:
if first_line:
self.rel_heads[int(rel_types[0])].add(int(entity))
self.rel_tails[int(rel_types[0]) + tol].add(int(entity))
else:
self.rel_tails[int(rel_types[0])].add(int(entity))
self.rel_heads[int(rel_types[0]) + tol].add(int(entity))
first_line = not first_line
for rel in self.rel_heads:
self.rel_heads[rel] = torch.tensor(list(self.rel_heads[rel]))
for rel in self.rel_tails:
self.rel_tails[rel] = torch.tensor(list(self.rel_tails[rel]))
[文档] @override
def sampling(
self,
data: list[tuple[int, int, int]]) -> dict[str, typing.Union[dgl.DGLGraph , torch.Tensor]]:
"""``R-GCN`` :cite:`R-GCN` 的测试数据采样函数。
:param data: 测试的正确三元组
:type data: list[tuple[int, int, int]]
:returns: ``R-GCN`` :cite:`R-GCN` 的测试数据
:rtype: dict[str, typing.Union[dgl.DGLGraph , torch.Tensor]]
"""
batch_data = {}
head_label = torch.zeros(len(data), self.ent_tol)
tail_label = torch.zeros(len(data), self.ent_tol)
for idx, triple in enumerate(data):
head, rel, tail = triple
head_label[idx][self.rt2h_all[(rel, tail)]] = 1.0
tail_label[idx][self.hr2t_all[(head, rel)]] = 1.0
if self.type_constrain:
head_label_type = torch.ones(len(data), self.ent_tol)
tail_laebl_type = torch.ones(len(data), self.ent_tol)
for idx, triple in enumerate(data):
head, rel, tail = triple
head_label_type[idx][self.rel_heads[rel]] = 0.0
tail_laebl_type[idx][self.rel_tails[rel]] = 0.0
head_label_type[idx][self.rt2h_all[(rel, tail)]] = 1.0
tail_laebl_type[idx][self.hr2t_all[(head, rel)]] = 1.0
batch_data["head_label_type"] = head_label_type
batch_data["tail_label_type"] = tail_laebl_type
batch_data["positive_sample"] = torch.tensor(data)
batch_data["head_label"] = head_label
batch_data["tail_label"] = tail_label
graph, rela, norm = self.sampler.build_graph(self.ent_tol, np.array(self.triples).transpose(), self.power)
batch_data["graph"] = graph
batch_data["rela"] = rela
batch_data["norm"] = norm
batch_data["entity"] = torch.arange(0, self.ent_tol, dtype=torch.long).view(-1,1)
return batch_data