Args:
``features`` (``torch.Tensor``): The feature matrix.
``k`` (``int``): The number of nearest neighbors.
"""
assert features.ndim == 2, "The feature matrix should be 2-D."
assert (
k <= features.shape[0]
), "The number of nearest neighbors should be less than or equal to the number of vertices."
dist_matrix = torch.cdist(features, features, p=2)
_, nbr_indices = torch.topk(dist_matrix, k, largest=False)
return nbr_indices.tolist()
@staticmethod
def _e_list_from_feature_kNN(features: torch.Tensor, k: int):
r"""Construct hyperedges from the feature matrix. Each hyperedge in the hypergraph is constructed by the central vertex and its :math:
k-1neighbor vertices.