# codes in this file are reproduced from with some changes.
import copy
from logging import Logger
from numpy.core.fromnumeric import sort

import torch
import torch.nn as nn
import torch.nn.functional as F

from . import register_nas_algo
from .base import BaseNAS
from import BaseSpace
from ..utils import (
from tqdm import tqdm, trange
from ....utils import get_logger

import numpy as np
LOGGER = get_logger("SPOS")

import collections
import dataclasses
import random

class Individual:
    A class that represents an individual.
    Holds two attributes, where ``x`` is the model and ``y`` is the metric (e.g., accuracy).
    x: dict
    y: float

class Evolution:
    Algorithm for regularized evolution (i.e. aging evolution).
    Follows "Algorithm 1" in Real et al. "Regularized Evolution for Image Classifier Architecture Search".
    optimize_mode : str
        Can be one of "maximize" and "minimize". Default: maximize.
    population_size : int
        The number of individuals to keep in the population. Default: 100.
    cycles : int
        The number of cycles (trials) the algorithm should run for. Default: 20000.
    sample_size : int
        The number of individuals that should participate in each tournament. Default: 25.
    mutation_prob : float
        Probability that mutation happens in each dim. Default: 0.05

    def __init__(self, optimize_mode='maximize', population_size=100, sample_size=25, cycles=20000,
        assert optimize_mode in ['maximize', 'minimize']
        assert sample_size < population_size
        self.optimize_mode = optimize_mode
        self.population_size = population_size
        self.sample_size = sample_size
        self.cycles = cycles
        self.mutation_prob = mutation_prob
        self.disable_progress= disable_progress
        self._worst = float('-inf') if self.optimize_mode == 'maximize' else float('inf')

        self._success_count = 0
        self._population = collections.deque()
        self._running_models = []
        self._polling_interval = 2.
        self._history = []

    def best_parent(self,sample_size=None):
        """get the config of the best parent 
        samples = [p for p in self._population]  # copy population
        if sample_size is not None:
            samples = list(samples)[:sample_size]
        if self.optimize_mode == 'maximize':
            parent = max(samples, key=lambda sample: sample.y)
            parent = min(samples, key=lambda sample: sample.y)
        return parent.x

    def _prepare(self):

    def _get_metric(self,config):
        for name, module in self.nas_modules:
            module.sampled = config[name]
        # todo: this may be computational expensive 
        # model=self.model.parse_model(config,self.device) 
        with torch.no_grad():
            metric, loss = self.estimator.infer(self.model, self.dataset, mask='val')
        return metric[0]

    def search(self, space: BaseSpace,nas_modules,dset,estimator,device):
        self.model = space
        self.dataset = dset 
        self.estimator = estimator
        self.nas_modules = nas_modules
        self.device = device

        self._prepare()'Initializing the first population.')
        with tqdm(range(self.population_size), disable=self.disable_progress) as bar:
            for i in bar:
                config = self.uniform.resample()
                individual = Individual(config, metric)
                # LOGGER.debug('Individual created: %s', str(individual))
                bar.set_postfix(metric=metric,max=max(x.y for x in self._population),min=min(x.y for x in self._population))'Running mutations.')
        with tqdm(range(self.cycles), disable=self.disable_progress) as bar:
            for i in bar:
                metric=self._get_metric(config) # todo : add aging factor
                individual = Individual(config, metric)
                LOGGER.debug('Individual created: %s', str(individual))
                if len(self._population) > self.population_size:
                bar.set_postfix(metric=metric,max_h=max(x.y for x in self._history),max=max(x.y for x in self._population),min=min(x.y for x in self._population))
        # todo: origin is best in history | or the population may need to be retrained
        self._history.sort(key=lambda x: x.y)
        # best=self.best_parent()
        if self.optimize_mode == 'maximize':
        return best

class MutationSampler:
    """uniform mutator

        nas_modules in NAS algorithms , including choices of modules
    mutation_prob: float
        probability of doing mutation in each choice.
    parent : dict
        parent individual's choices
    def __init__(self,nas_modules,mutation_prob):
        selection_range = {}
        for k, v in nas_modules:
            selection_range[k] = len(v)
        self.selection_dict = selection_range
        self.mutation_prob = mutation_prob

    def resample(self, parent):
        child = {}
        for k, v in parent.items():
            if random.uniform(0, 1) < self.mutation_prob:
                child[k] = np.random.choice(range(search_space[k])) # do not exclude the original operator
                child[k] = v
        return child

class UniformSampler:
    """Uniform Sampler

        nas_modules in NAS algorithms , including choices of modules
    def __init__(self,nas_modules):
        selection_range = {}
        for k, v in nas_modules:
            selection_range[k] = len(v)
        self.selection_dict = selection_range
    def resample(self):
        selection = {}
        for k, v in self.selection_dict.items():
            selection[k] = np.random.choice(range(v))
        return selection

[docs]@register_nas_algo("spos") class Spos(BaseNAS): """ SPOS trainer. Parameters ---------- n_warmup : int Number of epochs for training super network. model_lr : float Learning rate for super network. model_wd : float Weight decay for super network. Other parameters see Evolution """ def __init__( self, n_warmup=1000, grad_clip=5.0, disable_progress=False, optimize_mode='maximize', population_size=100, sample_size=25, cycles=20000, mutation_prob=0.05, device="cuda", ): super().__init__(device) self.model_lr=5e-3 self.model_wd=5e-4 self.n_warmup = n_warmup self.disable_progress= disable_progress self.grad_clip = grad_clip self.optimize_mode = optimize_mode self.population_size = population_size self.sample_size = sample_size self.cycles = cycles self.mutation_prob = mutation_prob def _prepare(self): # replace choice self.nas_modules = [] k2o = get_module_order(self.model) replace_layer_choice(self.model, PathSamplingLayerChoice, self.nas_modules) replace_input_choice(self.model, PathSamplingInputChoice, self.nas_modules) self.nas_modules = sort_replaced_module(k2o, self.nas_modules) # to device self.model = self.model_optim = torch.optim.Adam( self.model.parameters(), lr=self.model_lr, weight_decay=self.model_wd ) # controller self.controller=UniformSampler(self.nas_modules) # Evolution self.evolve = Evolution( optimize_mode='maximize', population_size=self.population_size, sample_size=self.sample_size, cycles=self.cycles, mutation_prob=self.mutation_prob, disable_progress=self.disable_progress )
[docs] def search(self, space: BaseSpace, dset, estimator): self.model = space self.dataset = dset self.estimator = estimator self._prepare() self._train() # train using uniform sampling self._search() # search using evolutionary algorithm selection = self.export() # here may sample N , retrain N ,and get best print(selection) return space.parse_model(selection, self.device)
def _search(self): self.model, self.nas_modules, self.dataset, self.estimator, self.device, ) def _train(self): with tqdm(range(self.n_warmup), disable=self.disable_progress) as bar: for i in bar: acc, l1 = self._train_one_epoch(i) with torch.no_grad(): val_acc, val_loss = self._infer("val") bar.set_postfix(loss=l1, acc=acc, val_acc=val_acc, val_loss=val_loss.item()) def _train_one_epoch(self, epoch): self.model.train() self.model_optim.zero_grad() self._resample() # uniform sampling metric, loss = self._infer(mask="train") loss.backward() if self.grad_clip > 0: nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip) self.model_optim.step() return metric, loss.item() def _resample(self): result=self.controller.resample() for name, module in self.nas_modules: module.sampled = result[name] def export(self): return self.best_config def _infer(self, mask="train"): metric, loss = self.estimator.infer(self.model, self.dataset, mask=mask) return metric[0], loss