Source code for autogl.module.train.base

import numpy as np
import typing as _typing

import torch
import pickle

from autogl.module.model.encoders.base_encoder import AutoHomogeneousEncoderMaintainer

from ..model import (
    EncoderUniversalRegistry,
    DecoderUniversalRegistry,
    BaseEncoderMaintainer,
    BaseDecoderMaintainer,
    BaseAutoModel,
    ModelUniversalRegistry
)
from ..hpo import AutoModule
import logging
from .evaluation import Evaluation, get_feval, Acc
from ...utils import get_logger

LOGGER_ES = get_logger("early-stopping")

class _DummyModel(torch.nn.Module):
    def __init__(self, encoder: _typing.Union[BaseEncoderMaintainer, BaseAutoModel], decoder: _typing.Optional[BaseDecoderMaintainer]):
        super().__init__()
        if isinstance(encoder, BaseAutoModel):
            self.encoder = encoder.model
            self.decoder = None
        else:
            self.encoder = encoder.encoder
            self.decoder = None if decoder is None else decoder.decoder

    def __str__(self, ):
        return "DummyModel(encoder={}, decoder={})".format(self.encoder, self.decoder)

    def encode(self, *args, **kwargs):
        return self.encoder(*args, **kwargs)
    
    def decode(self, *args, **kwargs):
        if self.decoder is None: return args[0]
        return self.decoder(*args, **kwargs)
    
    def forward(self, *args, **kwargs):
        res = self.encode(*args, **kwargs)
        return self.decode(res, *args, **kwargs)

class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience."""

    def __init__(
        self,
        patience=7,
        verbose=False,
        delta=0,
        path="checkpoint.pt",
        trace_func=LOGGER_ES.info,
    ):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 7
            verbose (bool): If True, prints a message for each validation loss improvement.
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
            path (str): Path for the checkpoint to be saved to.
                            Default: 'checkpoint.pt'
            trace_func (function): trace print function.
                            Default: print
        """
        self.patience = 100 if patience is None else patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta
        self.path = path
        self.trace_func = trace_func

    def __call__(self, val_loss, model):
        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score <= self.best_score + self.delta:
            self.counter += 1
            if self.verbose is True:
                self.trace_func(
                    f"EarlyStopping counter: {self.counter} out of {self.patience}"
                )
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        """Saves model when validation loss decrease."""
        if self.verbose:
            self.trace_func(
                f"Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ..."
            )
        self.best_param = pickle.dumps(model.state_dict())
        # torch.save(model.state_dict(), self.path)
        self.val_loss_min = val_loss

    def load_checkpoint(self, model):
        """Load models"""
        if hasattr(self, "best_param"):
            model.load_state_dict(pickle.loads(self.best_param))
        else:
            LOGGER_ES.warn("try to load checkpoint while no checkpoint is saved")


