Computes graph edges to the nearest k points.
import torch
from torch_geometric.nn import knn_graph
x = torch.tensor([[-1.0, -1.0], [-1.0, 1.0], [1.0, -1.0], [1.0, 1.0]])
batch = torch.tensor([0, 0, 0, 0])
edge_index = knn_graph(x, k=2, batch=batch, loop=False)
x (torch.Tensor) – Node feature matrix \(\mathbf{X} \in \mathbb{R}^{N \times F}\).
k (int) – The number of neighbors.
batch (torch.Tensor, optional) – Batch vector
\(\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N\), which assigns each
node to a specific example. (default: None)
loop (bool, optional) – If True, the graph will contain
self-loops. (default: False)
flow (str, optional) – The flow direction when using in combination with
message passing ("source_to_target" or
"target_to_source"). (default: "source_to_target")
cosine (bool, optional) – If True, will use the cosine
distance instead of euclidean distance to find nearest neighbors.
(default: False)
num_workers (int, optional) – Number of workers to use for computation.
Has no effect in case batch is not None, or the input
lies on the GPU. (default: 1)
batch_size (int, optional) – The number of examples \(B\).
Automatically calculated if not given. (default: None)