Shortcuts

KGEDataLoader

class unike.data.KGEDataLoader(in_path: str = './', ent_file: str = 'entity2id.txt', rel_file: str = 'relation2id.txt', train_file: str = 'train2id.txt', valid_file: str = 'valid2id.txt', test_file: str = 'test2id.txt', batch_size: int | None = None, neg_ent: int = 1, test: bool = False, test_batch_size: int | None = None, type_constrain: bool = True, num_workers: int | None = None, train_sampler: ~typing.Type[~unike.data.UniSampler.UniSampler] | ~typing.Type[~unike.data.BernSampler.BernSampler] | ~typing.Type[~unike.data.RGCNSampler.RGCNSampler] | ~typing.Type[~unike.data.CompGCNSampler.CompGCNSampler] = <class 'unike.data.BernSampler.BernSampler'>, test_sampler: ~typing.Type[~unike.data.TestSampler.TestSampler] = <class 'unike.data.TradTestSampler.TradTestSampler'>)[源代码]

KGE 模型数据加载器。

例子:

from unike.data import KGEDataLoader, BernSampler, TradTestSampler

dataloader = KGEDataLoader(
        in_path = "../../benchmarks/FB15K/", 
        batch_size = 8192,
        neg_ent = 25,
        test = True,
        test_batch_size = 256,
        num_workers = 16,
        train_sampler = BernSampler,
        test_sampler = TradTestSampler
)
__init__(in_path: str = './', ent_file: str = 'entity2id.txt', rel_file: str = 'relation2id.txt', train_file: str = 'train2id.txt', valid_file: str = 'valid2id.txt', test_file: str = 'test2id.txt', batch_size: int | None = None, neg_ent: int = 1, test: bool = False, test_batch_size: int | None = None, type_constrain: bool = True, num_workers: int | None = None, train_sampler: ~typing.Type[~unike.data.UniSampler.UniSampler] | ~typing.Type[~unike.data.BernSampler.BernSampler] | ~typing.Type[~unike.data.RGCNSampler.RGCNSampler] | ~typing.Type[~unike.data.CompGCNSampler.CompGCNSampler] = <class 'unike.data.BernSampler.BernSampler'>, test_sampler: ~typing.Type[~unike.data.TestSampler.TestSampler] = <class 'unike.data.TradTestSampler.TradTestSampler'>)[源代码]

创建 KGEDataLoader 对象。

参数:
  • in_path (str) – 数据集目录

  • ent_file (str) – entity2id.txt

  • rel_file (str) – relation2id.txt

  • train_file (str) – train2id.txt

  • valid_file (str) – valid2id.txt

  • test_file (str) – test2id.txt

  • batch_size (int | None) – batch size

  • neg_ent (int) – 对于每一个正三元组, 构建的负三元组的个数, 替换 entity;对于 CompGCN 不起作用。

  • test (bool) – 是否读取验证集和测试集

  • test_batch_size (int | None) – test batch size

  • type_constrain (bool) – 是否报告 type_constrain.txt 限制的测试结果

  • num_workers (int) – 加载数据的进程数

  • train_sampler (Union[Type[UniSampler], Type[BernSampler], Type[RGCNSampler], Type[CompGCNSampler]]) – 训练数据采样器

  • test_sampler (Type[TestSampler]) – 测试数据采样器

__weakref__

list of weak references to the object (if defined)

batch_size: int

batch size

data_test: list[tuple[int, int, int]]

测试集三元组

data_train: list[tuple[int, int, int]]

训练集三元组

data_val: list[tuple[int, int, int]]

验证集三元组

ent_file: str

entity2id.txt

get_ent_tol() int[源代码]

返回实体个数。

返回:

实体个数

返回类型:

int

get_rel_tol() int[源代码]

返回关系个数。

返回:

关系个数

返回类型:

int

in_path: str

数据集目录

neg_ent: int

对于每一个正三元组, 构建的负三元组的个数, 替换 entity;对于 CompGCN 不起作用。

num_workers: int

加载数据的进程数

rel_file: str

relation2id.txt

test: bool

是否读取验证集和测试集

test_batch_size: int

test batch size

test_dataloader() torch.utils.data.DataLoader[源代码]

返回测试数据加载器。

返回:

测试数据加载器

返回类型:

torch.utils.data.DataLoader

test_file: str

test2id.txt

test_sampler: TestSampler

测试数据采样器

train_dataloader() torch.utils.data.DataLoader[源代码]

返回训练数据加载器。

返回:

训练数据加载器

返回类型:

torch.utils.data.DataLoader

train_file: str

train2id.txt

train_sampler: UniSampler | BernSampler | RGCNSampler | CompGCNSampler

训练数据采样器

type_constrain: bool

是否报告 type_constrain.txt 限制的测试结果

val_dataloader() torch.utils.data.DataLoader[源代码]

返回验证数据加载器。

返回:

验证数据加载器

返回类型:

torch.utils.data.DataLoader

valid_file: str

valid2id.txt

validate_data_test() None[源代码]

验证测试集和验证集中的实体和关系 id 是否合法。

validate_data_train() None[源代码]

验证训练集中的实体和关系 id 是否合法。

Docs

Access comprehensive developer documentation for UniKE

View Docs