# codes in this file are reproduced from https://github.com/microsoft/nni with some changes.
import copy
from logging import Logger
from numpy.core.fromnumeric import sort
import torch
import torch.nn as nn
import torch.nn.functional as F
from . import register_nas_algo
from .base import BaseNAS
from ..space import BaseSpace
from ..utils import (
AverageMeterGroup,
replace_layer_choice,
replace_input_choice,
get_module_order,
sort_replaced_module,
PathSamplingLayerChoice,
PathSamplingInputChoice,
)
from tqdm import tqdm, trange
from ....utils import get_logger
import numpy as np
LOGGER = get_logger("SPOS")
import collections
import dataclasses
import random
@dataclasses.dataclass
class Individual:
"""
A class that represents an individual.
Holds two attributes, where ``x`` is the model and ``y`` is the metric (e.g., accuracy).
"""
x: dict
y: float
class Evolution:
"""
Algorithm for regularized evolution (i.e. aging evolution).
Follows "Algorithm 1" in Real et al. "Regularized Evolution for Image Classifier Architecture Search".
Parameters
----------
optimize_mode : str
Can be one of "maximize" and "minimize". Default: maximize.
population_size : int
The number of individuals to keep in the population. Default: 100.
cycles : int
The number of cycles (trials) the algorithm should run for. Default: 20000.
sample_size : int
The number of individuals that should participate in each tournament. Default: 25.
mutation_prob : float
Probability that mutation happens in each dim. Default: 0.05
"""
def __init__(self, optimize_mode='maximize', population_size=100, sample_size=25, cycles=20000,
mutation_prob=0.05,disable_progress=False):
assert optimize_mode in ['maximize', 'minimize']
assert sample_size < population_size
self.optimize_mode = optimize_mode
self.population_size = population_size
self.sample_size = sample_size
self.cycles = cycles
self.mutation_prob = mutation_prob
self.disable_progress= disable_progress
self._worst = float('-inf') if self.optimize_mode == 'maximize' else float('inf')
self._success_count = 0
self._population = collections.deque()
self._running_models = []
self._polling_interval = 2.
self._history = []
def best_parent(self,sample_size=None):
"""get the config of the best parent
"""
samples = [p for p in self._population] # copy population
random.shuffle(samples)
if sample_size is not None:
samples = list(samples)[:sample_size]
if self.optimize_mode == 'maximize':
parent = max(samples, key=lambda sample: sample.y)
else:
parent = min(samples, key=lambda sample: sample.y)
return parent.x
def _prepare(self):
self.uniform=UniformSampler(self.nas_modules)
self.mutation=MutationSampler(self.nas_modules,self.mutation_prob)
def _get_metric(self,config):
for name, module in self.nas_modules:
module.sampled = config[name]
# todo: this may be computational expensive
# model=self.model.parse_model(config,self.device)
with torch.no_grad():
metric, loss = self.estimator.infer(self.model, self.dataset, mask='val')
return metric[0]
def search(self, space: BaseSpace,nas_modules,dset,estimator,device):
self.model = space
self.dataset = dset
self.estimator = estimator
self.nas_modules = nas_modules
self.device = device
self._prepare()
LOGGER.info('Initializing the first population.')
with tqdm(range(self.population_size), disable=self.disable_progress) as bar:
for i in bar:
config = self.uniform.resample()
metric=self._get_metric(config)
individual = Individual(config, metric)
# LOGGER.debug('Individual created: %s', str(individual))
self._population.append(individual)
self._history.append(individual)
bar.set_postfix(metric=metric,max=max(x.y for x in self._population),min=min(x.y for x in self._population))
LOGGER.info('Running mutations.')
with tqdm(range(self.cycles), disable=self.disable_progress) as bar:
for i in bar:
parent=self.best_parent(self.sample_size)
config=self.mutation.resample(parent)
metric=self._get_metric(config) # todo : add aging factor
individual = Individual(config, metric)
LOGGER.debug('Individual created: %s', str(individual))
self._population.append(individual)
self._history.append(individual)
if len(self._population) > self.population_size:
self._population.popleft()
bar.set_postfix(metric=metric,max_h=max(x.y for x in self._history),max=max(x.y for x in self._population),min=min(x.y for x in self._population))
# todo: origin is best in history | or the population may need to be retrained
self._history.sort(key=lambda x: x.y)
# best=self.best_parent()
if self.optimize_mode == 'maximize':
best=self._history[-1].x
else:
best=self._history[0].x
return best
class MutationSampler:
"""uniform mutator
Parameters
----------
nas_modules:
nas_modules in NAS algorithms , including choices of modules
mutation_prob: float
probability of doing mutation in each choice.
parent : dict
parent individual's choices
"""
def __init__(self,nas_modules,mutation_prob):
selection_range = {}
for k, v in nas_modules:
selection_range[k] = len(v)
self.selection_dict = selection_range
self.mutation_prob = mutation_prob
def resample(self, parent):
search_space=self.selection_dict
child = {}
for k, v in parent.items():
if random.uniform(0, 1) < self.mutation_prob:
child[k] = np.random.choice(range(search_space[k])) # do not exclude the original operator
else:
child[k] = v
return child
class UniformSampler:
"""Uniform Sampler
Parameters
----------
nas_modules:
nas_modules in NAS algorithms , including choices of modules
"""
def __init__(self,nas_modules):
selection_range = {}
for k, v in nas_modules:
selection_range[k] = len(v)
self.selection_dict = selection_range
def resample(self):
selection = {}
for k, v in self.selection_dict.items():
selection[k] = np.random.choice(range(v))
return selection
[docs]@register_nas_algo("spos")
class Spos(BaseNAS):
"""
SPOS trainer.
Parameters
----------
n_warmup : int
Number of epochs for training super network.
model_lr : float
Learning rate for super network.
model_wd : float
Weight decay for super network.
Other parameters see Evolution
"""
def __init__(
self,
n_warmup=1000,
grad_clip=5.0,
disable_progress=False,
optimize_mode='maximize',
population_size=100,
sample_size=25,
cycles=20000,
mutation_prob=0.05,
device="cuda",
):
super().__init__(device)
self.model_lr=5e-3
self.model_wd=5e-4
self.n_warmup = n_warmup
self.disable_progress= disable_progress
self.grad_clip = grad_clip
self.optimize_mode = optimize_mode
self.population_size = population_size
self.sample_size = sample_size
self.cycles = cycles
self.mutation_prob = mutation_prob
def _prepare(self):
# replace choice
self.nas_modules = []
k2o = get_module_order(self.model)
replace_layer_choice(self.model, PathSamplingLayerChoice, self.nas_modules)
replace_input_choice(self.model, PathSamplingInputChoice, self.nas_modules)
self.nas_modules = sort_replaced_module(k2o, self.nas_modules)
# to device
self.model = self.model.to(self.device)
self.model_optim = torch.optim.Adam(
self.model.parameters(), lr=self.model_lr, weight_decay=self.model_wd
)
# controller
self.controller=UniformSampler(self.nas_modules)
# Evolution
self.evolve = Evolution(
optimize_mode='maximize',
population_size=self.population_size,
sample_size=self.sample_size,
cycles=self.cycles,
mutation_prob=self.mutation_prob,
disable_progress=self.disable_progress
)
[docs] def search(self, space: BaseSpace, dset, estimator):
self.model = space
self.dataset = dset
self.estimator = estimator
self._prepare()
self._train() # train using uniform sampling
self._search() # search using evolutionary algorithm
selection = self.export()
# here may sample N , retrain N ,and get best
print(selection)
return space.parse_model(selection, self.device)
def _search(self):
self.best_config=self.evolve.search(
self.model,
self.nas_modules,
self.dataset,
self.estimator,
self.device,
)
def _train(self):
with tqdm(range(self.n_warmup), disable=self.disable_progress) as bar:
for i in bar:
acc, l1 = self._train_one_epoch(i)
with torch.no_grad():
val_acc, val_loss = self._infer("val")
bar.set_postfix(loss=l1, acc=acc, val_acc=val_acc, val_loss=val_loss.item())
def _train_one_epoch(self, epoch):
self.model.train()
self.model_optim.zero_grad()
self._resample() # uniform sampling
metric, loss = self._infer(mask="train")
loss.backward()
if self.grad_clip > 0:
nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip)
self.model_optim.step()
return metric, loss.item()
def _resample(self):
result=self.controller.resample()
for name, module in self.nas_modules:
module.sampled = result[name]
def export(self):
return self.best_config
def _infer(self, mask="train"):
metric, loss = self.estimator.infer(self.model, self.dataset, mask=mask)
return metric[0], loss