import torch
import torch.nn.functional as F
from torch_geometric.nn import GraphConv, TopKPooling
from torch_geometric.nn import global_mean_pool as gap, global_max_pool as gmp
from . import register_model
from .base import BaseAutoModel, activate_func
from ....utils import get_logger

LOGGER = get_logger("TopkModel")

def set_default(args, d):
    for k, v in d.items():
        if k not in args:
            args[k] = v
    return args

class Topkpool(torch.nn.Module):
    def __init__(self, args):
        super(Topkpool, self).__init__()
        self.args = args

        missing_keys = list(
            - set(self.args.keys())
        if len(missing_keys) > 0:
            raise Exception("Missing keys: %s." % ",".join(missing_keys))

        self.num_features = self.args["features_num"]
        self.num_classes = self.args["num_class"]
        self.ratio = self.args["ratio"]
        self.dropout = self.args["dropout"]
        self.num_graph_features = self.args["num_graph_features"]

        self.conv1 = GraphConv(self.num_features, 128)
        self.pool1 = TopKPooling(128, ratio=self.ratio)
        self.conv2 = GraphConv(128, 128)
        self.pool2 = TopKPooling(128, ratio=self.ratio)
        self.conv3 = GraphConv(128, 128)
        self.pool3 = TopKPooling(128, ratio=self.ratio)

        self.lin1 = torch.nn.Linear(256 + self.num_graph_features, 128)
        self.lin2 = torch.nn.Linear(128, 64)
        self.lin3 = torch.nn.Linear(64, self.num_classes)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        if self.num_graph_features > 0:
            graph_feature =

        x = F.relu(self.conv1(x, edge_index))
        x, edge_index, _, batch, _, _ = self.pool1(x, edge_index, None, batch)
        x1 =[gmp(x, batch), gap(x, batch)], dim=1)

        x = F.relu(self.conv2(x, edge_index))
        x, edge_index, _, batch, _, _ = self.pool2(x, edge_index, None, batch)
        x2 =[gmp(x, batch), gap(x, batch)], dim=1)

        x = F.relu(self.conv3(x, edge_index))
        x, edge_index, _, batch, _, _ = self.pool3(x, edge_index, None, batch)
        x3 =[gmp(x, batch), gap(x, batch)], dim=1)

        x = x1 + x2 + x3
        if self.num_graph_features > 0:
            x =[x, graph_feature], dim=-1)
        x = self.lin1(x)
        x = activate_func(x, self.args["act"])
        x = F.dropout(x, p=self.dropout,
        x = self.lin2(x)
        x = activate_func(x, self.args["act"])
        x = F.log_softmax(self.lin3(x), dim=-1)

        return x

[docs]@register_model("topkpool-model") class AutoTopkpool(BaseAutoModel): r""" AutoTopkpool. The model used in this automodel is from, Parameters ---------- num_features: `int`. The dimension of features. num_classes: `int`. The number of classes. device: `torch.device` or `str` The device where model will be running on. init: `bool`. If True(False), the model will (not) be initialized. """ def __init__( self, num_features=None, num_classes=None, device=None, init=False, num_graph_features=0, **args ): super().__init__(num_features, num_classes, device, num_graph_features=num_graph_features, **args) self.num_graph_features = num_graph_features self.hyper_parameter_space = [ { "parameterName": "ratio", "type": "DOUBLE", "maxValue": 0.9, "minValue": 0.1, "scalingType": "LINEAR", }, { "parameterName": "dropout", "type": "DOUBLE", "maxValue": 0.9, "minValue": 0.1, "scalingType": "LINEAR", }, { "parameterName": "act", "type": "CATEGORICAL", "feasiblePoints": ["leaky_relu", "relu", "elu", "tanh"], }, ] self.hyper_parameters = {"ratio": 0.8, "dropout": 0.5, "act": "relu"} def from_hyper_parameter(self, hp, **kwargs): return super().from_hyper_parameter(hp, num_graph_features=self.num_graph_features, **kwargs) def _initialize(self): self._model = Topkpool({ "features_num": self.input_dimension, "num_class": self.output_dimension, "num_graph_features": self.num_graph_features, **self.hyper_parameters }).to(self.device)