# Modified from NNI
import logging
import torch
import torch.optim
import torch.nn as nn
import torch.nn.functional as F
from . import register_nas_algo
from .base import BaseNAS
from ..estimator.base import BaseEstimator
from ..space import BaseSpace
from ..utils import replace_layer_choice, replace_input_choice
# from nni.retiarii.oneshot.pytorch.darts import DartsLayerChoice, DartsInputChoice
_logger = logging.getLogger(__name__)
# copy from nni2.1 for stablility
class DartsLayerChoice(nn.Module):
def __init__(self, layer_choice):
super(DartsLayerChoice, self).__init__()
self.name = layer_choice.key
self.op_choices = nn.ModuleDict(layer_choice.named_children())
self.alpha = nn.Parameter(torch.randn(len(self.op_choices)) * 1e-3)
def forward(self, *args, **kwargs):
op_results = torch.stack([op(*args, **kwargs) for op in self.op_choices.values()])
alpha_shape = [-1] + [1] * (len(op_results.size()) - 1)
return torch.sum(op_results * F.softmax(self.alpha, -1).view(*alpha_shape), 0)
def parameters(self):
for _, p in self.named_parameters():
yield p
def named_parameters(self):
for name, p in super(DartsLayerChoice, self).named_parameters():
if name == 'alpha':
continue
yield name, p
def export(self):
return torch.argmax(self.alpha).item()
class DartsInputChoice(nn.Module):
def __init__(self, input_choice):
super(DartsInputChoice, self).__init__()
self.name = input_choice.key
self.alpha = nn.Parameter(torch.randn(input_choice.n_candidates) * 1e-3)
self.n_chosen = input_choice.n_chosen or 1
def forward(self, inputs):
inputs = torch.stack(inputs)
alpha_shape = [-1] + [1] * (len(inputs.size()) - 1)
return torch.sum(inputs * F.softmax(self.alpha, -1).view(*alpha_shape), 0)
def parameters(self):
for _, p in self.named_parameters():
yield p
def named_parameters(self):
for name, p in super(DartsInputChoice, self).named_parameters():
if name == 'alpha':
continue
yield name, p
def export(self):
return torch.argsort(-self.alpha).cpu().numpy().tolist()[:self.n_chosen]
[docs]@register_nas_algo("darts")
class Darts(BaseNAS):
"""
DARTS trainer.
Parameters
----------
num_epochs : int
Number of epochs planned for training.
workers : int
Workers for data loading.
gradient_clip : float
Gradient clipping. Set to 0 to disable. Default: 5.
model_lr : float
Learning rate to optimize the model.
model_wd : float
Weight decay to optimize the model.
arch_lr : float
Learning rate to optimize the architecture.
arch_wd : float
Weight decay to optimize the architecture.
device : str or torch.device
The device of the whole process
"""
def __init__(
self,
num_epochs=5,
workers=4,
gradient_clip=5.0,
model_lr=1e-3,
model_wd=5e-4,
arch_lr=3e-4,
arch_wd=1e-3,
device="auto",
):
super().__init__(device=device)
self.num_epochs = num_epochs
self.workers = workers
self.gradient_clip = gradient_clip
self.model_optimizer = torch.optim.Adam
self.arch_optimizer = torch.optim.Adam
self.model_lr = model_lr
self.model_wd = model_wd
self.arch_lr = arch_lr
self.arch_wd = arch_wd
[docs] def search(self, space: BaseSpace, dataset, estimator):
model_optim = self.model_optimizer(
space.parameters(), self.model_lr, weight_decay=self.model_wd
)
nas_modules = []
replace_layer_choice(space, DartsLayerChoice, nas_modules)
replace_input_choice(space, DartsInputChoice, nas_modules)
space = space.to(self.device)
ctrl_params = {}
for _, m in nas_modules:
if m.name in ctrl_params:
assert (
m.alpha.size() == ctrl_params[m.name].size()
), "Size of parameters with the same label should be same."
m.alpha = ctrl_params[m.name]
else:
ctrl_params[m.name] = m.alpha
arch_optim = self.arch_optimizer(
list(ctrl_params.values()), self.arch_lr, weight_decay=self.arch_wd
)
for epoch in range(self.num_epochs):
self._train_one_epoch(
epoch, space, dataset, estimator, model_optim, arch_optim
)
selection = self.export(nas_modules)
return space.parse_model(selection, self.device)
def _train_one_epoch(
self,
epoch,
model: BaseSpace,
dataset,
estimator,
model_optim: torch.optim.Optimizer,
arch_optim: torch.optim.Optimizer,
):
model.train()
# phase 1. architecture step
arch_optim.zero_grad()
# only no unroll here
_, loss = self._infer(model, dataset, estimator, "val")
loss.backward()
arch_optim.step()
# phase 2: child network step
model_optim.zero_grad()
metric, loss = self._infer(model, dataset, estimator, "train")
loss.backward()
# gradient clipping
if self.gradient_clip > 0:
nn.utils.clip_grad_norm_(model.parameters(), self.gradient_clip)
model_optim.step()
def _infer(self, model: BaseSpace, dataset, estimator: BaseEstimator, mask="train"):
metric, loss = estimator.infer(model, dataset, mask=mask)
return metric, loss
@torch.no_grad()
def export(self, nas_modules) -> dict:
result = dict()
for name, module in nas_modules:
if name not in result:
result[name] = module.export()
return result