CompGCNCov¶
- class unike.module.model.CompGCNCov(*args: Any, **kwargs: Any)[源代码]¶
CompGCN[VSNT20] 图神经网络模块。- __call__(*args: Any, **kwargs: Any) Any¶
Call self as a function.
- __init__(in_channels: int, out_channels: int, act: ~typing.Callable[[torch.Tensor], torch.Tensor] = <function CompGCNCov.<lambda>>, bias: bool = True, drop_rate: float = 0.0, opn: str = 'corr')[源代码]¶
创建 CompGCN 对象。
- 参数:
in_channels (int) – 输入的特征维度
out_channels (int) – 输出的特征维度
act (Callable[[torch.Tensor], torch.Tensor]) – 激活函数
bias (bool) – 是否有偏置
drop_rate (float) – Dropout rate
opn (str) – 组成运算符:’mult’、’sub’、’corr’
- static __new__(cls, *args: Any, **kwargs: Any) Any¶
- __repr__() str¶
Return repr(self).
- __weakref__¶
list of weak references to the object (if defined)
- act: Callable[[torch.Tensor], torch.Tensor]¶
激活函数
- bias: torch.nn.Parameter¶
偏置
- bn: torch.nn.BatchNorm1d¶
BatchNorm
- comp(h: torch.Tensor, r: torch.Tensor) torch.Tensor[源代码]¶
组成运算:’mult’、’sub’、’corr’
- 参数:
h (torch.Tensor) – 头实体嵌入向量
r (torch.Tensor) – 关系嵌入向量
- 返回:
组合后的边数据
- 返回类型:
torch.Tensor
- drop: torch.nn.Dropout¶
用于原始关系和相反关系转换后输出结果的 Dropout
- forward(graph: dgl.DGLGraph, ent_emb: torch.nn.parameter.Parameter, rel_emb: torch.nn.parameter.Parameter, edge_type: torch.Tensor, edge_norm: torch.Tensor) tuple[torch.nn.parameter.Parameter, torch.nn.parameter.Parameter][源代码]¶
定义每次调用时执行的计算。
torch.nn.Module子类必须重写torch.nn.Module.forward()。- 参数:
graph (dgl.DGLGraph) – 子图
ent_emb (torch.nn.parameter.Parameter) – 实体嵌入向量
rel_emb (torch.nn.parameter.Parameter) – 关系嵌入向量
edge_type (torch.Tensor) – 关系 ID
norm (torch.Tensor) – 关系的归一化系数
- 返回:
更新后的实体嵌入和关系嵌入
- 返回类型:
tuple[torch.nn.parameter.Parameter, torch.nn.parameter.Parameter]
- get_param(shape: list[int]) torch.nn.parameter.Parameter[源代码]¶
获得权重矩阵。
- 参数:
shape (list[int]) – 权重矩阵的 shape
- 返回:
权重矩阵
- 返回类型:
torch.nn.parameter.Parameter
- in_channels: int¶
输入的特征维度
- loop_rel: torch.nn.parameter.Parameter¶
自循环关系嵌入向量的转换矩阵
- opn: str¶
组成运算符:’mult’、’sub’、’corr’
- out_channels: int¶
输出的特征维度
- rel: torch.nn.parameter.Parameter¶
关系嵌入向量
- w_rel: torch.nn.parameter.Parameter¶
关系嵌入向量的转换矩阵