TradTestSampler¶
- class unike.data.TradTestSampler(sampler: TradSampler, valid_file: str = 'valid2id.txt', test_file: str = 'test2id.txt', type_constrain: bool = True)[源代码]¶
平移模型和语义匹配模型的测试数据采样器。
- __init__(sampler: TradSampler, valid_file: str = 'valid2id.txt', test_file: str = 'test2id.txt', type_constrain: bool = True)[源代码]¶
创建 TradTestSampler 对象。
- 参数:
sampler (TradSampler) – 训练数据采样器。
valid_file (str) – valid2id.txt
test_file (str) – test2id.txt
type_constrain (bool) – 是否报告 type_constrain.txt 限制的测试结果
- __weakref__¶
list of weak references to the object (if defined)
- all_true_triples: set[tuple[int, int, int]]¶
知识图谱所有三元组
- ent_tol: int¶
实体的个数
- get_all_true_triples() set[tuple[int, int, int]]¶
返回知识图谱所有三元组。
- 返回:
- 返回类型:
set[tuple[int, int, int]]
- get_test() list[tuple[int, int, int]]¶
返回测试集三元组。
- 返回:
- 返回类型:
list[tuple[int, int, int]]
- get_type_constrain_id()¶
读取 type_constrain.txt 文件。
- get_valid() list[tuple[int, int, int]]¶
返回验证集三元组。
- 返回:
- 返回类型:
list[tuple[int, int, int]]
- get_valid_test_triples_id()¶
读取
valid_file文件和test_file文件。
- hr2t_all: ddict[set]¶
知识图谱中所有 h-r 对对应的 t 集合
- rel_heads: ddict[set]¶
知识图谱中所有 r 存在头实体种类
- rel_tails: ddict[set]¶
知识图谱中所有 r 存在尾实体种类
- rt2h_all: ddict[set]¶
知识图谱中所有 r-t 对对应的 h 集合
- sampler: Union[TradSampler, RGCNSampler, CompGCNSampler]¶
训练数据采样器
- sampling(data: list[tuple[int, int, int]]) dict[str, torch.Tensor][源代码]¶
采样函数。
- 参数:
data (list[tuple[int, int, int]]) – 测试的正确三元组
- 返回:
测试数据
- 返回类型:
dict[str, torch.Tensor]
- test_file: str¶
test2id.txt
- test_tol: int¶
测试集三元组的个数
- test_triples: list[tuple[int, int, int]]¶
测试集三元组
- type_constrain: bool¶
是否报告 type_constrain.txt 限制的测试结果
- valid_file: str¶
valid2id.txt
- valid_tol: int¶
验证集三元组的个数
- valid_triples: list[tuple[int, int, int]]¶
验证集三元组