Source code for autogllight.nas.algorithm.grna

# "Adversarially Robust Neural Architecture Search for Graph Neural Networks"
import logging
import torch
import torch.optim
from .spos import Evolution, UniformSampler, Spos
from ..space import (
    replace_layer_choice,
    replace_input_choice,
    get_module_order,
    sort_replaced_module,
    PathSamplingInputChoice,
    PathSamplingLayerChoice,
)

_logger = logging.getLogger(__name__)


[docs]class GRNA(Spos): """ GRNA trainer. Parameters ---------- n_warmup : int Number of epochs for training super network. model_lr : float Learning rate for super network. model_wd : float Weight decay for super network. Other parameters see Evolution """ def __init__( self, n_warmup=1000, grad_clip=5.0, disable_progress=False, optimize_mode="maximize", population_size=100, sample_size=25, cycles=20000, mutation_prob=0.05, device="cuda", ): super().__init__( n_warmup, grad_clip, disable_progress, optimize_mode, population_size, sample_size, cycles, mutation_prob, device, ) def _prepare(self): # replace choice self.nas_modules = [] k2o = get_module_order(self.model) replace_layer_choice(self.model, PathSamplingLayerChoice, self.nas_modules) replace_input_choice(self.model, PathSamplingInputChoice, self.nas_modules) self.nas_modules = sort_replaced_module(k2o, self.nas_modules) # to device self.model = self.model.to(self.device) self.model_optim = torch.optim.Adam( self.model.parameters(), lr=self.model_lr, weight_decay=self.model_wd ) # controller self.controller = UniformSampler(self.nas_modules) # Evolution self.evolve = Evolution( optimize_mode="maximize", population_size=self.population_size, sample_size=self.sample_size, cycles=self.cycles, mutation_prob=self.mutation_prob, disable_progress=self.disable_progress, ) def _infer(self, mask="train"): metric, loss = self.estimator.infer(self.model, self.dataset, mask=mask) return metric, loss