Source code for autogl.module.nas.algorithm.enas

# codes in this file are reproduced from https://github.com/microsoft/nni with some changes.
import copy

import torch
import torch.nn as nn
import torch.nn.functional as F

from . import register_nas_algo
from .base import BaseNAS
from ..space import BaseSpace
from ..utils import (
    AverageMeterGroup,
    replace_layer_choice,
    replace_input_choice,
    get_module_order,
    sort_replaced_module,
)
from tqdm import tqdm, trange
from .rl import (
    PathSamplingLayerChoice,
    PathSamplingInputChoice,
    ReinforceField,
    ReinforceController,
)
from ....utils import get_logger

LOGGER = get_logger("ENAS")


[docs]@register_nas_algo("enas") class Enas(BaseNAS): """ ENAS trainer. Parameters ---------- num_epochs : int Number of epochs planned for training. n_warmup : int Number of epochs for training super network. log_frequency : int Step count per logging. grad_clip : float Gradient clipping. Set to 0 to disable. Default: 5. entropy_weight : float Weight of sample entropy loss. skip_weight : float Weight of skip penalty loss. baseline_decay : float Decay factor of baseline. New baseline will be equal to ``baseline_decay * baseline_old + reward * (1 - baseline_decay)``. ctrl_lr : float Learning rate for RL controller. ctrl_steps_aggregate : int Number of steps that will be aggregated into one mini-batch for RL controller. ctrl_kwargs : dict Optional kwargs that will be passed to :class:`ReinforceController`. model_lr : float Learning rate for super network. model_wd : float Weight decay for super network. disable_progeress: boolean Control whether show the progress bar. device : str or torch.device The device of the whole process, e.g. "cuda", torch.device("cpu") """ def __init__( self, num_epochs=5, n_warmup=100, log_frequency=None, grad_clip=5.0, entropy_weight=0.0001, skip_weight=0.8, baseline_decay=0.999, ctrl_lr=0.00035, ctrl_steps_aggregate=20, ctrl_kwargs=None, model_lr=5e-3, model_wd=5e-4, disable_progress=True, device="auto", ): super().__init__(device) self.num_epochs = num_epochs self.log_frequency = log_frequency self.entropy_weight = entropy_weight self.skip_weight = skip_weight self.baseline_decay = baseline_decay self.baseline = 0.0 self.ctrl_steps_aggregate = ctrl_steps_aggregate self.grad_clip = grad_clip self.ctrl_kwargs = ctrl_kwargs self.ctrl_lr = ctrl_lr self.n_warmup = n_warmup self.model_lr = model_lr self.model_wd = model_wd self.disable_progress = disable_progress
[docs] def search(self, space: BaseSpace, dset, estimator): self.model = space self.dataset = dset # .to(self.device) self.estimator = estimator # replace choice self.nas_modules = [] k2o = get_module_order(self.model) replace_layer_choice(self.model, PathSamplingLayerChoice, self.nas_modules) replace_input_choice(self.model, PathSamplingInputChoice, self.nas_modules) self.nas_modules = sort_replaced_module(k2o, self.nas_modules) # to device self.model = self.model.to(self.device) self.model_optim = torch.optim.Adam( self.model.parameters(), lr=self.model_lr, weight_decay=self.model_wd ) # fields self.nas_fields = [ ReinforceField( name, len(module), isinstance(module, PathSamplingLayerChoice) or module.n_chosen == 1, ) for name, module in self.nas_modules ] self.controller = ReinforceController( self.nas_fields, **(self.ctrl_kwargs or {}) ) self.ctrl_optim = torch.optim.Adam( self.controller.parameters(), lr=self.ctrl_lr ) # warm up supernet with tqdm(range(self.n_warmup), disable=self.disable_progress) as bar: for i in bar: acc, l1 = self._train_model(i) with torch.no_grad(): val_acc, val_loss = self._infer("val") bar.set_postfix(loss=l1, acc=acc, val_acc=val_acc, val_loss=val_loss) # train with tqdm(range(self.num_epochs), disable=self.disable_progress) as bar: for i in bar: #try: l1 = self._train_model(i) l2 = self._train_controller(i) """except Exception as e: print(e) nm = self.nas_modules for i in range(len(nm)): print(nm[i][1].sampled)""" bar.set_postfix(loss_model=l1, reward_controller=l2) selection = self.export() # print(selection) return space.parse_model(selection, self.device)
def _train_model(self, epoch): self.model.train() self.controller.eval() self.model_optim.zero_grad() self._resample() metric, loss = self._infer() loss.backward() if self.grad_clip > 0: nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip) self.model_optim.step() return metric, loss.item() def _train_controller(self, epoch): self.model.eval() self.controller.train() self.ctrl_optim.zero_grad() rewards = [] for ctrl_step in range(self.ctrl_steps_aggregate): self._resample() with torch.no_grad(): metric, loss = self._infer(mask="val") reward = metric rewards.append(reward) if self.entropy_weight: reward += self.entropy_weight * self.controller.sample_entropy.item() self.baseline = self.baseline * self.baseline_decay + reward * ( 1 - self.baseline_decay ) loss = self.controller.sample_log_prob * (reward - self.baseline) if self.skip_weight: loss += self.skip_weight * self.controller.sample_skip_penalty loss /= self.ctrl_steps_aggregate loss.backward() if (ctrl_step + 1) % self.ctrl_steps_aggregate == 0: if self.grad_clip > 0: nn.utils.clip_grad_norm_( self.controller.parameters(), self.grad_clip ) self.ctrl_optim.step() self.ctrl_optim.zero_grad() if self.log_frequency is not None and ctrl_step % self.log_frequency == 0: LOGGER.info( "RL Epoch [%d/%d] Step [%d/%d] %s", epoch + 1, self.num_epochs, ctrl_step + 1, self.ctrl_steps_aggregate, ) return sum(rewards) / len(rewards) def _resample(self): result = self.controller.resample() for name, module in self.nas_modules: module.sampled = result[name] def export(self): self.controller.eval() with torch.no_grad(): return self.controller.resample() def _infer(self, mask="train"): metric, loss = self.estimator.infer(self.model, self.dataset, mask=mask) return metric[0], loss