unike.data.CompGCNSampler 源代码
# coding:utf-8
#
# unike/data/CompGCNSampler.py
#
# created by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 19, 2023
# updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 19, 2023
#
# 该脚本定义了 CompGCNSampler 类.
"""
CompGCNSampler - CompGCN 的数据采样器。
"""
import dgl
import torch
import typing
import numpy as np
from .RGCNSampler import RGCNSampler
from typing_extensions import override
[文档]class CompGCNSampler(RGCNSampler):
"""``CompGCN`` :cite:`CompGCN` 的训练数据采样器。
例子::
from unike.data import RGCNSampler, CompGCNSampler
from torch.utils.data import DataLoader
#: 训练数据采样器
train_sampler: typing.Union[typing.Type[RGCNSampler], typing.Type[CompGCNSampler]] = train_sampler(
in_path=in_path,
ent_file=ent_file,
rel_file=rel_file,
train_file=train_file,
batch_size=batch_size,
neg_ent=neg_ent
)
#: 训练集三元组
data_train: list[tuple[int, int, int]] = train_sampler.get_train()
train_dataloader = DataLoader(
data_train,
shuffle=True,
batch_size=batch_size,
num_workers=num_workers,
pin_memory=True,
drop_last=True,
collate_fn=train_sampler.sampling,
)
"""
[文档] def __init__(
self,
in_path: str = "./",
ent_file: str = "entity2id.txt",
rel_file: str = "relation2id.txt",
train_file: str = "train2id.txt",
batch_size: int | None = None,
neg_ent: int = 1):
"""创建 CompGCNSampler 对象。
:param in_path: 数据集目录
:type in_path: str
:param ent_file: entity2id.txt
:type ent_file: str
:param rel_file: relation2id.txt
:type rel_file: str
:param train_file: train2id.txt
:type train_file: str
:param batch_size: batch size
:type batch_size: int | None
:param neg_ent: 对于 CompGCN 不起作用。
:type neg_ent: int
"""
super().__init__(
in_path=in_path,
ent_file=ent_file,
rel_file=rel_file,
train_file=train_file,
batch_size=batch_size,
neg_ent=neg_ent
)
super().get_hr_train()
self.graph, self.relation, self.norm = \
self.build_graph(self.ent_tol, np.array(self.t_triples).transpose(), -0.5)
[文档] @override
def sampling(
self,
pos_hr_t: list[tuple[tuple[int, int], list[int]]]) -> dict[str, typing.Union[dgl.DGLGraph, torch.Tensor]]:
"""``CompGCN`` :cite:`CompGCN` 的采样函数。
:param pos_triples: 知识图谱中的正确三元组
:type pos_triples: list[tuple[tuple[int, int], list[int]]]
:returns: ``CompGCN`` :cite:`CompGCN` 的训练数据
:rtype: dict[str, typing.Union[dgl.DGLGraph, torch.Tensor]]
"""
batch_data = {}
self.label = torch.zeros(self.batch_size, self.ent_tol)
self.triples = torch.LongTensor([hr for hr , _ in pos_hr_t])
for id, hr_sample in enumerate([t for _ ,t in pos_hr_t]):
self.label[id][hr_sample] = 1
batch_data['sample'] = self.triples
batch_data['label'] = self.label
batch_data['graph'] = self.graph
batch_data['relation'] = self.relation
batch_data['norm'] = self.norm
return batch_data
[文档] @override
def node_norm_to_edge_norm(
self,
graph: dgl.DGLGraph,
node_norm: torch.Tensor) -> torch.Tensor:
"""根据源节点和目标节点的度计算每条边的归一化系数。
:param graph: 子图的节点数
:type graph: dgl.DGLGraph
:param node_norm: 节点的归一化系数
:type node_norm: torch.Tensor
:returns: 边的归一化系数
:rtype: torch.Tensor
"""
graph.ndata['norm'] = node_norm
graph.apply_edges(lambda edges: {'norm': edges.dst['norm'] * edges.src['norm']})
norm = graph.edata.pop('norm').squeeze()
return norm