Source code for autogl.module.nas.estimator.base

"""
Base estimator of NAS
"""

from abc import abstractmethod
from ..space import BaseSpace
from typing import Tuple
from ...train.evaluation import Evaluation, Acc
import torch.nn.functional as F
import torch


[docs]class BaseEstimator: """ The estimator of NAS model. Parameters ---------- loss_f: callable Default loss function for evaluation evaluation: list of autogl.module.train.evaluation.Evaluation Default evaluation metric """ def __init__(self, loss_f: str = "nll_loss", evaluation=[Acc()]): self.loss_f = loss_f self.evaluation = evaluation def setLossFunction(self, loss_f: str): self.loss_f = loss_f def setEvaluation(self, evaluation): self.evaluation = evaluation
[docs] @abstractmethod def infer( self, model: BaseSpace, dataset, mask="train" ) -> Tuple[torch.Tensor, torch.Tensor]: """ Calculate the loss and metrics of given model on given dataset using specified masks. Parameters ---------- model: autogl.module.nas.space.BaseSpace The model in space. dataset: autogl.dataset The dataset to perform infer mask: str The mask to evalute on dataset Return ------ metrics: list of float the metrics on given datasets. loss: torch.Tensor the loss on given datasets. Note that loss should be differentiable. """ raise NotImplementedError()