实现细节¶
SimplE¶
OpenKE-PyTorch 实现的 SimplE 存在问题。
下面是 SimplE [KP18] 的作者给出的声明:
重要
Hi all, I’m the main author of the SimplE paper. I have received emails asking me if the OpenKE implementation of SimplE is correct or not so I thought I post a public response here. I can confirm that the OpenKE implementation is indeed incorrect and there are two issues (one major, one minor) in it:
Major issue: As pointed out by @dschaehi there’s a major issue in the model definition. SimplE requires two embedding vectors per entity, one to be used when the entity is the head and one to be used when the entity is the tail. In the OpenKE implementation, there is only one embedding vector per entity which hurts the model by making it almost identical to DistMult.
Minor issue: This implementation corresponds to a variant of SimplE which we called SimplE-ignr in the paper. It takes the average of the two predictions during training but only uses one of the predictions during testing (see https://github.com/thunlp/OpenKE/blob/OpenKE-PyTorch/openke/module/model/SimplE.py#L54). The standard SimplE model takes the average of the two predictions for both training and testing.
For a correct pytorch implementation of SimplE, I recommend this repo: https://github.com/baharefatemi/SimplE/blob/master/SimplE.py
关于这个问题的讨论在:https://github.com/thunlp/OpenKE/issues/151 。
因此,遵从 SimplE 作者的建议,依据 https://github.com/baharefatemi/SimplE/blob/master/SimplE.py 实现 SimplE 。
最终实现可以从 这里 得到。
HolE¶
警告
由于 unike.module.model.HolE 的
unike.module.model.HolE._ccorr() (OpenKE-PyTorch 的原始实现)需要
torch.rfft 和 torch.ifft 分别计算实数到复数离散傅里叶变换和复数到复数离散傅立叶逆变换。
pytorch 在版本 1.8.0 移除了上述两个函数,并且在版本 1.7.0 给出了警告。
因此,需要适配到更高版本的 pytorch。
重要
我参考了 PyKEEN 的 hole_interaction 实现 ,重新实现了 unike.module.model.HolE,
使其能够适配到更高版本的 pytorch。
RESCAL¶
我去掉了原始 OpenKE-PyTorch 的 RESCAL 的
predict 的
负号,原因如下:
警告
下面的内容是使用 1.0.0 版本 的实现进行陈述的,与 2.0.0 版本 不相符合。
由于 unike.module.model.RESCAL 采用 unike.module.loss.MarginLoss 进行训练,因此需要正样本评分函数的得分应小于负样本评分函数的得分,
unike.module.model.RESCAL 的评分函数需要添加负号即 unike.module.model.RESCAL._calc() 需要添加负号;
由于 UniKE 使用底层 C++ 模块进行评估模型性能,该模块需要正样本的得分小于负样本的得分,
因此 unike.module.model.RESCAL.predict() 不需要在 unike.module.model.RESCAL.forward() 返回的结果上添加负号。
重要
实验表明,去掉 unike.module.model.RESCAL.predict() 负号能够大幅度改善模型的评估结果。
ANALOGY¶
我去掉了原始 OpenKE-PyTorch 的 Analogy 的
_calc 的
负号,原因如下:
在旧版的 OpenKE-PyTorch 中,
DistMult、
ComplEx、
Analogy 3 者的
_calc 函数都带了负号,并且在
Analogy 原论文的实现 中,
DistMult、
ComplEx、
Analogy 3 者的
score 函数都未带负号。从原论文中也能发现,三者的评分函数的符号应该是一致的。
但是在新版的 OpenKE-PyTorch 中,
三者 DistMult、
ComplEx、
Analogy 的
_calc 函数实现中,仅仅 Analogy 带了负号。
因此,我最终决定去掉 OpenKE-PyTorch 的 Analogy 的
_calc 的
负号。
从运行结果也没发现差异。
最终实现可以从 unike.module.model.Analogy 得到。