Source code for autogl.module.model.pyg.graph_saint

import typing as _typing
import torch.nn.functional
from torch_geometric.nn.conv import MessagePassing
from torch_sparse import SparseTensor, matmul

from . import register_model
from .base import ClassificationModel, ClassificationSupportedSequentialModel


class _GraphSAINTAggregationLayers:
    class MultiOrderAggregationLayer(torch.nn.Module):
        class Order0Aggregator(torch.nn.Module):
            def __init__(
                self,
                input_dimension: int,
                output_dimension: int,
                bias: bool = True,
                activation: _typing.Optional[str] = "ReLU",
                batch_norm: bool = True,
            ):
                super().__init__()
                if not type(input_dimension) == type(output_dimension) == int:
                    raise TypeError
                if not (input_dimension > 0 and output_dimension > 0):
                    raise ValueError
                if not type(bias) == bool:
                    raise TypeError
                self.__linear_transform = torch.nn.Linear(
                    input_dimension, output_dimension, bias
                )
                self.__linear_transform.reset_parameters()
                if type(activation) == str:
                    if activation.lower() == "ReLU".lower():
                        self.__activation = torch.nn.functional.relu
                    elif activation.lower() == "elu":
                        self.__activation = torch.nn.functional.elu
                    elif hasattr(torch.nn.functional, activation) and callable(
                        getattr(torch.nn.functional, activation)
                    ):
                        self.__activation = getattr(torch.nn.functional, activation)
                    else:
                        self.__activation = lambda x: x
                else:
                    self.__activation = lambda x: x
                if type(batch_norm) != bool:
                    raise TypeError
                else:
                    self.__optional_batch_normalization: _typing.Optional[
                        torch.nn.BatchNorm1d
                    ] = (
                        torch.nn.BatchNorm1d(output_dimension, 1e-8)
                        if batch_norm
                        else None
                    )

            def forward(
                self,
                x: _typing.Union[
                    torch.Tensor, _typing.Tuple[torch.Tensor, torch.Tensor]
                ],
                _edge_index: torch.Tensor,
                _edge_weight: _typing.Optional[torch.Tensor] = None,
                _size: _typing.Optional[_typing.Tuple[int, int]] = None,
            ) -> torch.Tensor:
                __output: torch.Tensor = self.__linear_transform(x)
                if self.__activation is not None and callable(self.__activation):
                    __output: torch.Tensor = self.__activation(__output)
                if self.__optional_batch_normalization is not None and isinstance(
                    self.__optional_batch_normalization, torch.nn.BatchNorm1d
                ):
                    __output: torch.Tensor = self.__optional_batch_normalization(
                        __output
                    )
                return __output

        class Order1Aggregator(MessagePassing):
            def __init__(
                self,
                input_dimension: int,
                output_dimension: int,
                bias: bool = True,
                activation: _typing.Optional[str] = "ReLU",
                batch_norm: bool = True,
            ):
                super().__init__(aggr="add")
                if not type(input_dimension) == type(output_dimension) == int:
                    raise TypeError
                if not (input_dimension > 0 and output_dimension > 0):
                    raise ValueError
                if not type(bias) == bool:
                    raise TypeError
                self.__linear_transform = torch.nn.Linear(
                    input_dimension, output_dimension, bias
                )
                self.__linear_transform.reset_parameters()
                if type(activation) == str:
                    if activation.lower() == "ReLU".lower():
                        self.__activation = torch.nn.functional.relu
                    elif activation.lower() == "elu":
                        self.__activation = torch.nn.functional.elu
                    elif hasattr(torch.nn.functional, activation) and callable(
                        getattr(torch.nn.functional, activation)
                    ):
                        self.__activation = getattr(torch.nn.functional, activation)
                    else:
                        self.__activation = lambda x: x
                else:
                    self.__activation = lambda x: x
                if type(batch_norm) != bool:
                    raise TypeError
                else:
                    self.__optional_batch_normalization: _typing.Optional[
                        torch.nn.BatchNorm1d
                    ] = (
                        torch.nn.BatchNorm1d(output_dimension, 1e-8)
                        if batch_norm
                        else None
                    )

            def forward(
                self,
                x: _typing.Union[
                    torch.Tensor, _typing.Tuple[torch.Tensor, torch.Tensor]
                ],
                _edge_index: torch.Tensor,
                _edge_weight: _typing.Optional[torch.Tensor] = None,
                _size: _typing.Optional[_typing.Tuple[int, int]] = None,
            ) -> torch.Tensor:

                if type(x) == torch.Tensor:
                    x: _typing.Tuple[torch.Tensor, torch.Tensor] = (x, x)

                __output = self.propagate(
                    _edge_index, x=x, edge_weight=_edge_weight, size=_size
                )
                __output: torch.Tensor = self.__linear_transform(__output)
                if self.__activation is not None and callable(self.__activation):
                    __output: torch.Tensor = self.__activation(__output)
                if self.__optional_batch_normalization is not None and isinstance(
                    self.__optional_batch_normalization, torch.nn.BatchNorm1d
                ):
                    __output: torch.Tensor = self.__optional_batch_normalization(
                        __output
                    )
                return __output

            def message(
                self, x_j: torch.Tensor, edge_weight: _typing.Optional[torch.Tensor]
            ) -> torch.Tensor:
                return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j

            def message_and_aggregate(
                self,
                adj_t: SparseTensor,
                x: _typing.Union[
                    torch.Tensor, _typing.Tuple[torch.Tensor, torch.Tensor]
                ],
            ) -> torch.Tensor:
                return matmul(adj_t, x[0], reduce=self.aggr)

        @property
        def integral_output_dimension(self) -> int:
            return (self._order + 1) * self._each_order_output_dimension

        def __init__(
            self,
            _input_dimension: int,
            _each_order_output_dimension: int,
            _order: int,
            bias: bool = True,
            activation: _typing.Optional[str] = "ReLU",
            batch_norm: bool = True,
            _dropout: _typing.Optional[float] = ...,
        ):
            super().__init__()
            if not (
                type(_input_dimension) == type(_order) == int
                and type(_each_order_output_dimension) == int
            ):
                raise TypeError
            if _input_dimension <= 0 or _each_order_output_dimension <= 0:
                raise ValueError
            if _order not in (0, 1):
                raise ValueError("Unsupported order number")
            self._input_dimension: int = _input_dimension
            self._each_order_output_dimension: int = _each_order_output_dimension
            self._order: int = _order
            if type(bias) != bool:
                raise TypeError
            self.__order0_transform = self.Order0Aggregator(
                self._input_dimension,
                self._each_order_output_dimension,
                bias,
                activation,
                batch_norm,
            )
            if _order == 1:
                self.__order1_transform = self.Order1Aggregator(
                    self._input_dimension,
                    self._each_order_output_dimension,
                    bias,
                    activation,
                    batch_norm,
                )
            else:
                self.__order1_transform = None
            if _dropout is not None and type(_dropout) == float:
                if _dropout < 0:
                    _dropout = 0
                if _dropout > 1:
                    _dropout = 1
                self.__optional_dropout: _typing.Optional[
                    torch.nn.Dropout
                ] = torch.nn.Dropout(_dropout)
            else:
                self.__optional_dropout: _typing.Optional[torch.nn.Dropout] = None

        def _forward(
            self,
            x: _typing.Union[torch.Tensor, _typing.Tuple[torch.Tensor, torch.Tensor]],
            edge_index: torch.Tensor,
            edge_weight: _typing.Optional[torch.Tensor] = None,
            size: _typing.Optional[_typing.Tuple[int, int]] = None,
        ) -> torch.Tensor:
            if self.__order1_transform is not None and isinstance(
                self.__order1_transform, self.Order1Aggregator
            ):
                __output: torch.Tensor = torch.cat(
                    [
                        self.__order0_transform(x, edge_index, edge_weight, size),
                        self.__order1_transform(x, edge_index, edge_weight, size),
                    ],
                    dim=1,
                )
            else:
                __output: torch.Tensor = self.__order0_transform(
                    x, edge_index, edge_weight, size
                )
            if self.__optional_dropout is not None and isinstance(
                self.__optional_dropout, torch.nn.Dropout
            ):
                __output: torch.Tensor = self.__optional_dropout(__output)
            return __output

        def forward(self, data) -> torch.Tensor:
            x: torch.Tensor = getattr(data, "x")
            if type(x) != torch.Tensor:
                raise TypeError
            edge_index: torch.LongTensor = getattr(data, "edge_index")
            if type(edge_index) != torch.Tensor:
                raise TypeError
            edge_weight: _typing.Optional[torch.Tensor] = getattr(
                data, "edge_weight", None
            )
            if edge_weight is not None and type(edge_weight) != torch.Tensor:
                raise TypeError
            return self._forward(x, edge_index, edge_weight)

    class WrappedDropout(torch.nn.Module):
        def __init__(self, dropout_module: torch.nn.Dropout):
            super().__init__()
            self.__dropout_module: torch.nn.Dropout = dropout_module

        def forward(self, tenser_or_data) -> torch.Tensor:
            if type(tenser_or_data) == torch.Tensor:
                return self.__dropout_module(tenser_or_data)
            elif (
                hasattr(tenser_or_data, "x")
                and type(getattr(tenser_or_data, "x")) == torch.Tensor
            ):
                return self.__dropout_module(getattr(tenser_or_data, "x"))
            else:
                raise TypeError


