CompGCNSampler¶
- class unike.data.CompGCNSampler(in_path: str = './', ent_file: str = 'entity2id.txt', rel_file: str = 'relation2id.txt', train_file: str = 'train2id.txt', batch_size: int | None = None, neg_ent: int = 1)[源代码]¶
CompGCN[VSNT20] 的训练数据采样器。例子:
from unike.data import RGCNSampler, CompGCNSampler from torch.utils.data import DataLoader #: 训练数据采样器 train_sampler: typing.Union[typing.Type[RGCNSampler], typing.Type[CompGCNSampler]] = train_sampler( in_path=in_path, ent_file=ent_file, rel_file=rel_file, train_file=train_file, batch_size=batch_size, neg_ent=neg_ent ) #: 训练集三元组 data_train: list[tuple[int, int, int]] = train_sampler.get_train() train_dataloader = DataLoader( data_train, shuffle=True, batch_size=batch_size, num_workers=num_workers, pin_memory=True, drop_last=True, collate_fn=train_sampler.sampling, )
- __init__(in_path: str = './', ent_file: str = 'entity2id.txt', rel_file: str = 'relation2id.txt', train_file: str = 'train2id.txt', batch_size: int | None = None, neg_ent: int = 1)[源代码]¶
创建 CompGCNSampler 对象。
- 参数:
in_path (str) – 数据集目录
ent_file (str) – entity2id.txt
rel_file (str) – relation2id.txt
train_file (str) – train2id.txt
batch_size (int | None) – batch size
neg_ent (int) – 对于 CompGCN 不起作用。
- __weakref__¶
list of weak references to the object (if defined)
- add_reverse_relation()¶
增加相反关系:r` = r + rel_tol
- add_train_reverse_triples()¶
对于每一个三元组 (h, r, t),生成相反关系三元组 (t, r`, h): r` = r + rel_tol。
- batch_size: int¶
batch size
- build_graph(num_ent: int, triples: tuple[torch.Tensor, torch.Tensor, torch.Tensor], power: int = -1) tuple[dgl.DGLGraph, torch.Tensor, torch.Tensor]¶
建立子图。
- 参数:
num_ent (int) – 子图的节点数
triples (tuple[torch.Tensor, torch.Tensor, torch.Tensor]) – 知识图谱中的正确三元组子集
power (int) – 幂
- 返回:
子图、关系、边的归一化系数
- 返回类型:
tuple[dgl.DGLGraph, torch.Tensor, torch.Tensor]
- comp_deg_norm(graph: dgl.DGLGraph, power: int = -1) torch.Tensor¶
根据目标节点度计算目标节点的归一化系数。
- 参数:
graph (dgl.DGLGraph) – 子图
power (int) – 幂
- 返回:
节点的归一化系数
- 返回类型:
torch.Tensor
- ent2id: dict¶
实体->ID
- ent_file: str¶
entity2id.txt
- ent_tol: int¶
实体的个数
- get_hr2t_rt2h_from_train()¶
获得
hr2t_train和rt2h_train。
- get_train() list[tuple[int, int, int]]¶
返回训练集三元组。
- 返回:
- 返回类型:
list[tuple[int, int, int]]
- get_train_triples_id()¶
读取
train_file文件。
- hr2t_train: collections.defaultdict[set]¶
训练集中所有 h-r 对对应的 t 集合
- id2ent: dict¶
ID->实体
- id2rel: dict¶
ID->关系
- in_path: str¶
数据集目录
- neg_ent: int¶
对于每一个正三元组, 构建的负三元组的个数, 替换 entity (head + tail)
- node_norm_to_edge_norm(graph: dgl.DGLGraph, node_norm: torch.Tensor) torch.Tensor[源代码]¶
根据源节点和目标节点的度计算每条边的归一化系数。
- 参数:
graph (dgl.DGLGraph) – 子图的节点数
node_norm (torch.Tensor) – 节点的归一化系数
- 返回:
边的归一化系数
- 返回类型:
torch.Tensor
- rel2id: dict¶
关系->ID
- rel_file: str¶
relation2id.txt
- rel_tol: int¶
关系的个数
- rt2h_train: collections.defaultdict[set]¶
训练集中所有 r-t 对对应的 h 集合
- sampling(pos_hr_t: list[tuple[tuple[int, int], list[int]]]) dict[str, Union[dgl.DGLGraph, torch.Tensor]][源代码]¶
CompGCN[VSNT20] 的采样函数。- 参数:
pos_triples (list[tuple[tuple[int, int], list[int]]]) – 知识图谱中的正确三元组
- 返回:
CompGCN[VSNT20] 的训练数据- 返回类型:
dict[str, Union[dgl.DGLGraph, torch.Tensor]]
- sampling_negative(mode: int, pos_triples: list[tuple[int, int, int]]) numpy.ndarray¶
采样负三元组。
- 参数:
mode (str) – ‘head’ 或 ‘tail’
pos_triples (list[tuple[int, int, int]]) – 知识图谱中的正确三元组
- 返回:
负三元组
- 返回类型:
numpy.ndarray
- sampling_positive(positive_triples: list[tuple[int, int, int]]) tuple[numpy.ndarray, torch.Tensor]¶
为创建子图重新采样三元组子集,重排实体 ID。
- 参数:
pos_triples (list[tuple[int, int, int]]) – 知识图谱中的正确三元组
- 返回:
三元组子集和原始的实体 ID
- 返回类型:
tuple[numpy.ndarray, torch.Tensor]
- train_file: str¶
train2id.txt
- train_tol: int¶
训练集三元组的个数
- train_triples: list[tuple[int, int, int]]¶
训练集三元组