class BaseTrainer(AutoModule):
    def __init__(
        self,
        encoder: _typing.Union[BaseAutoModel, BaseEncoderMaintainer, None],
        decoder: _typing.Union[BaseDecoderMaintainer, None],
        device: _typing.Union[torch.device, str],
        feval: _typing.Union[
            _typing.Sequence[str], _typing.Sequence[_typing.Type[Evaluation]]
        ] = (Acc,),
        loss: str = "nll_loss",
    ):
        """
        The basic trainer.
        Used to automatically train the problems, e.g., node classification, graph classification, etc.
        Parameters
        ----------
        model: `BaseModel` or `str`
            The (name of) model used to train and predict.
        init: `bool`
            If True(False), the model will (not) be initialized.
        """
        super().__init__(device)
        self.encoder = encoder
        self.decoder = None if isinstance(encoder, BaseAutoModel) else decoder
        self.feval = feval
        self.loss = loss

    def _compose_model(self):
        return _DummyModel(self.encoder, self.decoder).to(self.device)

    def _initialize(self):
        self.encoder.initialize()
        if self.decoder is not None:
            self.decoder.initialize(self.encoder)

    @property
    def feval(self) -> _typing.Sequence[_typing.Type[Evaluation]]:
        return self.__feval

    @feval.setter
    def feval(
        self,
        _feval: _typing.Union[
            _typing.Sequence[str], _typing.Sequence[_typing.Type[Evaluation]]
        ],
    ):
        self.__feval: _typing.Sequence[_typing.Type[Evaluation]] = get_feval(_feval)

    @property
    def model(self):
        # compatible with v0.2
        return self.encoder
    
    @model.setter
    def model(self, model):
        # compatible with v0.2
        self.encoder = model

    def to(self, device: _typing.Union[str, torch.device]):
        """
        Transfer the trainer to another device
        Parameters
        ----------
        device: `str` or `torch.device`
            The device this trainer will use
        """
        self.device = device
        if self.encoder is not None:
            self.encoder.to_device(self.device)
        if self.decoder is not None:
            self.decoder.to_device(self.device)

    def get_feval(
        self, return_major: bool = False
    ) -> _typing.Union[
        _typing.Type[Evaluation], _typing.Sequence[_typing.Type[Evaluation]]
    ]:
        """
        Parameters
        ----------
        return_major: ``bool``
            Wether to return the major ``feval``. Default ``False``.
        Returns
        -------
        ``evaluation`` or list of ``evaluation``:
            If ``return_major=True``, will return the major ``evaluation`` method.
            Otherwise, will return the ``evaluation`` element passed when constructing.
        """
        if return_major:
            if isinstance(self.feval, _typing.Sequence):
                return self.feval[0]
            else:
                return self.feval
        return self.feval

    @classmethod
    def save(cls, instance, path):
        with open(path, "wb") as output:
            pickle.dump(instance, output, pickle.HIGHEST_PROTOCOL)

    @classmethod
    def load(cls, path):
        with open(path, "rb") as inputs:
            instance = pickle.load(inputs)
            return instance

    def duplicate_from_hyper_parameter(self, *args, **kwargs) -> "BaseTrainer":
        """Create a new trainer with the given hyper parameter."""
        raise NotImplementedError()

    def train(self, dataset, keep_valid_result):
        """
        Train on the given dataset.
        Parameters
        ----------
        dataset: The dataset used in training.
        keep_valid_result: `bool`
            If True(False), save the validation result after training.
        Returns
        -------
        """
        raise NotImplementedError()

    def predict(self, dataset, mask=None):
        """
        Predict on the given dataset.
        Parameters
        ----------
        dataset: The dataset used in predicting.
        mask: `train`, `val`, or `test`.
            The dataset mask.
        Returns
        -------
        prediction result
        """
        raise NotImplementedError()

    def predict_proba(self, dataset, mask=None, in_log_format=False):
        """
        Predict the probability on the given dataset.
        Parameters
        ----------
        dataset: The dataset used in predicting.
        mask: `train`, `val`, or `test`.
            The dataset mask.
        in_log_format: `bool`.
            If True(False), the probability will (not) be log format.
        Returns
        -------
        The prediction result.
        """
        raise NotImplementedError()

    def get_valid_predict_proba(self):
        """Get the valid result (prediction probability)."""
        raise NotImplementedError()

    def get_valid_predict(self):
        """Get the valid result."""
        raise NotImplementedError()

    def get_valid_score(self, return_major=True):
        """Get the validation score."""
        raise NotImplementedError()

    def __repr__(self) -> str:
        raise NotImplementedError

    def evaluate(self, dataset, mask=None, feval=None):
        """
        Parameters
        ----------
        dataset: The dataset used in evaluation.
        mask: `train`, `val`, or `test`.
            The dataset mask.
        feval: The evaluation methods.
        Returns
        -------
        The evaluation result.
        """
        raise NotImplementedError

    def update_parameters(self, **kwargs):
        """
        Update parameters of this trainer
        """
        for k, v in kwargs.items():
            if k == "feval":
                self.feval = get_feval(v)
            elif k == "device":
                self.to(v)
            elif hasattr(self, k):
                setattr(self, k, v)
            else:
                raise KeyError("Cannot set parameter", k, "for trainer", self.__class__)

    def combined_hyper_parameter_space(self):
        return {
            "trainer": self.hyper_parameter_space,
            "encoder": self.encoder.hyper_parameter_space,
            "decoder": [] if self.decoder is None else self.decoder.hyper_parameter_space
        }


