# codes in this file are reproduced from https://github.com/microsoft/nni with some changes.
import torch
import torch.nn as nn
import torch.nn.functional as F
from .base import BaseNAS
from ..space import (
BaseSpace,
replace_layer_choice,
replace_input_choice,
get_module_order,
sort_replaced_module,
PathSamplingInputChoice,
PathSamplingLayerChoice,
apply_fixed_architecture,
)
from tqdm import tqdm
import numpy as np
import logging
LOGGER = logging.getLogger(__name__)
from .ea_utils import Individual, UniformSampler, MutationSampler
import collections
import random
class Evolution:
"""
Algorithm for regularized evolution (i.e. aging evolution).
Follows "Algorithm 1" in Real et al. "Regularized Evolution for Image Classifier Architecture Search".
Parameters
----------
optimize_mode : str
Can be one of "maximize" and "minimize". Default: maximize.
population_size : int
The number of individuals to keep in the population. Default: 100.
cycles : int
The number of cycles (trials) the algorithm should run for. Default: 20000.
sample_size : int
The number of individuals that should participate in each tournament. Default: 25.
mutation_prob : float
Probability that mutation happens in each dim. Default: 0.05
"""
def __init__(
self,
optimize_mode="maximize",
population_size=100,
sample_size=25,
cycles=20000,
mutation_prob=0.05,
disable_progress=False,
):
assert optimize_mode in ["maximize", "minimize"]
assert sample_size < population_size
self.optimize_mode = optimize_mode
self.population_size = population_size
self.sample_size = sample_size
self.cycles = cycles
self.mutation_prob = mutation_prob
self.disable_progress = disable_progress
self._worst = (
float("-inf") if self.optimize_mode == "maximize" else float("inf")
)
self._success_count = 0
self._population = collections.deque()
self._running_models = []
self._polling_interval = 2.0
self._history = []
def best_parent(self, sample_size=None):
"""get the config of the best parent
"""
samples = [p for p in self._population] # copy population
random.shuffle(samples)
if sample_size is not None:
samples = list(samples)[:sample_size]
if self.optimize_mode == "maximize":
parent = max(samples, key=lambda sample: sample.y)
else:
parent = min(samples, key=lambda sample: sample.y)
return parent.x
def _prepare(self):
self.uniform = UniformSampler(self.nas_modules)
self.mutation = MutationSampler(self.nas_modules, self.mutation_prob)
def _get_metric(self, config):
for name, module in self.nas_modules:
module.sampled = config[name]
# todo: this may be computational expensive
# model=self.model.parse_model(config,self.device)
with torch.no_grad():
metric, loss = self.estimator.infer(self.model, self.dataset, mask="val")
return metric["acc"]
def search(self, space: BaseSpace, nas_modules, dset, estimator, device):
self.model = space
self.dataset = dset
self.estimator = estimator
self.nas_modules = nas_modules
self.device = device
self._prepare()
LOGGER.info("Initializing the first population.")
with tqdm(range(self.population_size), disable=self.disable_progress) as bar:
for i in bar:
config = self.uniform.resample()
metric = self._get_metric(config)
individual = Individual(config, metric)
# LOGGER.debug('Individual created: %s', str(individual))
self._population.append(individual)
self._history.append(individual)
bar.set_postfix(
metric=metric,
max=max(x.y for x in self._population),
min=min(x.y for x in self._population),
)
LOGGER.info("Running mutations.")
with tqdm(range(self.cycles), disable=self.disable_progress) as bar:
for i in bar:
parent = self.best_parent(self.sample_size)
config = self.mutation.resample(parent)
metric = self._get_metric(config) # todo : add aging factor
individual = Individual(config, metric)
LOGGER.debug("Individual created: %s", str(individual))
self._population.append(individual)
self._history.append(individual)
if len(self._population) > self.population_size:
self._population.popleft()
bar.set_postfix(
metric=metric,
max_h=max(x.y for x in self._history),
max=max(x.y for x in self._population),
min=min(x.y for x in self._population),
)
# todo: origin is best in history | or the population may need to be retrained
self._history.sort(key=lambda x: x.y)
# best=self.best_parent()
if self.optimize_mode == "maximize":
best = self._history[-1].x
else:
best = self._history[0].x
return best
[docs]class Spos(BaseNAS):
"""
SPOS trainer.
Parameters
----------
n_warmup : int
Number of epochs for training super network.
model_lr : float
Learning rate for super network.
model_wd : float
Weight decay for super network.
Other parameters see Evolution
"""
def __init__(
self,
n_warmup=1000,
grad_clip=5.0,
disable_progress=False,
optimize_mode="maximize",
population_size=100,
sample_size=25,
cycles=20000,
mutation_prob=0.05,
device="cuda",
):
super().__init__(device)
self.model_lr = 5e-3
self.model_wd = 5e-4
self.n_warmup = n_warmup
self.disable_progress = disable_progress
self.grad_clip = grad_clip
self.optimize_mode = optimize_mode
self.population_size = population_size
self.sample_size = sample_size
self.cycles = cycles
self.mutation_prob = mutation_prob
def _prepare(self):
# replace choice
self.nas_modules = []
k2o = get_module_order(self.model)
replace_layer_choice(self.model, PathSamplingLayerChoice, self.nas_modules)
replace_input_choice(self.model, PathSamplingInputChoice, self.nas_modules)
self.nas_modules = sort_replaced_module(k2o, self.nas_modules)
# to device
self.model = self.model.to(self.device)
self.model_optim = torch.optim.Adam(
self.model.parameters(), lr=self.model_lr, weight_decay=self.model_wd
)
# controller
self.controller = UniformSampler(self.nas_modules)
# Evolution
self.evolve = Evolution(
optimize_mode="maximize",
population_size=self.population_size,
sample_size=self.sample_size,
cycles=self.cycles,
mutation_prob=self.mutation_prob,
disable_progress=self.disable_progress,
)
[docs] def search(self, space: BaseSpace, dset, estimator):
self.model = space
self.dataset = dset
self.estimator = estimator
self._prepare()
self._train() # train using uniform sampling
self._search() # search using evolutionary algorithm
selection = self.export()
# here may sample N , retrain N ,and get best
print(selection)
return space.parse_model(selection)
def _search(self):
self.best_config = self.evolve.search(
self.model, self.nas_modules, self.dataset, self.estimator, self.device,
)
def _train(self):
with tqdm(range(self.n_warmup), disable=self.disable_progress) as bar:
for i in bar:
acc, l1 = self._train_one_epoch(i)
with torch.no_grad():
val_acc, val_loss = self._infer("val")
bar.set_postfix(
loss=l1.item(),
acc=acc["acc"],
val_acc=val_acc["acc"],
val_loss=val_loss.item(),
)
def _train_one_epoch(self, epoch):
self.model.train()
self.model_optim.zero_grad()
self._resample() # uniform sampling
metric, loss = self._infer(mask="train")
loss.backward()
if self.grad_clip > 0:
nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip)
self.model_optim.step()
return metric, loss
def _resample(self):
result = self.controller.resample()
for name, module in self.nas_modules:
module.sampled = result[name]
def export(self):
return self.best_config
def _infer(self, mask="train"):
metric, loss = self.estimator.infer(self.model, self.dataset, mask=mask)
return metric, loss