Source code for autogl.module.feature._graph._netlsd

import netlsd
import networkx
import torch
from autogl.data.graph import GeneralStaticGraph
from autogl.data.graph.utils import conversion
from .._base_feature_engineer import BaseFeatureEngineer
from .._feature_engineer_registry import FeatureEngineerUniversalRegistry


[docs]@FeatureEngineerUniversalRegistry.register_feature_engineer("NetLSD".lower()) class NetLSD(BaseFeatureEngineer): r""" Notes ----- a graph feature generation method. This is a simple wrapper of NetLSD [#]_. References ---------- .. [#] A. Tsitsulin, D. Mottin, P. Karras, A. Bronstein, and E. Müller, “NetLSD: Hearing the shape of a graph,” Proc. ACM SIGKDD Int. Conf. Knowl. Discov. Data Min., pp. 2347–2356, 2018. """ def __init__(self, *args, **kwargs): self.__args = args self.__kwargs = kwargs super(NetLSD, self).__init__() def __extract(self, nx_g: networkx.Graph) -> torch.Tensor: return torch.tensor(netlsd.heat(nx_g, *self.__args, **self.__kwargs)).view(-1) def __transform_homogeneous_static_graph( self, homogeneous_static_graph: GeneralStaticGraph ) -> GeneralStaticGraph: if not ( homogeneous_static_graph.nodes.is_homogeneous and homogeneous_static_graph.edges.is_homogeneous ): raise ValueError("Provided static graph must be homogeneous") dsc: torch.Tensor = self.__extract( conversion.HomogeneousStaticGraphToNetworkX(to_undirected=True).__call__( homogeneous_static_graph, to_undirected=True ) ) if 'gf' in homogeneous_static_graph.data: gf = homogeneous_static_graph.data['gf'].view(-1) homogeneous_static_graph.data['gf'] = torch.cat([gf, dsc]) else: homogeneous_static_graph.data['gf'] = dsc return homogeneous_static_graph @classmethod def __edge_index_to_nx_graph(cls, edge_index: torch.Tensor) -> networkx.Graph: g: networkx.Graph = networkx.Graph() for u, v in edge_index.t().tolist(): if u == v: continue else: g.add_edge(u, v) return g def __transform_data(self, data): if not ( hasattr(data, "edge_index") and torch.is_tensor(data.edge_index) and isinstance(data.edge_index, torch.Tensor) and data.edge_index.dim() == data.edge_index.size(0) == 2 and data.edge_index.dtype == torch.long ): raise TypeError("Unsupported provided data") dsc: torch.Tensor = self.__extract(self.__edge_index_to_nx_graph(data.edge_index)) if hasattr(data, 'gf') and isinstance(data.gf, torch.Tensor): gf = data.gf.view(-1) data.gf = torch.cat([gf, dsc]) else: data.gf = dsc return data def _transform(self, data): if isinstance(data, GeneralStaticGraph): return self.__transform_homogeneous_static_graph(data) else: return self.__transform_data(data)