Bases: MessagePassing
The Heterogeneous Graph Transformer (HGT) operator from the “Heterogeneous Graph Transformer” paper.
Note
For an example of using HGT, see examples/hetero/hgt_dblp.py.
in_channels (int or Dict[str, int]) – Size of each input sample of every
node type, or -1 to derive the size from the first input(s)
to the forward method.
out_channels (int) – Size of each output sample.
metadata (Tuple[List[str], List[Tuple[str, str, str]]]) – The metadata
of the heterogeneous graph, i.e. its node and edge types given
by a list of strings and a list of string triplets, respectively.
See torch_geometric.data.HeteroData.metadata() for more
information.
heads (int, optional) – Number of multi-head-attentions.
(default: 1)
**kwargs (optional) – Additional arguments of
torch_geometric.nn.conv.MessagePassing.
Runs the forward pass of the module.
x_dict (Dict[str, torch.Tensor]) – A dictionary holding input node features for each individual node type.
edge_index_dict (Dict[Tuple[str, str, str], torch.Tensor]) – A
dictionary holding graph connectivity information for each
individual edge type, either as a torch.Tensor of
shape [2, num_edges] or a
torch_sparse.SparseTensor.
Dict[str, Optional[torch.Tensor]] - The output node
embeddings for each node type.
In case a node type does not receive any message, its output will
be set to None.