Shortcuts

Tester

class unike.config.Tester(model: Model | None = None, data_loader: KGEDataLoader | None = None, sampling_mode: str = 'link_test', prediction: str = 'all', use_tqdm: bool = True, use_gpu: bool = True, device: str = 'cuda:0', only_test: bool = False)[源代码]

主要用于 KGE 模型的评估。

例子:

from unike.data import KGEDataLoader, BernSampler, TradTestSampler
from unike.module.model import TransE
from unike.module.loss import MarginLoss
from unike.module.strategy import NegativeSampling
from unike.config import Trainer, Tester

# dataloader for training
dataloader = KGEDataLoader(
        in_path = "../../benchmarks/FB15K/", 
        batch_size = 8192,
        neg_ent = 25,
        test = True,
        test_batch_size = 256,
        num_workers = 16,
        train_sampler = BernSampler,
        test_sampler = TradTestSampler
)

# define the model
transe = TransE(
        ent_tol = dataloader.train_sampler.ent_tol,
        rel_tol = dataloader.train_sampler.rel_tol,
        dim = 50, 
        p_norm = 1, 
        norm_flag = True)

# define the loss function
model = NegativeSampling(
        model = transe, 
        loss = MarginLoss(margin = 1.0),
        regul_rate = 0.01
)

# test the model
tester = Tester(model = transe, data_loader = dataloader, use_gpu = True, device = 'cuda:1')

# train the model
trainer = Trainer(model = model, data_loader = dataloader.train_dataloader(),
        epochs = 1000, lr = 0.01, use_gpu = True, device = 'cuda:1',
        tester = tester, test = True, valid_interval = 100,
        log_interval = 100, save_interval = 100,
        save_path = '../../checkpoint/transe.pth', delta = 0.01)
trainer.run()
__init__(model: Model | None = None, data_loader: KGEDataLoader | None = None, sampling_mode: str = 'link_test', prediction: str = 'all', use_tqdm: bool = True, use_gpu: bool = True, device: str = 'cuda:0', only_test: bool = False)[源代码]

创建 Tester 对象。

参数:
  • model (unike.module.model.Model) – KGE 模型

  • data_loader (unike.data.KGEDataLoader) – py:class:unike.data.KGEDataLoader

  • sampling_mode (str) – 评估验证集还是测试集:’link_test’ or ‘link_valid’

  • prediction (str) – 链接预测模式: ‘all’’head’’tail’

  • use_tqdm (bool) – 是否启用进度条

  • use_gpu (bool) – 是否使用 gpu

  • device (str) – 使用哪个 gpu

  • only_test (bool) – 是否是评估已经训练好的模型

__weakref__

list of weak references to the object (if defined)

data_loader: KGEDataLoader

unike.data.KGEDataLoader

device: torch.device

gpu,利用 device 构造的 torch.device 对象

hits: list[int] = [1, 3, 10]

准备报告的指标 Hit@N 的列表,默认为 [1, 3, 10], 表示报告 Hits@1, Hits@3, Hits@10

model: Model

KGE 模型,即 unike.module.model.Model

prediction: str

链接预测模式: ‘all’’head’’tail’

进行链接预测。

返回:

经典指标分别为 MR,MRR,Hits@1,Hits@3,Hits@10

返回类型:

dict[str, float]

sampling_mode: str

unike.data.TestDataLoader 负采样的方式:’link_test’ or ‘link_valid’

set_hits(new_hits: list[int] = [1, 3, 10])[源代码]

定义 Hits 指标。

参数:

new_hits (list[int]) – 准备报告的指标 Hit@N 的列表,默认为 [1, 3, 10], 表示报告 Hits@1, Hits@3, Hits@10

set_sampling_mode(sampling_mode: str)[源代码]

设置 sampling_mode

参数:

sampling_mode (str) – 数据采样模式,’link_test’‘link_valid’ 分别表示为链接预测进行测试集和验证集的负采样

test_dataloader: torch.utils.data.DataLoader

测试数据加载器。

to_var(x: torch.Tensor) torch.Tensor[源代码]

根据 use_gpu 返回 x 的张量

参数:

x (torch.Tensor) – 数据

返回:

张量

返回类型:

torch.Tensor

use_gpu: bool

是否使用 gpu

use_tqdm: bool

是否启用进度条

val_dataloader: torch.utils.data.DataLoader

验证数据加载器。

Docs

Access comprehensive developer documentation for UniKE

View Docs