Source code for autogllight.nas.algorithm.gasso

# "Graph differentiable architecture search with structure optimization" NeurIPS 21'

import logging
import torch
import torch.optim
from .base import BaseNAS
from ..estimator.base import BaseEstimator
from ..space import BaseSpace

from torch.autograd import Variable
import time
import torch.optim as optim
from tqdm import tqdm, trange

_logger = logging.getLogger(__name__)


[docs]class Gasso(BaseNAS): """ GASSO trainer. Parameters ---------- num_epochs : int Number of epochs planned for training. warmup_epochs : int Number of epochs planned for warming up. workers : int Workers for data loading. model_lr : float Learning rate to optimize the model. model_wd : float Weight decay to optimize the model. arch_lr : float Learning rate to optimize the architecture. stru_lr : float Learning rate to optimize the structure. lamb : float The parameter to control the influence of hidden feature smoothness device : str or torch.device The device of the whole process """ def __init__( self, num_epochs=250, warmup_epochs=10, model_lr=0.01, model_wd=1e-4, arch_lr=0.03, stru_lr=0.04, lamb=0.6, device="auto", disable_progress=False, ): super().__init__(device=device) self.num_epochs = num_epochs self.warmup_epochs = warmup_epochs self.model_lr = model_lr self.model_wd = model_wd self.arch_lr = arch_lr self.stru_lr = stru_lr self.lamb = lamb self.disable_progress = disable_progress def train_stru(self, model, optimizer, data): # forward model.train() # data[0].adj = self.adjs logits = model(data[0].to(self.device), self.adjs).detach() loss = 0 for adj in self.adjs: e1 = adj[0][0] e2 = adj[0][1] ew = adj[1] diff = (logits[e1] - logits[e2]).pow(2).sum(1) smooth = (diff * torch.sigmoid(ew)).sum() dist = (ew * ew).sum() loss += self.lamb * smooth + dist optimizer.zero_grad() loss.backward() optimizer.step() train_loss = loss.item() del logits def _infer(self, model: BaseSpace, dataset, estimator: BaseEstimator, mask="train"): # dataset[0].adj = self.adjs metric, loss = estimator.infer(model, dataset, mask=mask, adjs=self.adjs) return metric, loss
[docs] def prepare(self, dset): """Train Pro-GNN. """ data = dset[0] self.ews = [] self.edges = data.edge_index.to(self.device) edge_weight = torch.ones(self.edges.size(1)).to(self.device) self.adjs = [] for i in range(self.steps): edge_weight = Variable(edge_weight * 1.0, requires_grad=True).to( self.device ) self.ews.append(edge_weight) self.adjs.append((self.edges, edge_weight))
def fit(self, data): self.optimizer = optim.Adam( self.space.parameters(), lr=self.model_lr, weight_decay=self.model_wd ) self.arch_optimizer = optim.Adam( self.space.arch_parameters(), lr=self.arch_lr, betas=(0.5, 0.999) ) self.stru_optimizer = optim.SGD(self.ews, lr=self.stru_lr) # Train model best_performance = 0 min_val_loss = float("inf") min_train_loss = float("inf") t_total = time.time() with trange(self.num_epochs, disable=self.disable_progress) as bar: for epoch in bar: self.space.train() self.optimizer.zero_grad() _, loss = self._infer(self.space, data, self.estimator, "train") loss.backward() self.optimizer.step() if epoch < 20: continue self.train_stru(self.space, self.stru_optimizer, data) self.arch_optimizer.zero_grad() _, loss = self._infer(self.space, data, self.estimator, "train") loss.backward() self.arch_optimizer.step() self.space.eval() train_acc, _ = self._infer(self.space, data, self.estimator, "train") val_acc, val_loss = self._infer(self.space, data, self.estimator, "val") if val_loss < min_val_loss: min_val_loss = val_loss best_performance = val_acc self.space.keep_prediction() bar.set_postfix(train_acc=train_acc["acc"], val_acc=val_acc["acc"]) # print("acc:" + str(train_acc) + " val_acc" + str(val_acc)) return best_performance, min_val_loss
[docs] def search(self, space: BaseSpace, dataset, estimator): self.estimator = estimator self.space = space.to(self.device) self.steps = space.steps self.prepare(dataset) perf, val_loss = self.fit(dataset) return space.parse_model(None)