Source code for autogllight.nas.algorithm.gauss

# "Large-Scale Graph Neural Architecture Search" ICML 22'

import random

import torch
import torch.nn.functional as F
from tqdm import trange

from ..estimator.base import BaseEstimator
from ..space import BaseSpace
from .base import BaseNAS


[docs]class Gauss(BaseNAS): """ GAUSS trainer. Parameters ---------- num_epochs : int Number of epochs planned for training. device : str or torch.device The device of the whole process """ def __init__( self, num_epochs=100, device="auto", disable_progress=False, args=None, ): super().__init__(device=device) self.device = device self.num_epochs = num_epochs self.disable_progress = disable_progress self.args = args def prepare(self, data): self.data = data.to(self.device) # fix random seed of train/val/test split random.seed(2022) masks = list(range(data.num_nodes)) random.shuffle(masks) fold = int(data.num_nodes * 0.1) train_idx = masks[:fold * 6] val_idx = masks[fold * 6: fold * 8] test_idx = masks[fold * 8:] split_idx = { 'train': torch.tensor(train_idx).long(), 'valid': torch.tensor(val_idx).long(), 'test': torch.tensor(test_idx).long() } for key in split_idx: split_idx[key] = split_idx[key].to(self.device) self.train_idx = split_idx['train'].to(self.device) self.valid_idx = split_idx['valid'].to(self.device) self.test_idx = split_idx['test'].to(self.device) def train_graph( self, optimizer, epoch ): self.space.train() optimizer.zero_grad() archs = self.space.sampler.samples(self.args.repeat) if self.args.use_curriculum: judgement = None best_acc = 0 if epoch < self.args.warm_up: # min_ratio = args.min_ratio[0] min_ratio = epoch / self.args.epochs * (self.args.min_ratio[1] - self.args.min_ratio[0]) + self.args.min_ratio[0] else: min_ratio = epoch / self.args.epochs * (self.args.min_ratio[1] - self.args.min_ratio[0]) + self.args.min_ratio[0] # min_ratio = args.min_ratio[1] archs = sorted(archs, key=lambda x:x[1]) for arch, score in archs: ratio = score if self.args.no_baseline else 1. out = self.space.model(self.data, arch)[self.train_idx] if self.args.min_clip > 0: ratio = max(ratio, self.args.min_clip) if self.args.max_clip > 0: ratio = min(ratio, self.args.max_clip) loss = F.nll_loss(out, self.data.y[self.train_idx], reduction="none") / self.args.repeat * ratio aggrement = (out.argmax(dim=1) == self.data.y[self.train_idx]) cur_acc = aggrement.float().mean() if self.args.use_curriculum and (judgement is None or cur_acc > best_acc): # cal the judgement judgement = torch.ones_like(loss).float() bar = 1 / self.space.num_classes wrong_idxs = (~aggrement).nonzero()[:, 0] # .squeeze() # pass by the bar distributions = torch.exp(out) try: wrong_idxs = wrong_idxs[distributions[wrong_idxs].max(dim=1)[0] > min(5 * bar, 0.7)] except: import pdb pdb.set_trace() sorted_idxs = distributions[wrong_idxs].max(dim=1)[0].sort(descending=True)[1][:int(self.args.max_ratio * out.size(0))] wrong_idxs = wrong_idxs[sorted_idxs] if min_ratio < 0: judgement = judgement.bool() judgement[wrong_idxs] = False else: judgement[wrong_idxs] = min_ratio loss = loss.mean() best_acc = cur_acc else: if not self.args.use_curriculum: loss = loss.mean() else: if min_ratio < 0: loss = loss[judgement].mean() else: loss = (loss * judgement).mean() loss.backward() optimizer.step() return loss.item(), cur_acc.item() def _infer(self, mask="train"): if mask == "train": mask = self.train_idx elif mask == "valid": mask = self.valid_idx else: mask = self.test_idx metric, loss = self.estimator.infer(self.space, self.data, self.args.arch, mask=mask) return metric, loss def fit(self): optimizer = torch.optim.Adam(self.space.model.parameters(), lr=self.args.lr, weight_decay=self.args.wd) # eta = self.args.eta best_performance = 0 min_val_loss = float("inf") with trange(self.num_epochs, disable=self.disable_progress) as bar: for epoch in bar: """ space training """ self.space.train() # eta = ( # self.args.eta_max - self.args.eta # ) * epoch / self.num_epochs + self.args.eta optimizer.zero_grad() train_loss, train_acc = self.train_graph( optimizer, epoch ) """ space evaluation """ self.space.eval() train_acc, _ = self._infer("train") val_acc, val_loss = self._infer("val") if min_val_loss > val_loss: min_val_loss, best_performance = val_loss, val_acc self.space.keep_prediction() bar.set_postfix(train_acc=train_acc["acc"], val_acc=val_acc["acc"]) return best_performance, min_val_loss
[docs] def search(self, space: BaseSpace, data, estimator: BaseEstimator): self.estimator = estimator self.space = space.to(self.device) self.prepare(data) perf, val_loss = self.fit() return space.parse_model(None)