Source code for autogllight.nas.space.gasso

import typing as _typ
from . import BaseSpace
import torch
from .gasso_space import *
from torch.autograd import Variable
from torch import nn
import torch.nn.functional as F

GNN_LIST = [
    "gat",  # GAT with 2 heads
    "gcn",  # GCN
    "gin",  # GIN
    # "cheb",  # chebnet
    "sage",  # sage
    # "arma",
    # "sg",  # simplifying gcn
    "linear",  # skip connection
    # "skip",  # skip connection
    # "zero",  # skip connection
]


class MixedOp(nn.Module):
    def __init__(self, in_c, out_c, gnn_list=GNN_LIST):
        super(MixedOp, self).__init__()
        self._ops = nn.ModuleList()
        self.gnn_list = gnn_list
        for action in gnn_list:
            self._ops.append(gnn_map(action, in_c, out_c))

    def forward(self, x, edge_index, edge_weight, weights, selected_idx=None):
        gnn_list = self.gnn_list
        if selected_idx is None:
            fin = []
            for w, op, op_name in zip(weights, self._ops, gnn_list):
                """if op_name == "gcn":
                    w = 1.0
                else:
                    continue"""
                if edge_weight == None:
                    fin.append(w * op(x, edge_index))
                else:
                    fin.append(w * op(x, edge_index, edge_weight=edge_weight))
            return sum(fin)
            # return sum(w * op(x, edge_index) for w, op in zip(weights, self._ops))
        else:  # unchosen operations are pruned
            return self._ops[selected_idx](x, edge_index)


def Get_edges(adjs,):
    edges = []
    edges_weights = []
    for adj in adjs:
        edges.append(adj[0])
        edges_weights.append(torch.sigmoid(adj[1]))
    return edges, edges_weights


class CellWS(nn.Module):
    def __init__(self, steps, his_dim, hidden_dim, out_dim, dp, bias=True):
        super(CellWS, self).__init__()
        self.steps = steps
        self._ops = nn.ModuleList()
        self._bns = nn.ModuleList()
        self.use2 = False
        self.dp = 0.8
        for i in range(self.steps):
            if i == 0:
                inpdim = his_dim
            else:
                inpdim = hidden_dim
            if i == self.steps - 1:
                oupdim = out_dim
            else:
                oupdim = hidden_dim
            op = MixedOp(inpdim, oupdim)
            self._ops.append(op)
            self._bns.append(nn.BatchNorm1d(oupdim))

    def forward(self, x, adjs, weights):
        edges, ews = Get_edges(adjs)
        for i in range(self.steps):
            if i > 0:
                x = F.relu(x)
                x = F.dropout(x, p=self.dp, training=self.training)
            x = self._ops[i](x, edges[i], ews[i], weights[i])  # call the gcn module
        return x


[docs]class GassoSpace(BaseSpace): def __init__( self, hidden_dim: _typ.Optional[int] = 64, layer_number: _typ.Optional[int] = 2, dropout: _typ.Optional[float] = 0.8, input_dim: _typ.Optional[int] = None, output_dim: _typ.Optional[int] = None, ops: _typ.Tuple = GNN_LIST, ): super().__init__() self.input_dim = input_dim self.output_dim = output_dim self.hidden_dim = hidden_dim self.steps = layer_number self.dropout = dropout self.ops = ops self.use_forward = True self.dead_tensor = torch.nn.Parameter( torch.FloatTensor([1]), requires_grad=True )
[docs] def build_graph(self): his_dim, cur_dim, hidden_dim, out_dim = ( self.input_dim, self.input_dim, self.hidden_dim, self.hidden_dim, ) self.cells = nn.ModuleList() self.cell = CellWS( self.steps, his_dim, hidden_dim, self.output_dim, self.dropout ) his_dim = cur_dim cur_dim = self.steps * out_dim self.classifier = nn.Linear(cur_dim, self.output_dim) self.initialize_alphas()
# def forward(self, x, adjs):
[docs] def forward(self, data, adjs): if self.use_forward: # x, adjs = data.x, data.adj x = data.x x = F.dropout(x, p=self.dropout, training=self.training) weights = [] for j in range(self.steps): weights.append(F.softmax(self.alphas_normal[j], dim=-1)) x = self.cell(x, adjs, weights) x = F.log_softmax(x, dim=1) self.current_pred = x.detach() return x else: # for i in self.parameters(): # print(i) x = self.prediction + self.dead_tensor * 0 return x
def keep_prediction(self): self.prediction = self.current_pred """def to(self, *args, **kwargs): fin = super().to(*args, **kwargs) device = next(fin.parameters()).device fin.alphas_normal = [i.to(device) for i in self.alphas_normal] return fin""" def initialize_alphas(self): num_ops = len(self.ops) self.alphas_normal = [] for i in range(self.steps): self.alphas_normal.append( Variable(1e-3 * torch.randn(num_ops), requires_grad=True) ) self._arch_parameters = [self.alphas_normal] def arch_parameters(self): return self.alphas_normal
[docs] def parse_model(self, selection): self.use_forward = False return self.wrap()