class _BaseClassificationTrainer(BaseTrainer):
    """ Base class of trainer for classification tasks """

    def __init__(
        self,
        encoder: _typing.Union[BaseAutoModel, BaseEncoderMaintainer, str, None],
        decoder: _typing.Union[BaseDecoderMaintainer, str, None],
        num_features: int,
        num_classes: int,
        last_dim: _typing.Union[int, str] = "auto",
        device: _typing.Union[torch.device, str, None] = "auto",
        feval: _typing.Union[
            _typing.Sequence[str], _typing.Sequence[_typing.Type[Evaluation]]
        ] = (Acc,),
        loss: str = "nll_loss",
    ):
        self._encoder = None
        self._decoder = None
        self.num_features = num_features
        self.num_classes = num_classes
        self.last_dim: _typing.Union[int, str] = last_dim
        super(_BaseClassificationTrainer, self).__init__(
            encoder, decoder, device, feval, loss
        )
    
    @property
    def encoder(self):
        return self._encoder
    
    @encoder.setter
    def encoder(self, enc: _typing.Union[BaseAutoModel, BaseEncoderMaintainer, str, None]):
        if isinstance(enc, str):
            if enc in EncoderUniversalRegistry:
                self._encoder = EncoderUniversalRegistry.get_encoder(enc)(
                    self.num_features, final_dimension=self.last_dim, device=self.device, init=self.initialized
                )
            else:
                self._encoder = ModelUniversalRegistry.get_model(enc)(
                    self.num_features, final_dimension=self.last_dim, device=self.device
                )
                
        elif isinstance(enc, BaseEncoderMaintainer):
            self._encoder = enc
        elif isinstance(enc, BaseAutoModel):
            self._encoder = enc
            if self.decoder is not None:
                logging.warn("will disable decoder since a whole model is passed")
                self.decoder = None
        elif enc is None:
            self._encoder = None
        else:
            raise ValueError("Enc {} is not supported!".format(enc))
        self.num_features = self.num_features
        self.num_classes = self.num_classes
        self.last_dim = self.last_dim

    @property
    def decoder(self):
        return self._decoder
    
    @decoder.setter
    def decoder(self, dec: _typing.Union[BaseDecoderMaintainer, str, None]):
        if isinstance(self.encoder, BaseAutoModel):
            logging.warn("Ignore passed dec since enc is a whole model")
            self._decoder = None
            return
        if isinstance(dec, str):
            self._decoder = DecoderUniversalRegistry.get_decoder(dec)(
                self.num_classes, input_dimension=self.last_dim, device=self.device, init=self.initialized
            )
        elif isinstance(dec, BaseDecoderMaintainer):
            self._decoder = dec
        elif dec is None:
            self._decoder = None
        else:
            raise ValueError("Dec {} is not supported!".format(dec))
        self.num_features = self.num_features
        self.num_classes = self.num_classes
        self.last_dim = self.last_dim
    
    @property
    def num_classes(self):
        return self.__num_classes

    @num_classes.setter
    def num_classes(self, num_classes):
        self.__num_classes = num_classes
        if isinstance(self.encoder, BaseAutoModel):
            self.encoder.output_dimension = num_classes
        elif isinstance(self.decoder, BaseDecoderMaintainer):
            self.decoder.output_dimension = num_classes

    @property
    def last_dim(self):
        return self._last_dim
    
    @last_dim.setter
    def last_dim(self, dim):
        self._last_dim = dim
        if isinstance(self.encoder, AutoHomogeneousEncoderMaintainer):
            self.encoder.final_dimension = self._last_dim

    @property
    def num_features(self):
        return self._num_features
    
    @num_features.setter
    def num_features(self, num_features):
        self._num_features = num_features
        if self.encoder is not None:
            self.encoder.input_dimension = num_features

