# codes in this file are reproduced from https://github.com/microsoft/nni with some changes.
import torch
import torch.nn as nn
import torch.nn.functional as F
from . import register_nas_algo
from .base import BaseNAS
from ..space import BaseSpace
from ..utils import (
AverageMeterGroup,
replace_layer_choice,
replace_input_choice,
get_module_order,
sort_replaced_module,
PathSamplingInputChoice,
PathSamplingLayerChoice,
)
from nni.nas.pytorch.fixed import apply_fixed_architecture
from tqdm import tqdm
from datetime import datetime
import numpy as np
from ....utils import get_logger
LOGGER = get_logger("RL_NAS")
def _get_mask(sampled, total):
multihot = [
i == sampled or (isinstance(sampled, list) and i in sampled)
for i in range(total)
]
return torch.tensor(multihot, dtype=torch.bool) # pylint: disable=not-callable
class StackedLSTMCell(nn.Module):
def __init__(self, layers, size, bias):
super().__init__()
self.lstm_num_layers = layers
self.lstm_modules = nn.ModuleList(
[nn.LSTMCell(size, size, bias=bias) for _ in range(self.lstm_num_layers)]
)
def forward(self, inputs, hidden):
prev_h, prev_c = hidden
next_h, next_c = [], []
for i, m in enumerate(self.lstm_modules):
curr_h, curr_c = m(inputs, (prev_h[i], prev_c[i]))
next_c.append(curr_c)
next_h.append(curr_h)
# current implementation only supports batch size equals 1,
# but the algorithm does not necessarily have this limitation
inputs = curr_h[-1].view(1, -1)
return next_h, next_c
class ReinforceField:
"""
A field with ``name``, with ``total`` choices. ``choose_one`` is true if one and only one is meant to be
selected. Otherwise, any number of choices can be chosen.
"""
def __init__(self, name, total, choose_one):
self.name = name
self.total = total
self.choose_one = choose_one
def __repr__(self):
return f"ReinforceField(name={self.name}, total={self.total}, choose_one={self.choose_one})"
class ReinforceController(nn.Module):
"""
A controller that mutates the graph with RL.
Parameters
----------
fields : list of ReinforceField
List of fields to choose.
lstm_size : int
Controller LSTM hidden units.
lstm_num_layers : int
Number of layers for stacked LSTM.
tanh_constant : float
Logits will be equal to ``tanh_constant * tanh(logits)``. Don't use ``tanh`` if this value is ``None``.
skip_target : float
Target probability that skipconnect will appear.
temperature : float
Temperature constant that divides the logits.
entropy_reduction : str
Can be one of ``sum`` and ``mean``. How the entropy of multi-input-choice is reduced.
"""
def __init__(
self,
fields,
lstm_size=64,
lstm_num_layers=1,
tanh_constant=1.5,
skip_target=0.4,
temperature=None,
entropy_reduction="sum",
):
super(ReinforceController, self).__init__()
self.fields = fields
self.lstm_size = lstm_size
self.lstm_num_layers = lstm_num_layers
self.tanh_constant = tanh_constant
self.temperature = temperature
self.skip_target = skip_target
self.lstm = StackedLSTMCell(self.lstm_num_layers, self.lstm_size, False)
self.attn_anchor = nn.Linear(self.lstm_size, self.lstm_size, bias=False)
self.attn_query = nn.Linear(self.lstm_size, self.lstm_size, bias=False)
self.v_attn = nn.Linear(self.lstm_size, 1, bias=False)
self.g_emb = nn.Parameter(torch.randn(1, self.lstm_size) * 0.1)
self.skip_targets = nn.Parameter(
torch.tensor(
[1.0 - self.skip_target, self.skip_target]
), # pylint: disable=not-callable
requires_grad=False,
)
assert entropy_reduction in [
"sum",
"mean",
], "Entropy reduction must be one of sum and mean."
self.entropy_reduction = torch.sum if entropy_reduction == "sum" else torch.mean
self.cross_entropy_loss = nn.CrossEntropyLoss(reduction="none")
self.soft = nn.ModuleDict(
{
field.name: nn.Linear(self.lstm_size, field.total, bias=False)
for field in fields
}
)
self.embedding = nn.ModuleDict(
{field.name: nn.Embedding(field.total, self.lstm_size) for field in fields}
)
def resample(self):
self._initialize()
result = dict()
for field in self.fields:
result[field.name] = self._sample_single(field)
return result
def _initialize(self):
self._inputs = self.g_emb.data
self._c = [
torch.zeros(
(1, self.lstm_size),
dtype=self._inputs.dtype,
device=self._inputs.device,
)
for _ in range(self.lstm_num_layers)
]
self._h = [
torch.zeros(
(1, self.lstm_size),
dtype=self._inputs.dtype,
device=self._inputs.device,
)
for _ in range(self.lstm_num_layers)
]
self.sample_log_prob = 0
self.sample_entropy = 0
self.sample_skip_penalty = 0
def _lstm_next_step(self):
self._h, self._c = self.lstm(self._inputs, (self._h, self._c))
def _sample_single(self, field):
self._lstm_next_step()
logit = self.soft[field.name](self._h[-1])
if self.temperature is not None:
logit /= self.temperature
if self.tanh_constant is not None:
logit = self.tanh_constant * torch.tanh(logit)
if field.choose_one:
sampled = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1)
log_prob = self.cross_entropy_loss(logit, sampled)
self._inputs = self.embedding[field.name](sampled)
else:
logit = logit.view(-1, 1)
logit = torch.cat(
[-logit, logit], 1
) # pylint: disable=invalid-unary-operand-type
sampled = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1)
skip_prob = torch.sigmoid(logit)
kl = torch.sum(skip_prob * torch.log(skip_prob / self.skip_targets))
self.sample_skip_penalty += kl
log_prob = self.cross_entropy_loss(logit, sampled)
sampled = sampled.nonzero().view(-1)
if sampled.sum().item():
self._inputs = (
torch.sum(self.embedding[field.name](sampled.view(-1)), 0)
/ (1.0 + torch.sum(sampled))
).unsqueeze(0)
else:
self._inputs = torch.zeros(
1, self.lstm_size, device=self.embedding[field.name].weight.device
)
sampled = sampled.detach().numpy().tolist()
self.sample_log_prob += self.entropy_reduction(log_prob)
entropy = (
log_prob * torch.exp(-log_prob)
).detach() # pylint: disable=invalid-unary-operand-type
self.sample_entropy += self.entropy_reduction(entropy)
if len(sampled) == 1:
sampled = sampled[0]
return sampled
[docs]@register_nas_algo("rl")
class RL(BaseNAS):
"""
RL in GraphNas.
Parameters
----------
num_epochs : int
Number of epochs planned for training.
device : torch.device
``torch.device("cpu")`` or ``torch.device("cuda")``.
log_frequency : int
Step count per logging.
grad_clip : float
Gradient clipping. Set to 0 to disable. Default: 5.
entropy_weight : float
Weight of sample entropy loss.
skip_weight : float
Weight of skip penalty loss.
baseline_decay : float
Decay factor of baseline. New baseline will be equal to ``baseline_decay * baseline_old + reward * (1 - baseline_decay)``.
ctrl_lr : float
Learning rate for RL controller.
ctrl_steps_aggregate : int
Number of steps that will be aggregated into one mini-batch for RL controller.
ctrl_steps : int
Number of mini-batches for each epoch of RL controller learning.
ctrl_kwargs : dict
Optional kwargs that will be passed to :class:`ReinforceController`.
n_warmup : int
Number of epochs for training super network.
model_lr : float
Learning rate for super network.
model_wd : float
Weight decay for super network.
disable_progress: boolean
Control whether show the progress bar.
"""
def __init__(
self,
num_epochs=5,
device="auto",
log_frequency=None,
grad_clip=5.0,
entropy_weight=0.0001,
skip_weight=0.8,
baseline_decay=0.999,
ctrl_lr=0.00035,
ctrl_steps_aggregate=20,
ctrl_kwargs=None,
n_warmup=100,
model_lr=5e-3,
model_wd=5e-4,
disable_progress=False,
):
super().__init__(device)
self.num_epochs = num_epochs
self.log_frequency = log_frequency
self.entropy_weight = entropy_weight
self.skip_weight = skip_weight
self.baseline_decay = baseline_decay
self.baseline = 0.0
self.ctrl_steps_aggregate = ctrl_steps_aggregate
self.grad_clip = grad_clip
self.ctrl_kwargs = ctrl_kwargs
self.ctrl_lr = ctrl_lr
self.n_warmup = n_warmup
self.model_lr = model_lr
self.model_wd = model_wd
self.disable_progress = disable_progress
[docs] def search(self, space: BaseSpace, dset, estimator):
self.model = space
self.dataset = dset # .to(self.device)
self.estimator = estimator
# replace choice
self.nas_modules = []
k2o = get_module_order(self.model)
replace_layer_choice(self.model, PathSamplingLayerChoice, self.nas_modules)
replace_input_choice(self.model, PathSamplingInputChoice, self.nas_modules)
self.nas_modules = sort_replaced_module(k2o, self.nas_modules)
# to device
self.model = self.model.to(self.device)
# fields
self.nas_fields = [
ReinforceField(
name,
len(module),
isinstance(module, PathSamplingLayerChoice) or module.n_chosen == 1,
)
for name, module in self.nas_modules
]
self.controller = ReinforceController(
self.nas_fields, **(self.ctrl_kwargs or {})
)
self.ctrl_optim = torch.optim.Adam(
self.controller.parameters(), lr=self.ctrl_lr
)
# train
with tqdm(range(self.num_epochs), disable=self.disable_progress) as bar:
for i in bar:
l2 = self._train_controller(i)
bar.set_postfix(reward_controller=l2)
selection = self.export()
arch = space.parse_model(selection, self.device)
# print(selection,arch)
return arch
def _train_controller(self, epoch):
self.model.eval()
self.controller.train()
self.ctrl_optim.zero_grad()
rewards = []
with tqdm(
range(self.ctrl_steps_aggregate), disable=self.disable_progress
) as bar:
for ctrl_step in bar:
self._resample()
metric, loss = self._infer(mask="val")
reward = metric
bar.set_postfix(acc=metric, loss=loss.item())
LOGGER.debug(f"{self.arch}\n{self.selection}\n{metric},{loss}")
rewards.append(reward)
if self.entropy_weight:
reward += (
self.entropy_weight * self.controller.sample_entropy.item()
)
self.baseline = self.baseline * self.baseline_decay + reward * (
1 - self.baseline_decay
)
loss = self.controller.sample_log_prob * (reward - self.baseline)
if self.skip_weight:
loss += self.skip_weight * self.controller.sample_skip_penalty
loss /= self.ctrl_steps_aggregate
loss.backward()
if (ctrl_step + 1) % self.ctrl_steps_aggregate == 0:
if self.grad_clip > 0:
nn.utils.clip_grad_norm_(
self.controller.parameters(), self.grad_clip
)
self.ctrl_optim.step()
self.ctrl_optim.zero_grad()
if (
self.log_frequency is not None
and ctrl_step % self.log_frequency == 0
):
LOGGER.debug(
"RL Epoch [%d/%d] Step [%d/%d] %s",
epoch + 1,
self.num_epochs,
ctrl_step + 1,
self.ctrl_steps_aggregate,
)
return sum(rewards) / len(rewards)
def _resample(self):
result = self.controller.resample()
self.arch = self.model.parse_model(result, device=self.device)
self.selection = result
def export(self):
self.controller.eval()
with torch.no_grad():
return self.controller.resample()
def _infer(self, mask="train"):
metric, loss = self.estimator.infer(self.arch._model, self.dataset, mask=mask)
return metric[0], loss
[docs]@register_nas_algo("graphnas")
class GraphNasRL(BaseNAS):
"""
RL in GraphNas.
Parameters
----------
device : torch.device
``torch.device("cpu")`` or ``torch.device("cuda")``.
num_epochs : int
Number of epochs planned for training.
log_frequency : int
Step count per logging.
grad_clip : float
Gradient clipping. Set to 0 to disable. Default: 5.
entropy_weight : float
Weight of sample entropy loss.
skip_weight : float
Weight of skip penalty loss.
baseline_decay : float
Decay factor of baseline. New baseline will be equal to ``baseline_decay * baseline_old + reward * (1 - baseline_decay)``.
ctrl_lr : float
Learning rate for RL controller.
ctrl_steps_aggregate : int
Number of steps that will be aggregated into one mini-batch for RL controller.
ctrl_steps : int
Number of mini-batches for each epoch of RL controller learning.
ctrl_kwargs : dict
Optional kwargs that will be passed to :class:`ReinforceController`.
n_warmup : int
Number of epochs for training super network.
model_lr : float
Learning rate for super network.
model_wd : float
Weight decay for super network.
topk : int
Number of architectures kept in training process.
disable_progeress: boolean
Control whether show the progress bar.
"""
def __init__(
self,
device="auto",
num_epochs=10,
log_frequency=None,
grad_clip=5.0,
entropy_weight=0.0001,
skip_weight=0,
baseline_decay=0.95,
ctrl_lr=0.00035,
ctrl_steps_aggregate=100,
ctrl_kwargs=None,
n_warmup=100,
model_lr=5e-3,
model_wd=5e-4,
topk=5,
disable_progress=False,
hardware_metric_limit=None,
):
super().__init__(device)
self.num_epochs = num_epochs
self.log_frequency = log_frequency
self.entropy_weight = entropy_weight
self.skip_weight = skip_weight
self.baseline_decay = baseline_decay
self.ctrl_steps_aggregate = ctrl_steps_aggregate
self.grad_clip = grad_clip
self.ctrl_kwargs = ctrl_kwargs
self.ctrl_lr = ctrl_lr
self.n_warmup = n_warmup
self.model_lr = model_lr
self.model_wd = model_wd
self.hist = []
self.topk = topk
self.disable_progress = disable_progress
self.hardware_metric_limit = hardware_metric_limit
[docs] def search(self, space: BaseSpace, dset, estimator):
self.model = space
self.dataset = dset # .to(self.device)
self.estimator = estimator
# replace choice
self.nas_modules = []
k2o = get_module_order(self.model)
replace_layer_choice(self.model, PathSamplingLayerChoice, self.nas_modules)
replace_input_choice(self.model, PathSamplingInputChoice, self.nas_modules)
self.nas_modules = sort_replaced_module(k2o, self.nas_modules)
# to device
self.model = self.model.to(self.device)
# fields
self.nas_fields = [
ReinforceField(
name,
len(module),
isinstance(module, PathSamplingLayerChoice) or module.n_chosen == 1,
)
for name, module in self.nas_modules
]
self.controller = ReinforceController(
self.nas_fields,
lstm_size=100,
temperature=5.0,
tanh_constant=2.5,
**(self.ctrl_kwargs or {}),
)
self.ctrl_optim = torch.optim.Adam(
self.controller.parameters(), lr=self.ctrl_lr
)
# train
with tqdm(range(self.num_epochs), disable=self.disable_progress) as bar:
for i in bar:
l2 = self._train_controller(i)
bar.set_postfix(reward_controller=l2)
# selection=self.export()
selections = [x[1] for x in self.hist]
candidiate_accs = [-x[0] for x in self.hist]
# print('candidiate accuracies',candidiate_accs)
selection = self._choose_best(selections)
arch = space.parse_model(selection, self.device)
# print(selection,arch)
return arch
def _choose_best(self, selections):
# graphnas use top 5 models, can evaluate 20 times epoch and choose the best.
results = []
for selection in selections:
accs = []
for i in tqdm(range(20), disable=self.disable_progress):
self.arch = self.model.parse_model(selection, device=self.device)
metric, loss, _ = self._infer(mask="val")
accs.append(metric)
result = np.mean(accs)
LOGGER.info(
"selection {} \n acc {:.4f} +- {:.4f}".format(
selection, np.mean(accs), np.std(accs) / np.sqrt(20)
)
)
results.append(result)
best_selection = selections[np.argmax(results)]
return best_selection
def _train_controller(self, epoch):
self.model.eval()
self.controller.train()
self.ctrl_optim.zero_grad()
rewards = []
baseline = None
# diff: graph nas train 100 and derive 100 for every epoch(10 epochs), we just train 100(20 epochs). totol num of samples are same (2000)
with tqdm(
range(self.ctrl_steps_aggregate), disable=self.disable_progress
) as bar:
for ctrl_step in bar:
self._resample()
metric, loss, hardware_metric = self._infer(mask="val")
reward = metric
# bar.set_postfix(acc=metric,loss=loss.item())
LOGGER.debug(f"{self.arch}\n{self.selection}\n{metric},{loss}")
# diff: not do reward shaping as in graphnas code
if (
self.hardware_metric_limit is None
or hardware_metric[0] < self.hardware_metric_limit
):
self.hist.append([-metric, self.selection])
if len(self.hist) > self.topk:
self.hist.sort(key=lambda x: x[0])
self.hist.pop()
rewards.append(reward)
if self.entropy_weight:
reward += (
self.entropy_weight * self.controller.sample_entropy.item()
)
if not baseline:
baseline = reward
else:
baseline = baseline * self.baseline_decay + reward * (
1 - self.baseline_decay
)
loss = self.controller.sample_log_prob * (reward - baseline)
self.ctrl_optim.zero_grad()
loss.backward()
self.ctrl_optim.step()
bar.set_postfix(acc=metric, max_acc=max(rewards))
LOGGER.info("epoch:{}, mean rewards:{}".format(epoch, sum(rewards) / len(rewards)))
return sum(rewards) / len(rewards)
def _resample(self):
result = self.controller.resample()
self.arch = self.model.parse_model(result, device=self.device)
self.selection = result
def export(self):
self.controller.eval()
with torch.no_grad():
return self.controller.resample()
def _infer(self, mask="train"):
metric, loss = self.estimator.infer(self.arch._model, self.dataset, mask=mask)
return metric[0], loss, metric[1:]