Shortcuts

RGCNSampler

class unike.data.RGCNSampler(in_path: str = './', ent_file: str = 'entity2id.txt', rel_file: str = 'relation2id.txt', train_file: str = 'train2id.txt', batch_size: int | None = None, neg_ent: int = 1)[源代码]

R-GCN [SKB+18] 的训练数据采样器。

例子:

from unike.data import RGCNSampler, CompGCNSampler
from torch.utils.data import DataLoader

#: 训练数据采样器
train_sampler: typing.Union[typing.Type[RGCNSampler], typing.Type[CompGCNSampler]] = train_sampler(
    in_path=in_path,
    ent_file=ent_file,
    rel_file=rel_file,
    train_file=train_file,
    batch_size=batch_size,
    neg_ent=neg_ent
)

#: 训练集三元组
data_train: list[tuple[int, int, int]] = train_sampler.get_train()

train_dataloader = DataLoader(
    data_train,
    shuffle=True,
    batch_size=batch_size,
    num_workers=num_workers,
    pin_memory=True,
    drop_last=True,
    collate_fn=train_sampler.sampling,
)
__init__(in_path: str = './', ent_file: str = 'entity2id.txt', rel_file: str = 'relation2id.txt', train_file: str = 'train2id.txt', batch_size: int | None = None, neg_ent: int = 1)[源代码]

创建 RGCNSampler 对象。

参数:
  • in_path (str) – 数据集目录

  • ent_file (str) – entity2id.txt

  • rel_file (str) – relation2id.txt

  • train_file (str) – train2id.txt

  • batch_size (int | None) – batch size

  • neg_ent (int) – 对于每一个正三元组, 构建的负三元组的个数, 替换 entity (head + tail)

__weakref__

list of weak references to the object (if defined)

add_reverse_relation()

增加相反关系:r` = r + rel_tol

add_train_reverse_triples()

对于每一个三元组 (h, r, t),生成相反关系三元组 (t, r`, h): r` = r + rel_tol。

batch_size: int

batch size

build_graph(num_ent: int, triples: tuple[torch.Tensor, torch.Tensor, torch.Tensor], power: int = -1) tuple[dgl.DGLGraph, torch.Tensor, torch.Tensor][源代码]

建立子图。

参数:
  • num_ent (int) – 子图的节点数

  • triples (tuple[torch.Tensor, torch.Tensor, torch.Tensor]) – 知识图谱中的正确三元组子集

  • power (int) – 幂

返回:

子图、关系、边的归一化系数

返回类型:

tuple[dgl.DGLGraph, torch.Tensor, torch.Tensor]

comp_deg_norm(graph: dgl.DGLGraph, power: int = -1) torch.Tensor[源代码]

根据目标节点度计算目标节点的归一化系数。

参数:
  • graph (dgl.DGLGraph) – 子图

  • power (int) – 幂

返回:

节点的归一化系数

返回类型:

torch.Tensor

ent2id: dict

实体->ID

ent_file: str

entity2id.txt

ent_tol: int

实体的个数

get_hr2t_rt2h_from_train()

获得 hr2t_trainrt2h_train

get_hr_train()

用于 CompGCN [VSNT20] 训练,因为 CompGCN [VSNT20] 的组合运算仅需要头实体和关系。

如果想获得更详细的信息请访问 CompGCN

get_id()

读取 ent_file 文件和 rel_file 文件。

get_train() list[tuple[int, int, int]]

返回训练集三元组。

返回:

train_triples

返回类型:

list[tuple[int, int, int]]

get_train_triples_id()

读取 train_file 文件。

hr2t_train: collections.defaultdict[set]

训练集中所有 h-r 对对应的 t 集合

id2ent: dict

ID->实体

id2rel: dict

ID->关系

in_path: str

数据集目录

neg_ent: int

对于每一个正三元组, 构建的负三元组的个数, 替换 entity (head + tail)

node_norm_to_edge_norm(graph: dgl.DGLGraph, node_norm: torch.Tensor) torch.Tensor[源代码]

根据目标节点度计算每条边的归一化系数。

参数:
  • graph (dgl.DGLGraph) – 子图

  • node_norm (torch.Tensor) – 节点的归一化系数

返回:

边的归一化系数

返回类型:

torch.Tensor

rel2id: dict

关系->ID

rel_file: str

relation2id.txt

rel_tol: int

关系的个数

rt2h_train: collections.defaultdict[set]

训练集中所有 r-t 对对应的 h 集合

sampling(pos_triples: list[tuple[int, int, int]]) dict[str, Union[dgl.DGLGraph, torch.Tensor]][源代码]

R-GCN [SKB+18] 的采样函数。

参数:

pos_triples (list[tuple[int, int, int]]) – 知识图谱中的正确三元组

返回:

R-GCN [SKB+18] 的训练数据

返回类型:

dict[str, Union[dgl.DGLGraph , torch.Tensor]]

sampling_negative(mode: int, pos_triples: list[tuple[int, int, int]]) numpy.ndarray[源代码]

采样负三元组。

参数:
  • mode (str) – ‘head’ 或 ‘tail’

  • pos_triples (list[tuple[int, int, int]]) – 知识图谱中的正确三元组

返回:

负三元组

返回类型:

numpy.ndarray

sampling_positive(positive_triples: list[tuple[int, int, int]]) tuple[numpy.ndarray, torch.Tensor][源代码]

为创建子图重新采样三元组子集,重排实体 ID。

参数:

pos_triples (list[tuple[int, int, int]]) – 知识图谱中的正确三元组

返回:

三元组子集和原始的实体 ID

返回类型:

tuple[numpy.ndarray, torch.Tensor]

train_file: str

train2id.txt

train_tol: int

训练集三元组的个数

train_triples: list[tuple[int, int, int]]

训练集三元组

Docs

Access comprehensive developer documentation for UniKE

View Docs