Source code for autogl.module.nas.algorithm.base

Base class for algorithm
from ...model import BaseAutoModel
import torch
from abc import abstractmethod
from ....utils import get_device

[docs]class BaseNAS: """ Base NAS algorithm class Parameters ---------- device: str or torch.device The device of the whole process """ def __init__(self, device="auto") -> None: self.device = get_device(device)
[docs] def to(self, device): """ Change the device of the whole NAS search process Parameters ---------- device: str or torch.device """ self.device = get_device(device)
[docs] @abstractmethod def search(self, space, dataset, estimator) -> BaseAutoModel: """ The search process of NAS. Parameters ---------- space : The search space. Constructed following nni. dataset : autogl.datasets Dataset to perform search on. estimator : autogl.module.nas.estimator.BaseEstimator The estimator to compute loss & metrics. Returns ------- model: autogl.module.model.BaseModel The searched model. """ raise NotImplementedError()