import netlsd
import networkx
import torch
from autogl.data.graph import GeneralStaticGraph
from autogl.data.graph.utils import conversion
from .._base_feature_engineer import BaseFeatureEngineer
from .._feature_engineer_registry import FeatureEngineerUniversalRegistry
[docs]@FeatureEngineerUniversalRegistry.register_feature_engineer("NetLSD".lower())
class NetLSD(BaseFeatureEngineer):
r"""
Notes
-----
a graph feature generation method. This is a simple wrapper of NetLSD [#]_.
References
----------
.. [#] A. Tsitsulin, D. Mottin, P. Karras, A. Bronstein, and E. Müller, “NetLSD: Hearing the shape of a graph,”
Proc. ACM SIGKDD Int. Conf. Knowl. Discov. Data Min., pp. 2347–2356, 2018.
"""
def __init__(self, *args, **kwargs):
self.__args = args
self.__kwargs = kwargs
super(NetLSD, self).__init__()
def __extract(self, nx_g: networkx.Graph) -> torch.Tensor:
return torch.tensor(netlsd.heat(nx_g, *self.__args, **self.__kwargs)).view(-1)
def __transform_homogeneous_static_graph(
self, homogeneous_static_graph: GeneralStaticGraph
) -> GeneralStaticGraph:
if not (
homogeneous_static_graph.nodes.is_homogeneous and
homogeneous_static_graph.edges.is_homogeneous
):
raise ValueError("Provided static graph must be homogeneous")
dsc: torch.Tensor = self.__extract(
conversion.HomogeneousStaticGraphToNetworkX(to_undirected=True).__call__(
homogeneous_static_graph, to_undirected=True
)
)
if 'gf' in homogeneous_static_graph.data:
gf = homogeneous_static_graph.data['gf'].view(-1)
homogeneous_static_graph.data['gf'] = torch.cat([gf, dsc])
else:
homogeneous_static_graph.data['gf'] = dsc
return homogeneous_static_graph
@classmethod
def __edge_index_to_nx_graph(cls, edge_index: torch.Tensor) -> networkx.Graph:
g: networkx.Graph = networkx.Graph()
for u, v in edge_index.t().tolist():
if u == v:
continue
else:
g.add_edge(u, v)
return g
def __transform_data(self, data):
if not (
hasattr(data, "edge_index") and
torch.is_tensor(data.edge_index) and
isinstance(data.edge_index, torch.Tensor) and
data.edge_index.dim() == data.edge_index.size(0) == 2 and
data.edge_index.dtype == torch.long
):
raise TypeError("Unsupported provided data")
dsc: torch.Tensor = self.__extract(self.__edge_index_to_nx_graph(data.edge_index))
if hasattr(data, 'gf') and isinstance(data.gf, torch.Tensor):
gf = data.gf.view(-1)
data.gf = torch.cat([gf, dsc])
else:
data.gf = dsc
return data
def _transform(self, data):
if isinstance(data, GeneralStaticGraph):
return self.__transform_homogeneous_static_graph(data)
else:
return self.__transform_data(data)