Shortcuts

CompGCN-FB15K237-single-gpu || CompGCN-FB15K237-single-gpu-wandb || CompGCN-FB15K237-single-gpu-hpo

CompGCN-FB15K237-single-gpu-hpo

备注

created by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on May 7, 2023

备注

updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on May 23, 2024

备注

last run by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on May 23, 2024

这一部分介绍如何用一个 GPU 在 FB15K237 知识图谱上寻找 CompGCN [VSNT20] 的超参数。

定义训练数据加载器超参数优化范围

import pprint
import os
from unike.data import get_kge_data_loader_hpo_config
from unike.module.model import get_compgcn_hpo_config
from unike.module.loss import get_compgcn_loss_hpo_config
from unike.module.strategy import get_compgcn_sampling_hpo_config
from unike.config import get_tester_hpo_config
from unike.config import get_trainer_hpo_config
from unike.config import set_hpo_config, start_hpo_train

unike.data.get_kge_data_loader_hpo_config() 将返回 unike.data.KGEDataLoader 的默认超参数优化范围, 你可以修改数据目录等信息。

data_loader_config = get_kge_data_loader_hpo_config()
print("data_loader_config:")
pprint.pprint(data_loader_config)
print()

data_loader_config.update({
    'in_path': {
        'value': os.path.join(os.path.dirname(__file__), '../../benchmarks/FB15K237/')
    },
    'neg_ent': {
        'value': 1
    },
    'train_sampler': {
        'value': 'CompGCNSampler'
    },
    'test_sampler': {
        'value': 'CompGCNTestSampler'
    },
    'test_batch_size': {
        'value': 256
    }
})

定义模型超参数优化范围

unike.module.model.get_compgcn_hpo_config() 返回了 unike.module.model.CompGCN 的默认超参数优化范围。

# set the hpo config
kge_config = get_compgcn_hpo_config()
print("kge_config:")
pprint.pprint(kge_config)
print()

定义损失函数超参数优化范围

unike.module.loss.get_compgcn_loss_hpo_config() 返回了 unike.module.loss.CompGCNLoss 的默认超参数优化范围。

# set the hpo config
loss_config = get_compgcn_loss_hpo_config()
print("loss_config:")
pprint.pprint(loss_config)
print()

定义训练策略超参数优化范围

unike.module.strategy.get_compgcn_sampling_hpo_config() 返回了 unike.module.strategy.CompGCNSampling 的默认超参数优化范围。

# set the hpo config
strategy_config = get_compgcn_sampling_hpo_config()
print("strategy_config:")
pprint.pprint(strategy_config)
print()

定义评估器超参数优化范围

unike.config.get_tester_hpo_config() 返回了 unike.config.Tester 的默认超参数优化范围。

# set the hpo config
tester_config = get_tester_hpo_config()
print("tester_config:")
pprint.pprint(tester_config)
print()

tester_config.update({
    'prediction': {
        'value': 'tail'
    }
})

定义训练器超参数优化范围

unike.config.get_trainer_hpo_config() 返回了 unike.config.Trainer 的默认超参数优化范围。

# set the hpo config
trainer_config = get_trainer_hpo_config()
print("trainer_config:")
pprint.pprint(trainer_config)
print()

trainer_config.update({
    'lr': {
        'distribution': 'uniform',
        'min': 0,
        'max': 0.01
    }
})

设置超参数优化参数

unike.config.set_hpo_config() 可以设置超参数优化参数。

# set the hpo config
sweep_config = set_hpo_config(
    sweep_name = "CompGCN_FB15K237",
    data_loader_config = data_loader_config,
    kge_config = kge_config,
    loss_config = loss_config,
    strategy_config = strategy_config,
    tester_config = tester_config,
    trainer_config = trainer_config)
print("sweep_config:")
pprint.pprint(sweep_config)
print()

开始超参数优化

unike.config.start_hpo_train() 可以开始超参数优化。

# start hpo
start_hpo_train(config=sweep_config, count = 3)

备注

上述代码的运行日志可以从 此处 下载。

备注

上述代码的运行报告可以从 此处 下载。


Total running time of the script: ( 0 minutes 0.000 seconds)

Gallery generated by Sphinx-Gallery

Docs

Access comprehensive developer documentation for UniKE

View Docs