Source code for autogl.module.nas.algorithm.gasso

# "Graph differentiable architecture search with structure optimization" NeurIPS 21'

import logging

import torch
import torch.optim
import torch.nn as nn
import torch.nn.functional as F

from . import register_nas_algo
from .base import BaseNAS
from ..estimator.base import BaseEstimator
from ..space import BaseSpace
from ..utils import replace_layer_choice, replace_input_choice
from ...model.base import BaseAutoModel

from torch.autograd import Variable
import numpy as np
import time
import copy
import torch.optim as optim
import scipy.sparse as sp

_logger = logging.getLogger(__name__)

[docs]@register_nas_algo("gasso") class Gasso(BaseNAS): """ GASSO trainer. Parameters ---------- num_epochs : int Number of epochs planned for training. warmup_epochs : int Number of epochs planned for warming up. workers : int Workers for data loading. model_lr : float Learning rate to optimize the model. model_wd : float Weight decay to optimize the model. arch_lr : float Learning rate to optimize the architecture. stru_lr : float Learning rate to optimize the structure. lamb : float The parameter to control the influence of hidden feature smoothness device : str or torch.device The device of the whole process """ def __init__( self, num_epochs=250, warmup_epochs=10, model_lr=0.01, model_wd=1e-4, arch_lr = 0.03, stru_lr = 0.04, lamb = 0.6, device="auto", ): super().__init__(device=device) self.num_epochs = num_epochs self.warmup_epochs = warmup_epochs self.model_lr = model_lr self.model_wd = model_wd self.arch_lr = arch_lr self.stru_lr = stru_lr self.lamb = lamb def train_stru(self, model, optimizer, data): # forward model.train() data[0].adj = self.adjs logits = model(data[0]).detach() loss = 0 for adj in self.adjs: e1 = adj[0][0] e2 = adj[0][1] ew = adj[1] diff = (logits[e1] - logits[e2]).pow(2).sum(1) smooth = (diff * torch.sigmoid(ew)).sum() dist = (ew * ew).sum() loss += self.lamb * smooth + dist optimizer.zero_grad() loss.backward() optimizer.step() train_loss = loss.item() del logits def _infer(self, model: BaseSpace, dataset, estimator: BaseEstimator, mask="train"): dataset[0].adj = self.adjs metric, loss = estimator.infer(model, dataset, mask=mask) return metric, loss
[docs] def prepare(self, dset): """Train Pro-GNN. """ data = dset[0] self.ews = [] self.edges = data.edge_index.to(self.device) edge_weight = torch.ones(self.edges.size(1)).to(self.device) self.adjs = [] for i in range(self.steps): edge_weight = Variable(edge_weight * 1.0, requires_grad = True).to(self.device) self.ews.append(edge_weight) self.adjs.append((self.edges, edge_weight))
def fit(self, data): self.optimizer = optim.Adam(self.space.parameters(), lr=self.model_lr, weight_decay=self.model_wd) self.arch_optimizer = optim.Adam(self.space.arch_parameters(), lr=self.arch_lr, betas=(0.5, 0.999)) self.stru_optimizer = optim.SGD(self.ews, lr=self.stru_lr) # Train model best_performance = 0 min_val_loss = float("inf") min_train_loss = float("inf") t_total = time.time() for epoch in range(self.num_epochs): self.space.train() self.optimizer.zero_grad() _, loss = self._infer(self.space, data, self.estimator, "train") loss.backward() self.optimizer.step() if epoch <20: continue self.train_stru(self.space, self.stru_optimizer, data) self.arch_optimizer.zero_grad() _, loss = self._infer(self.space, data, self.estimator, "train") loss.backward() self.arch_optimizer.step() self.space.eval() train_acc, _ = self._infer(self.space, data, self.estimator, "train") val_acc, val_loss = self._infer(self.space, data, self.estimator, "val") if val_loss < min_val_loss: min_val_loss = val_loss best_performance = val_acc self.space.keep_prediction() #print("acc:" + str(train_acc) + " val_acc" + str(val_acc)) return best_performance, min_val_loss
[docs] def search(self, space: BaseSpace, dataset, estimator): self.estimator = estimator self.space = space.to(self.device) self.steps = space.steps self.prepare(dataset) perf, val_loss = self.fit(dataset) return space.parse_model(None, self.device)