import numpy as np
import torch
import typing as _typing
from autogl.data.graph import GeneralStaticGraph
from .._base_feature_engineer import BaseFeatureEngineer
from .._feature_engineer_registry import FeatureEngineerUniversalRegistry
class BaseFeatureSelector(BaseFeatureEngineer):
def __init__(self):
self._selection : _typing.Optional[torch.Tensor] = None
super(BaseFeatureSelector, self).__init__()
def __transform_homogeneous_static_graph(
self, static_graph: GeneralStaticGraph
) -> GeneralStaticGraph:
if (
'x' in static_graph.nodes.data and
isinstance(self._selection, (torch.Tensor, np.ndarray))
):
static_graph.nodes.data['x'] = static_graph.nodes.data['x'][:, self._selection]
if (
'feat' in static_graph.nodes.data and
isinstance(self._selection, (torch.Tensor, np.ndarray))
):
static_graph.nodes.data['feat'] = static_graph.nodes.data['feat'][:, self._selection]
return static_graph
def _transform(
self, data: _typing.Union[GeneralStaticGraph, _typing.Any]
) -> _typing.Union[GeneralStaticGraph, _typing.Any]:
if isinstance(data, GeneralStaticGraph):
return self.__transform_homogeneous_static_graph(data)
elif (
hasattr(data, 'x') and isinstance(data.x, torch.Tensor) and
torch.is_tensor(data.x) and data.x.dim() == 2
):
data.x = data.x[:, self._selection]
return data
else:
return data
[docs]@FeatureEngineerUniversalRegistry.register_feature_engineer("FilterConstant")
class FilterConstant(BaseFeatureSelector):
r"""drop constant features"""
def _fit(self, static_graph: GeneralStaticGraph) -> GeneralStaticGraph:
if (
'x' in static_graph.nodes.data and
self._selection not in (Ellipsis, None) and
isinstance(self._selection, torch.Tensor) and
torch.is_tensor(self._selection) and self._selection.dim() == 1
):
feature: _typing.Optional[np.ndarray] = static_graph.nodes.data['x'].numpy()
elif (
'feat' in static_graph.nodes.data and
self._selection not in (Ellipsis, None) and
isinstance(self._selection, torch.Tensor) and
torch.is_tensor(self._selection) and self._selection.dim() == 1
):
feature: _typing.Optional[np.ndarray] = static_graph.nodes.data['feat'].numpy()
else:
feature: _typing.Optional[np.ndarray] = None
self._selection: _typing.Optional[torch.Tensor] = torch.from_numpy(
np.where(np.all(feature == feature[0, :], axis=0) == np.array(False))[0]
if feature is not None and isinstance(feature, np.ndarray) and feature.ndim == 2
else None
)
return static_graph