TradSampler¶
- class unike.data.TradSampler(in_path: str = './', ent_file: str = 'entity2id.txt', rel_file: str = 'relation2id.txt', train_file: str = 'train2id.txt', batch_size: int | None = None, neg_ent: int = 1)[源代码]¶
平移模型和语义匹配模型的采样器的基类。
- __init__(in_path: str = './', ent_file: str = 'entity2id.txt', rel_file: str = 'relation2id.txt', train_file: str = 'train2id.txt', batch_size: int | None = None, neg_ent: int = 1)[源代码]¶
创建 TradSampler 对象。
- 参数:
in_path (str) – 数据集目录
ent_file (str) – entity2id.txt
rel_file (str) – relation2id.txt
train_file (str) – train2id.txt
batch_size (int | None) – batch size 在该采样器中不起作用,只是占位符。
neg_ent (int) – 对于每一个正三元组, 构建的负三元组的个数, 替换 entity
- __weakref__¶
list of weak references to the object (if defined)
- batch_size: int¶
batch size
- corrupt_head(t: int, r: int, num_max: int = 1) numpy.ndarray[源代码]¶
替换头实体构建负三元组。
- 参数:
t (int) – 尾实体
r (int) – 关系
num_max (int) – 一次负采样的个数
- 返回:
负三元组的头实体列表
- 返回类型:
numpy.ndarray
- corrupt_tail(h: int, r: int, num_max: int = 1) numpy.ndarray[源代码]¶
替换尾实体构建负三元组。
- 参数:
h (int) – 头实体
r (int) – 关系
num_max (int) – 一次负采样的个数
- 返回:
负三元组的尾实体列表
- 返回类型:
numpy.ndarray
- ent2id: dict¶
实体->ID
- ent_file: str¶
entity2id.txt
- ent_tol: int¶
实体的个数
- get_hr2t_rt2h_from_train()¶
获得
hr2t_train和rt2h_train。
- get_train() list[tuple[int, int, int]]¶
返回训练集三元组。
- 返回:
- 返回类型:
list[tuple[int, int, int]]
- get_train_triples_id()¶
读取
train_file文件。
- hr2t_train: collections.defaultdict[set]¶
训练集中所有 h-r 对对应的 t 集合
- id2ent: dict¶
ID->实体
- id2rel: dict¶
ID->关系
- in_path: str¶
数据集目录
- neg_ent: int¶
对于每一个正三元组, 构建的负三元组的个数, 替换 entity (head + tail)
- rel2id: dict¶
关系->ID
- rel_file: str¶
relation2id.txt
- rel_tol: int¶
关系的个数
- rt2h_train: collections.defaultdict[set]¶
训练集中所有 r-t 对对应的 h 集合
- sampling(pos_triples: list[tuple[int, int, int]]) dict[str, Union[str, torch.Tensor]][源代码]¶
平移模型和语义匹配模型的训练集数据采样函数。该方法未实现,子类必须重写该方法,否则抛出
NotImplementedError错误。- 参数:
pos_triples (list[tuple[int, int, int]]) – 知识图谱中的正确三元组
- 返回:
平移模型和语义匹配模型的训练数据
- 返回类型:
dict[str, Union[str, torch.Tensor]]
- train_file: str¶
train2id.txt
- train_tol: int¶
训练集三元组的个数
- train_triples: list[tuple[int, int, int]]¶
训练集三元组