Source code for autogllight.nas.algorithm.graphnas_rl

import torch
import torch.nn as nn
import logging

from .base import BaseNAS
from ..space import (
    BaseSpace,
    replace_layer_choice,
    replace_input_choice,
    get_module_order,
    sort_replaced_module,
    PathSamplingInputChoice,
    PathSamplingLayerChoice,
    apply_fixed_architecture,
)
from tqdm import tqdm
from datetime import datetime
import numpy as np
from .rl_utils import ReinforceController, ReinforceField

LOGGER = logging.getLogger(__name__)


[docs]class GraphNasRL(BaseNAS): """ RL in GraphNas. Parameters ---------- device : torch.device ``torch.device("cpu")`` or ``torch.device("cuda")``. num_epochs : int Number of epochs planned for training. log_frequency : int Step count per logging. grad_clip : float Gradient clipping. Set to 0 to disable. Default: 5. entropy_weight : float Weight of sample entropy loss. skip_weight : float Weight of skip penalty loss. baseline_decay : float Decay factor of baseline. New baseline will be equal to ``baseline_decay * baseline_old + reward * (1 - baseline_decay)``. ctrl_lr : float Learning rate for RL controller. ctrl_steps_aggregate : int Number of steps that will be aggregated into one mini-batch for RL controller. ctrl_steps : int Number of mini-batches for each epoch of RL controller learning. ctrl_kwargs : dict Optional kwargs that will be passed to :class:`ReinforceController`. n_warmup : int Number of epochs for training super network. model_lr : float Learning rate for super network. model_wd : float Weight decay for super network. topk : int Number of architectures kept in training process. disable_progeress: boolean Control whether show the progress bar. """ def __init__( self, device="auto", num_epochs=10, log_frequency=None, grad_clip=5.0, entropy_weight=0.0001, skip_weight=0, baseline_decay=0.95, ctrl_lr=0.00035, ctrl_steps_aggregate=100, ctrl_kwargs=None, n_warmup=100, model_lr=5e-3, model_wd=5e-4, topk=5, disable_progress=False, hardware_metric_limit=None, weight_share=True, ): super().__init__(device) self.num_epochs = num_epochs self.log_frequency = log_frequency self.entropy_weight = entropy_weight self.skip_weight = skip_weight self.baseline_decay = baseline_decay self.ctrl_steps_aggregate = ctrl_steps_aggregate self.grad_clip = grad_clip self.ctrl_kwargs = ctrl_kwargs self.ctrl_lr = ctrl_lr self.n_warmup = n_warmup self.model_lr = model_lr self.model_wd = model_wd self.hist = [] self.topk = topk self.disable_progress = disable_progress self.hardware_metric_limit = hardware_metric_limit self.weight_share = weight_share
[docs] def search(self, space: BaseSpace, dset, estimator): self.model = space self.dataset = dset # .to(self.device) self.estimator = estimator # replace choice self.nas_modules = [] k2o = get_module_order(self.model) replace_layer_choice(self.model, PathSamplingLayerChoice, self.nas_modules) replace_input_choice(self.model, PathSamplingInputChoice, self.nas_modules) self.nas_modules = sort_replaced_module(k2o, self.nas_modules) # to device self.model = self.model.to(self.device) # fields self.nas_fields = [ ReinforceField( name, len(module), isinstance(module, PathSamplingLayerChoice) or module.n_chosen == 1, ) for name, module in self.nas_modules ] self.controller = ReinforceController( self.nas_fields, lstm_size=100, temperature=5.0, tanh_constant=2.5, **(self.ctrl_kwargs or {}), ) self.ctrl_optim = torch.optim.Adam( self.controller.parameters(), lr=self.ctrl_lr ) # train with tqdm(range(self.num_epochs), disable=self.disable_progress) as bar: for i in bar: l2 = self._train_controller(i) bar.set_postfix(reward_controller=l2) # selection=self.export() selections = [x[1] for x in self.hist] candidiate_accs = [-x[0] for x in self.hist] # print('candidiate accuracies',candidiate_accs) selection = self._choose_best(selections) arch = space.parse_model(selection) # print(selection,arch) return arch
def _choose_best(self, selections): # graphnas use top 5 models, can evaluate 20 times epoch and choose the best. results = [] for selection in selections: accs = [] for i in tqdm(range(20), disable=self.disable_progress): self.arch = self.model.parse_model(selection) metric, loss = self._infer(mask="val") metric = metric["acc"] accs.append(metric) result = np.mean(accs) LOGGER.info( "selection {} \n acc {:.4f} +- {:.4f}".format( selection, np.mean(accs), np.std(accs) / np.sqrt(20) ) ) results.append(result) best_selection = selections[np.argmax(results)] return best_selection def _train_controller(self, epoch): self.model.eval() self.controller.train() self.ctrl_optim.zero_grad() rewards = [] baseline = None # diff: graph nas train 100 and derive 100 for every epoch(10 epochs), we just train 100(20 epochs). totol num of samples are same (2000) with tqdm( range(self.ctrl_steps_aggregate), disable=self.disable_progress ) as bar: for ctrl_step in bar: self._resample() metric, loss = self._infer(mask="val") reward = metric["acc"] # bar.set_postfix(acc=metric,loss=loss.item()) LOGGER.debug(f"{self.selection}\n{metric},{loss}") # diff: not do reward shaping as in graphnas code self.hist.append([-reward, self.selection]) if len(self.hist) > self.topk: self.hist.sort(key=lambda x: x[0]) self.hist.pop() rewards.append(reward) if self.entropy_weight: reward += ( self.entropy_weight * self.controller.sample_entropy.item() ) if not baseline: baseline = reward else: baseline = baseline * self.baseline_decay + reward * ( 1 - self.baseline_decay ) loss = self.controller.sample_log_prob * (reward - baseline) self.ctrl_optim.zero_grad() loss.backward() self.ctrl_optim.step() bar.set_postfix(max_acc=max(rewards), **metric) LOGGER.info( "epoch:{}, mean rewards:{}".format(epoch, sum(rewards) / len(rewards)) ) return sum(rewards) / len(rewards) def _resample(self): result = self.controller.resample() if self.weight_share: for name, module in self.nas_modules: module.sampled = result[name] else: self.arch = self.model.parse_model(result) self.selection = result def export(self): self.controller.eval() with torch.no_grad(): return self.controller.resample() def _infer(self, mask="train"): if self.weight_share: metric, loss = self.estimator.infer(self.model, self.dataset, mask=mask) else: metric, loss = self.estimator.infer( self.arch._model, self.dataset, mask=mask ) return metric, loss