Source code for autogllight.nas.estimator.train_scratch

import torch.nn.functional as F
from ..space import BaseSpace
from .base import BaseEstimator
from autogllight.utils.evaluation import Acc
from autogllight.utils.backend.op import *


[docs]class TrainScratchEstimator(BaseEstimator): """ Train scratch estimator. Train the model to get estimations. Parameters ---------- trainer : str The trainer to train the model evaluation : list of Evaluation The evaluation metrics in module/train/evaluation """ def __init__(self, trainer, evaluation=[Acc()]): super().__init__(None, evaluation) self.trainer = trainer self.evaluation = evaluation
[docs] def infer(self, model: BaseSpace, dataset, mask="train", *args, **kwargs): metrics, loss = self.trainer(model, dataset, mask, self.evaluation, *args, **kwargs) return metrics, loss