class GraphSAINTMultiOrderAggregationModel(ClassificationSupportedSequentialModel):
    def __init__(
        self,
        num_features: int,
        num_classes: int,
        _output_dimension_for_each_order: int,
        _layers_order_list: _typing.Sequence[int],
        _pre_dropout: float,
        _layers_dropout: _typing.Union[float, _typing.Sequence[float]],
        activation: _typing.Optional[str] = "ReLU",
        bias: bool = True,
        batch_norm: bool = True,
        normalize: bool = True,
    ):
        super(GraphSAINTMultiOrderAggregationModel, self).__init__()
        if type(_output_dimension_for_each_order) != int:
            raise TypeError
        if not _output_dimension_for_each_order > 0:
            raise ValueError
        self._layers_order_list: _typing.Sequence[int] = _layers_order_list

        if isinstance(_layers_dropout, _typing.Sequence):
            if len(_layers_dropout) != len(_layers_order_list):
                raise ValueError
            else:
                self._layers_dropout: _typing.Sequence[float] = _layers_dropout
        elif type(_layers_dropout) == float:
            if _layers_dropout < 0:
                _layers_dropout = 0
            if _layers_dropout > 1:
                _layers_dropout = 1
            self._layers_dropout: _typing.Sequence[float] = [
                _layers_dropout for _ in _layers_order_list
            ]
        else:
            raise TypeError
        if type(_pre_dropout) != float:
            raise TypeError
        else:
            if _pre_dropout < 0:
                _pre_dropout = 0
            if _pre_dropout > 1:
                _pre_dropout = 1
        self.__sequential_encoding_layers: torch.nn.ModuleList = torch.nn.ModuleList(
            (
                _GraphSAINTAggregationLayers.WrappedDropout(
                    torch.nn.Dropout(_pre_dropout)
                ),
                _GraphSAINTAggregationLayers.MultiOrderAggregationLayer(
                    num_features,
                    _output_dimension_for_each_order,
                    _layers_order_list[0],
                    bias,
                    activation,
                    batch_norm,
                    _layers_dropout[0],
                ),
            )
        )
        for _layer_index in range(1, len(_layers_order_list)):
            self.__sequential_encoding_layers.append(
                _GraphSAINTAggregationLayers.MultiOrderAggregationLayer(
                    self.__sequential_encoding_layers[-1].integral_output_dimension,
                    _output_dimension_for_each_order,
                    _layers_order_list[_layer_index],
                    bias,
                    activation,
                    batch_norm,
                    _layers_dropout[_layer_index],
                )
            )
        self.__apply_normalize: bool = normalize
        self.__linear_transform: torch.nn.Linear = torch.nn.Linear(
            self.__sequential_encoding_layers[-1].integral_output_dimension,
            num_classes,
            bias,
        )
        self.__linear_transform.reset_parameters()

    def cls_decode(self, x: torch.Tensor) -> torch.Tensor:
        if self.__apply_normalize:
            x: torch.Tensor = torch.nn.functional.normalize(x, p=2, dim=1)
        return torch.nn.functional.log_softmax(self.__linear_transform(x), dim=1)

    def cls_encode(self, data) -> torch.Tensor:
        if type(getattr(data, "x")) != torch.Tensor:
            raise TypeError
        if type(getattr(data, "edge_index")) != torch.Tensor:
            raise TypeError
        if (
            getattr(data, "edge_weight", None) is not None
            and type(getattr(data, "edge_weight")) != torch.Tensor
        ):
            raise TypeError
        for encoding_layer in self.__sequential_encoding_layers:
            setattr(data, "x", encoding_layer(data))
        return getattr(data, "x")

    @property
    def sequential_encoding_layers(self) -> torch.nn.ModuleList:
        return self.__sequential_encoding_layers


[docs]@register_model("GraphSAINTAggregationModel") class GraphSAINTAggregationModel(ClassificationModel): def __init__( self, num_features: int = ..., num_classes: int = ..., device: _typing.Union[str, torch.device] = ..., init: bool = False, **kwargs ): super(GraphSAINTAggregationModel, self).__init__( num_features, num_classes, device=device, init=init, **kwargs ) # todo: Initialize with default hyper parameter space and hyper parameter def _initialize(self): """ Initialize model """ self.model = GraphSAINTMultiOrderAggregationModel( self.num_features, self.num_classes, self.hyper_parameter.get("output_dimension_for_each_order"), self.hyper_parameter.get("layers_order_list"), self.hyper_parameter.get("pre_dropout"), self.hyper_parameter.get("layers_dropout"), self.hyper_parameter.get("activation", "ReLU"), bool(self.hyper_parameter.get("bias", True)), bool(self.hyper_parameter.get("batch_norm", True)), bool(self.hyper_parameter.get("normalize", True)), ).to(self.device)