Shortcuts

unike.data.RGCNSampler 源代码

# coding:utf-8
#
# unike/data/RGCNSampler.py
#
# created by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 16, 2024
# updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 21, 2024
#
# R-GCN 的数据采样器.

"""
RGCNSampler - R-GCN 的数据采样器。
"""

import dgl
import torch
import typing
import warnings
import numpy as np
from .RevSampler import RevSampler

warnings.filterwarnings("ignore")

[文档]class RGCNSampler(RevSampler): """``R-GCN`` :cite:`R-GCN` 的训练数据采样器。 例子:: from unike.data import RGCNSampler, CompGCNSampler from torch.utils.data import DataLoader #: 训练数据采样器 train_sampler: typing.Union[typing.Type[RGCNSampler], typing.Type[CompGCNSampler]] = train_sampler( in_path=in_path, ent_file=ent_file, rel_file=rel_file, train_file=train_file, batch_size=batch_size, neg_ent=neg_ent ) #: 训练集三元组 data_train: list[tuple[int, int, int]] = train_sampler.get_train() train_dataloader = DataLoader( data_train, shuffle=True, batch_size=batch_size, num_workers=num_workers, pin_memory=True, drop_last=True, collate_fn=train_sampler.sampling, ) """
[文档] def __init__( self, in_path: str = "./", ent_file: str = "entity2id.txt", rel_file: str = "relation2id.txt", train_file: str = "train2id.txt", batch_size: int | None = None, neg_ent: int = 1): """创建 RGCNSampler 对象。 :param in_path: 数据集目录 :type in_path: str :param ent_file: entity2id.txt :type ent_file: str :param rel_file: relation2id.txt :type rel_file: str :param train_file: train2id.txt :type train_file: str :param batch_size: batch size :type batch_size: int | None :param neg_ent: 对于每一个正三元组, 构建的负三元组的个数, 替换 entity (head + tail) :type neg_ent: int """ super().__init__( in_path=in_path, ent_file=ent_file, rel_file=rel_file, train_file=train_file ) #: batch size self.batch_size: int = batch_size #: 对于每一个正三元组, 构建的负三元组的个数, 替换 entity (head + tail) self.neg_ent: int = neg_ent self.entity = None self.triples = None self.label = None self.graph = None self.relation = None self.norm = None
[文档] def sampling( self, pos_triples: list[tuple[int, int, int]]) -> dict[str, typing.Union[dgl.DGLGraph, torch.Tensor]]: """``R-GCN`` :cite:`R-GCN` 的采样函数。 :param pos_triples: 知识图谱中的正确三元组 :type pos_triples: list[tuple[int, int, int]] :returns: ``R-GCN`` :cite:`R-GCN` 的训练数据 :rtype: dict[str, typing.Union[dgl.DGLGraph , torch.Tensor]] """ batch_data = {} pos_triples = np.array(pos_triples) pos_triples, self.entity = self.sampling_positive(pos_triples) head_triples = self.sampling_negative('head', pos_triples) tail_triples = self.sampling_negative('tail', pos_triples) self.triples = np.concatenate((pos_triples,head_triples,tail_triples)) batch_data['entity'] = self.entity batch_data['triples'] = torch.from_numpy(self.triples) self.label = torch.zeros((len(self.triples),1)) self.label[0 : self.batch_size] = 1 batch_data['label'] = self.label split_size = int(self.batch_size * 0.5) graph_split_ids = np.random.choice( self.batch_size, size=split_size, replace=False ) head,rela,tail = pos_triples.transpose() head = torch.tensor(head[graph_split_ids], dtype=torch.long).contiguous() rela = torch.tensor(rela[graph_split_ids], dtype=torch.long).contiguous() tail = torch.tensor(tail[graph_split_ids], dtype=torch.long).contiguous() self.graph, self.relation, self.norm = self.build_graph(len(self.entity), (head,rela,tail), -1) batch_data['graph'] = self.graph batch_data['relation'] = self.relation batch_data['norm'] = self.norm return batch_data
[文档] def sampling_positive( self, positive_triples: list[tuple[int, int, int]]) -> tuple[np.ndarray, torch.Tensor]: """为创建子图重新采样三元组子集,重排实体 ID。 :param pos_triples: 知识图谱中的正确三元组 :type pos_triples: list[tuple[int, int, int]] :returns: 三元组子集和原始的实体 ID :rtype: tuple[numpy.ndarray, torch.Tensor] """ edges = np.random.choice( np.arange(len(positive_triples)), size = self.batch_size, replace=False ) edges = positive_triples[edges] head, rela, tail = np.array(edges).transpose() entity, index = np.unique((head, tail), return_inverse=True) head, tail = np.reshape(index, (2, -1)) return np.stack((head,rela,tail)).transpose(), \ torch.from_numpy(entity).view(-1,1).long()
[文档] def sampling_negative( self, mode: int, pos_triples: list[tuple[int, int, int]]) -> np.ndarray: """采样负三元组。 :param mode: 'head' 或 'tail' :type mode: str :param pos_triples: 知识图谱中的正确三元组 :type pos_triples: list[tuple[int, int, int]] :returns: 负三元组 :rtype: numpy.ndarray """ neg_random = np.random.choice( len(self.entity), size = self.neg_ent * len(pos_triples) ) neg_samples = np.tile(pos_triples, (self.neg_ent, 1)) if mode == 'head': neg_samples[:,0] = neg_random elif mode == 'tail': neg_samples[:,2] = neg_random return neg_samples
[文档] def build_graph( self, num_ent: int, triples: tuple[torch.Tensor, torch.Tensor, torch.Tensor], power: int = -1) -> tuple[dgl.DGLGraph, torch.Tensor, torch.Tensor]: """建立子图。 :param num_ent: 子图的节点数 :type num_ent: int :param triples: 知识图谱中的正确三元组子集 :type triples: tuple[torch.Tensor, torch.Tensor, torch.Tensor] :param power: 幂 :type power: int :returns: 子图、关系、边的归一化系数 :rtype: tuple[dgl.DGLGraph, torch.Tensor, torch.Tensor] """ head, rela, tail = triples[0], triples[1], triples[2] graph = dgl.graph(([], [])) graph.add_nodes(num_ent) graph.add_edges(head, tail) node_norm = self.comp_deg_norm(graph, power) edge_norm = self.node_norm_to_edge_norm(graph,node_norm) rela = torch.tensor(rela) return graph, rela, edge_norm
[文档] def comp_deg_norm( self, graph: dgl.DGLGraph, power: int = -1) -> torch.Tensor: """根据目标节点度计算目标节点的归一化系数。 :param graph: 子图 :type graph: dgl.DGLGraph :param power: 幂 :type power: int :returns: 节点的归一化系数 :rtype: torch.Tensor """ graph = graph.local_var() in_deg = graph.in_degrees(range(graph.number_of_nodes())).float().numpy() norm = in_deg.__pow__(power) norm[np.isinf(norm)] = 0 return torch.from_numpy(norm)
[文档] def node_norm_to_edge_norm( self, graph: dgl.DGLGraph, node_norm: torch.Tensor) -> torch.Tensor: """根据目标节点度计算每条边的归一化系数。 :param graph: 子图 :type graph: dgl.DGLGraph :param node_norm: 节点的归一化系数 :type node_norm: torch.Tensor :returns: 边的归一化系数 :rtype: torch.Tensor """ graph = graph.local_var() # convert to edge norm graph.ndata['norm'] = node_norm.view(-1,1) graph.apply_edges(lambda edges : {'norm' : edges.dst['norm']}) return graph.edata['norm']

Docs

Access comprehensive developer documentation for UniKE

View Docs