unike.utils.tools 源代码
# coding:utf-8
#
# unike/utils/tools.py
#
# created by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 3, 2024
# updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Feb 25, 2024
#
# 该脚本定义了 WandbLogger 类.
import importlib
[文档]def import_class(module_and_class_name: str) -> type:
"""从模块中导入类。
:param module_and_class_name: 模块和类名,如 **unike.module.model.TransE** 。
:type module_and_class_name: str
:returns: 类名
:rtype: type
"""
module_name, class_name = module_and_class_name.rsplit(".", 1)
module = importlib.import_module(module_name)
class_ = getattr(module, class_name)
return class_
[文档]def construct_type_constrain(
in_path: str = "./",
train_file: str = "train2id.txt",
valid_file: str = "valid2id.txt",
test_file: str = "test2id.txt"
):
"""构建 type_constrain.txt 文件
type_constrain.txt: 类型约束文件, 第一行是关系的个数
下面的行是每个关系的类型限制 (训练集、验证集、测试集中每个关系存在的 head 和 tail 的类型)
每个关系有两行:
第一行:**rel_id** **heads_num** **head1** **head2** ...
第二行: **rel_id** **tails_num** **tail1** **tail2** ...
如 benchmarks/FB15K 的 id 为 1200 的关系,它有 4 种类型头实体(3123,1034,58 和 5733)和 4 种类型的尾实体(12123,4388,11087 和 11088)。
1200 4 3123 1034 58 5733
1200 4 12123 4388 11087 11088
:param in_path: 数据集目录
:type in_path: str
:param train_file: train2id.txt
:type train_file: str
:param valid_file: valid2id.txt
:type valid_file: str
:param test_file: test2id.txt
:type test_file: str
"""
rel_head: dict = {}
rel_tail: dict = {}
train = open(in_path + train_file, "r")
valid = open(in_path + valid_file, "r")
test = open(in_path + test_file, "r")
tot = (int)(train.readline())
for i in range(tot):
content = train.readline()
h,t,r = content.strip().split()
if not r in rel_head:
rel_head[r] = {}
if not r in rel_tail:
rel_tail[r] = {}
rel_head[r][h] = 1
rel_tail[r][t] = 1
tot = (int)(valid.readline())
for i in range(tot):
content = valid.readline()
h,t,r = content.strip().split()
if not r in rel_head:
rel_head[r] = {}
if not r in rel_tail:
rel_tail[r] = {}
rel_head[r][h] = 1
rel_tail[r][t] = 1
tot = (int)(test.readline())
for i in range(tot):
content = test.readline()
h,t,r = content.strip().split()
if not r in rel_head:
rel_head[r] = {}
if not r in rel_tail:
rel_tail[r] = {}
rel_head[r][h] = 1
rel_tail[r][t] = 1
train.close()
valid.close()
test.close()
f = open(in_path + "type_constrain.txt", "w")
f.write("%d\n" % (len(rel_head)))
for i in rel_head:
f.write("%s\t%d" % (i, len(rel_head[i])))
for j in rel_head[i]:
f.write("\t%s" % (j))
f.write("\n")
f.write("%s\t%d" % (i, len(rel_tail[i])))
for j in rel_tail[i]:
f.write("\t%s" % (j))
f.write("\n")
f.close()