[docs]class BaseNodeClassificationTrainer(_BaseClassificationTrainer): def __init__( self, encoder: _typing.Union[BaseAutoModel, BaseEncoderMaintainer, str, None], decoder: _typing.Union[BaseDecoderMaintainer, str, None], num_features: int, num_classes: int, device: _typing.Union[torch.device, str, None] = None, feval: _typing.Union[ _typing.Sequence[str], _typing.Sequence[_typing.Type[Evaluation]] ] = (Acc,), loss: str = "nll_loss", ): super(BaseNodeClassificationTrainer, self).__init__( encoder, decoder, num_features, num_classes, num_classes, device, feval, loss ) # override num_classes property to support last_dim setting @property def num_classes(self): return self.__num_classes @num_classes.setter def num_classes(self, num_classes): self.__num_classes = num_classes if isinstance(self.encoder, BaseAutoModel): self.encoder.output_dimension = num_classes elif isinstance(self.decoder, BaseDecoderMaintainer): self.decoder.output_dimension = num_classes self.last_dim = num_classes
[docs]class BaseGraphClassificationTrainer(_BaseClassificationTrainer): def __init__( self, encoder: _typing.Union[BaseAutoModel, BaseEncoderMaintainer, str, None] = None, decoder: _typing.Union[BaseDecoderMaintainer, str, None] = None, num_features: _typing.Optional[int] = None, num_classes: _typing.Optional[int] = None, num_graph_features: int = 0, last_dim: _typing.Union[int, str] = "auto", device: _typing.Union[torch.device, str, None] = None, feval: _typing.Union[ _typing.Sequence[str], _typing.Sequence[_typing.Type[Evaluation]] ] = (Acc,), loss: str = "nll_loss", ): self._encoder = None self._decoder = None self.num_graph_features: int = num_graph_features super(BaseGraphClassificationTrainer, self).__init__( encoder, decoder, num_features, num_classes, last_dim, device, feval, loss ) # override encoder and decoder to depend on graph level features @property def encoder(self): return self._encoder @encoder.setter def encoder(self, enc: _typing.Union[BaseAutoModel, BaseEncoderMaintainer, str, None]): if isinstance(enc, str): if enc in EncoderUniversalRegistry: self._encoder = EncoderUniversalRegistry.get_encoder(enc)( self.num_features, last_dim=self.last_dim, num_graph_features=self.num_graph_features, device=self.device, init=self.initialized ) else: self._encoder = ModelUniversalRegistry.get_model(enc)( self.num_features, self.last_dim, device=self.device, num_graph_features=self.num_graph_features, ) elif isinstance(enc, (BaseAutoModel, BaseEncoderMaintainer)): self._encoder = enc if isinstance(enc, BaseAutoModel) and self.decoder is not None: logging.warn("will disable decoder since a whole model is passed") self.decoder = None elif enc is None: self._encoder = None else: raise ValueError("Enc {} is not supported!".format(enc)) self.num_features = self.num_features self.num_classes = self.num_classes self.last_dim = self.last_dim self.num_graph_features = self.num_graph_features @property def decoder(self): if isinstance(self.encoder, BaseAutoModel): return None return self._decoder @decoder.setter def decoder(self, dec: _typing.Union[BaseDecoderMaintainer, str, None]): if isinstance(self.encoder, BaseAutoModel): logging.warn("Ignore passed dec since enc is a whole model") self._decoder = None return if isinstance(dec, str): self._decoder = DecoderUniversalRegistry.get_decoder(dec)( self.num_classes, input_dim=self.last_dim, num_graph_features=self.num_graph_features, device=self.device, init=self.initialized ) elif isinstance(dec, (BaseDecoderMaintainer, None)): self._decoder = dec else: raise ValueError("Invalid decoder setting") self.num_features = self.num_features self.num_classes = self.num_classes self.last_dim = self.last_dim # override num_classes property to support last_dim setting @property def num_classes(self): return self.__num_classes @num_classes.setter def num_classes(self, num_classes): self.__num_classes = num_classes if isinstance(self.encoder, BaseAutoModel): self.encoder.output_dimension = num_classes elif isinstance(self.decoder, BaseDecoderMaintainer): self.decoder.output_dimension = num_classes @property def num_graph_features(self): return self._num_graph_features @num_graph_features.setter def num_graph_features(self, num_graph_features): self._num_graph_features = num_graph_features if self.encoder is not None: self.encoder.num_graph_features = self._num_graph_features if self.decoder is not None: self.decoder.num_graph_features = self._num_graph_features
# TODO: according to discussion, link prediction may not belong to classification tasks
[docs]class BaseLinkPredictionTrainer(_BaseClassificationTrainer): def __init__( self, encoder: _typing.Union[BaseAutoModel, BaseEncoderMaintainer, str, None] = None, decoder: _typing.Union[BaseDecoderMaintainer, str, None] = None, num_features: _typing.Optional[int] = None, last_dim: _typing.Union[int, str] = "auto", device: _typing.Union[torch.device, str, None] = None, feval: _typing.Union[ _typing.Sequence[str], _typing.Sequence[_typing.Type[Evaluation]] ] = (Acc,), loss: str = "nll_loss", ): super(BaseLinkPredictionTrainer, self).__init__( encoder, decoder, num_features, 2, last_dim, device, feval, loss ) # override decoder since no num_classes is needed @property def decoder(self): if isinstance(self.encoder, BaseAutoModel): return None return self._decoder @decoder.setter def decoder(self, dec: _typing.Union[BaseDecoderMaintainer, str, None]): if isinstance(self.encoder, BaseAutoModel): logging.warn("Ignore passed dec since enc is a whole model") self._decoder = None return if isinstance(dec, str): self._decoder = DecoderUniversalRegistry.get_decoder(dec)( input_dim=self.last_dim, device=self.device, init=self.initialized ) elif isinstance(dec, BaseDecoderMaintainer): self._decoder = dec elif dec is None: self._decoder = None else: raise ValueError("Invalid decoder setting") self.num_features = self.num_features self.num_classes = self.num_classes self.last_dim = self.last_dim
# ============== Het =================
[docs]class BaseNodeClassificationHetTrainer(BaseNodeClassificationTrainer): """ Base class of trainer for classification tasks """ def __init__( self, model: _typing.Union[BaseAutoModel, str], dataset: None, num_features: int, num_classes: int, device: _typing.Union[torch.device, str, None] = "auto", feval: _typing.Union[ _typing.Sequence[str], _typing.Sequence[_typing.Type[Evaluation]] ] = (Acc,), loss: str = "nll_loss", ): self._dataset = dataset super(BaseNodeClassificationHetTrainer, self).__init__( model, None, num_features, num_classes, device, feval, loss ) self.from_dataset(dataset) def from_dataset(self, dataset): self._dataset = dataset if self.encoder is not None: self.encoder.from_dataset(self._dataset) @property def encoder(self): return self._encoder @encoder.setter def encoder(self, enc: _typing.Union[BaseAutoModel, str, None]): if isinstance(enc, str): self._encoder = ModelUniversalRegistry.get_model(enc)( self.num_features, self.num_classes, device=self.device ) elif isinstance(enc, BaseAutoModel): self._encoder = enc if self.decoder is not None: logging.warn("will disable decoder since a whole model is passed") self.decoder = None elif enc is None: self._encoder = None else: raise ValueError("Enc {} is not supported!".format(enc)) self.num_features = self.num_features self.num_classes = self.num_classes if self._dataset is not None: self.from_dataset(self._dataset)