AutoGL SSL Trainer

AutoGL project use trainer to implement the graph self-supervised methods. Currently, we only support the GraphCL with semi-supervised downstream tasks:

  • GraphCLSemisupervisedTrainer using GraphCL algorithm for semi-supervised downstream tasks, the main interfaces are shown below
    • train(self, dataset, keep_valid_result=True): The function of training on the given dataset and keeping valid results
      • dataset: the graph dataset used to be trained
      • keep_valid_result: if True, save the validation result after training. Only if keep_valid_result is True and after training, the method get_valid_score, get_valid_predict_proba and get_valid_predict could output meaningful results.
    • predict(self, dataset, mask="test"): The function of predicting on the given dataset
      • dataset: the graph dataset used to be predicted
      • mask: "train", "val" or "test", the dataset mask
    • predict_proba(self, dataset, mask="test", in_log_format=False): The function of predicting the probability on the given dataset.
      • dataset: the graph dataset used to be predicted
      • mask: "train", "val" or "test", the dataset mask
      • in_log_format: if in_log_format is True, the probability will be log format
    • evaluate(self, dataset, mask="val", feval=None): The function of evaluating the model on the given dataset and keeping valid result.
      • dataset: the graph dataset used to be evaluated
      • mask: "train", "val" or "test", the dataset mask
      • feval: the evaluation method used in this function. If feval is None, it will use the feval given when initiate
    • get_valid_score(self, return_major=True): The function of getting valid scores after training.
      • return_major: if return_major is True, then return only consists of the major result.
    • get_valid_predict_proba(self): Get the prediction probability of the valid set after training.
    • get_valid_predict(self): Get the valid result after training

Lazy Initialization

Similar reason to :ref:model, we also use lazy initialization for all trainers. Only (part of) the hyper-parameters will be set when __init__() is called. The trainer will have its core model only after initialize() is explicitly called, which will be done automatically in solver and duplicate_from_hyper_parameter(), after all the hyper-parameters are set properly.

For example, if you want to set gcn as encoder, a simple mlp as a decoder, and use mlp as a classifier to solve a graph classification problem, there are three steps you need to do.

  • First, import everything you need

    from autogl.module.train.ssl import GraphCLSemisupervisedTrainer
    from autogl.datasets import build_dataset_from_name, utils
    from autogl.datasets.utils.conversion import to_pyg_dataset as convert_dataset
    
  • Secondly, setup the hyper-parameters of the encoder, decoder and the classifier

    trainer_hp = {
      'batch_size': 128,
      'p_lr': 0.0001,                  # learning rate of pretraining stage
      'p_weight_decay': 0,                             # weight decay of pretraining stage
      'p_epoch': 100,                                                          # max epoch of pretraining stage
      'p_early_stopping_round': 100, # early stopping round of pretraining stage
      'f_lr': 0.0001,                                                  # learning rate of fine-tuning stage
      'f_weight_decay': 0,                                     # weight decay of fine-tuning stage
      'f_epoch': 100,                                                          # max epoch of fine-tuning stage
      'f_early_stopping_round': 100, # early stopping round of fine-tuning stage
    }
    
    encoder_hp = {
      'num_layers': 3,
      'hidden': [64, 128],                                     # hidden dimensions, didn't need to set the dimension of final layer
      'dropout': 0.5,
      'act': 'relu',
      'eps': 'false'
    }
    
    decoder_hp = {
      'hidden': 64,
      'act': 'relu',
      'dropout': 0.5
    }
    
    prediction_head_hp = {
      'hidden': 64,
      'act': 'relu',
      'dropout': 0.5
    }
    
  • Thirdly, use duplicate_from_hyper_parameter()

    dataset = build_dataset_from_name('proteins')
    dataset = convert_dataset(dataset)
    utils.graph_random_splits(dataset, train_ratio=0.1, val_ratio=0.1, seed=2022) # split the dataset
    
    # generate a trainer, but it couldn't be used
    # before you call `duplicate_from_hyper_parameter`
    trainer = GraphCLSemisupervisedTrainer(
      model=('gcn', 'sumpoolmlp'),
      prediction_model_head='sumpoolmlp',
      views_fn=['random2', 'random2'],
      num_features=dataset[0].x.size(1),
      num_classes=max([data.y.item() for data in dataset]) + 1,
      z_dim=128,      # the embedding dimension
      init=False
    )
    
    # call duplicate_from_hyper_parameter to set some information about
    # model architecture and learning hyperparameters
    trainer.initialize()
    trainer = trainer.duplicate_from_hyper_parameter(
      {
        'trainer': trainer_hp,
        'encoder': encoder_hp,
        'decoder': decoder_hp,
        'prediction_head': prediction_head_hp
      }
    )
    

Train and Predict

After initializing a trainer, you can train it on the given datasets.

We are given the training and testing functions for the tasks of graph classification. You can also create your own tasks following similar patterns to ours.

We provide some interfaces, and you can easily use them to train or test on the given datasets.

  • Training: train()

    trainer.train(dataset, keep_valid_result=False)
    

    train() is the method of training on the given dataset and keeping valid results.

    It has two parameters, the first parameter is dataset, which is the graph dataset used to be trained. And the second parameter is keep_valid_result. It is a bool value, if true, the trainer will save the validation result after training if the dataset has a validation set.

  • Testing: predict()

    trainer.predict(dataset, 'test').detach().cpu().numpy()
    

    predict() is the method of predicting the given dataset.

    It has two parameters, the first parameter is dataset, which is the graph dataset used to be predicted. And the second parameter is mask. It is a string which can be ‘train’, ‘val’, or ‘test’. And returns the prediction results.

  • Evaluation: evaluate()

    result = trainer.evaluate(dataset, 'test')    # return a list of metrics, the default metric is accuracy
    

    evaluate() is the method of evaluating the model on the given dataset and keeping valid results.

    It has three parameters, the first parameter is dataset, which is the graph dataset used to be evaluated. And the second parameter is mask. It is a string which can be ‘train’, ‘val’ or ‘test’. And the last parameter is feval, which can be a string, tuple of strings, or None, which means the used evaluation methods such Acc.

    And you can write your own evaluation metrics and methods. Here is a simple example:

    from autogl.module.train.evaluation import Evaluation, register_evaluate
    from sklearn.metrics import accuracy_score
    
    @register_evaluate("my_acc") # use method register_evaluate, and then you can use this class by its register name 'my_acc'
    class MyAcc(Evaluation):
      @staticmethod
      def get_eval_name():
        '''
        define the name, didn't need to same as the registered name
        '''
        return "my_acc"
    
      @staticmethod
      def is_higher_better():
        '''
        return whether this evaluation method is higher better (bool)
        '''
        return True
    
      @staticmethod
      def evaluate(predict, label):
        '''
        return the evaluation result (float)
        '''
        if len(predict.shape) == 2:
              predict = np.argmax(predict, axis=1)
        else:
              predict = [1 if p > 0.5 else 0 for p in predict]
        return accuracy_score(label, predict)
    

Implement SSL Trainer

Next, we will show how to implement your own ssl trainer. It is more difficult to implement the trainer than to use it, it needs to implement three main functions _train_only(), _predict_only() and duplicate_from_hyper_parameter(). Now we will implement GraphCL with unsupervised downstream tasks step by step.

  • initialize your trainer

    First, We need to import some classes and methods, define a basic __init__() method, and register our trainer.

    import torch
    from torch.optim.lr_scheduler import StepLR
    from autogl.module.train import register_trainer
    from autogl.module.train.ssl.base import BaseContrastiveTrainer
    from autogl.datasets import utils
    
    @register_trainer("GraphCLUnsupervisedTrainer")
    class GraphCLUnsupervisedTrainer(BaseContrastiveTrainer):
      def __init__(
        self,
        model,
        prediction_model_head,
        num_features,
        num_classes,
        num_graph_features,
        device,
        feval,
        views_fn,
        z_dim,
        num_workers,
        batch_size,
        eval_interval,
        init,
        *args,
        **kwargs,
      ):
        # setup encoder and decoder
        if isinstance(model, Tuple):
          encoder, decoder = model
        elif isinstance(model, BaseAutoModel):
          raise ValueError("The GraphCL trainer must need an encoder and a decoder, so `model` shouldn't be an instance of `BaseAutoModel`")
        else:
          encoder, decoder = model, "sumpoolmlp"
        self.eval_interval = eval_interval
        # init contrastive learning
        super().__init__(
          encoder=encoder,
          decoder=decoder,
          decoder_node=None,
          num_features=num_features,
          num_graph_features=num_graph_features,
          views_fn=views_fn,
          graph_level=True,                                                                                   # have graph-level features
          node_level=False,                                                                                   # have node-level features
          device=device,
          feval=feval,
          z_dim=z_dim,                                                                                                        # the dimension of the embedding output by encoder
          z_node_dim=None,
          *args,
          **kwargs,
        )
        # initialize something specific for your own method
        self.views_fn = views_fn
        self.aug_ratio = aug_ratio
        self._prediction_model_head = None
        self.num_classes = num_classes
        self.prediction_model_head = prediction_model_head
        self.batch_size = batch_size
        self.num_workers = num_workers
        if self.num_workers > 0:
              mp.set_start_method("fork", force=True)
        # setup the hyperparameter when initializing
        self.hyper_parameters = {
          "batch_size": self.batch_size,
          "p_epoch": self.p_epoch,
          "p_early_stopping_round": self.p_early_stopping_round,
          "p_lr": self.p_lr,
          "p_weight_decay": self.p_weight_decay,
          "f_epoch": self.f_epoch,
          "f_early_stopping_round": self.f_early_stopping_round,
          "f_lr": self.f_lr,
          "f_weight_decay": self.f_weight_decay,
        }
        self.args = args
        self.kwargs = kwargs
        if init:
          self.initialize()
    
  • _train_only(self, dataset)

    In this method, the trainer trains the model on the given dataset. You can define several different methods for different training stages.

    • set the model on the specified device

      def _set_model_device(self, dataset):
        self.encoder.encoder.to(self.device)
        self.decoder.decoder.to(self.device)
      
    • For training, you can simply call super(). _train_pretraining_only(dataset, per_epoch) to train the encoder.

      for i, epoch in enumerate(super()._train_pretraining_only(dataset, per_epoch=True)):
        # you can define your own training process if you want
        # for example, we will fine-tune for every eval_interval epoch
        if (i + 1) % self.eval_interval == 0:
          # fine-tuning
          # get dataset
          train_loader = utils.graph_get_split(dataset, "train", batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True)
          val_loader = utils.graph_get_split(dataset, "val", batch_size=self.batch_size, num_workers=self.num_workers)
          # setup model
          self.encoder.encoder.eval()
          self.prediction_model_head.initialize(self.encoder)
          # just fine-tuning the prediction head
          model = self.prediction_model_head.decoder
          # setup optimizer and scheduler
          optimizer = self.f_optimizer(model.parameters(), lr=self.f_lr, weight_decay=self.f_weight_decay)
          scheduler = self._get_scheduler('finetune', optimizer)
          for epoch in range(self.f_epoch):
            model.train()
            for data in train_loader:
              optimizer.zero_grad()
              data = data.to(self.device)
              embeds = self.encoder.encoder(data)
              out = model(embeds, data)
              loss = self.f_loss(out, data.y)
              loss.backward()
              optimizer.step()
              if self.f_lr_scheduler_type:
                scheduler.step()
      
    • To implement the full model, we also need to implement the _predict_only() function to evaluate the effect of the model.

      def _predict_only(self, loader, return_label=False):
        model = self._compose_model()
        model.eval()
        pred = []
        label = []
        for data in loader:
          data = data.to(self.device)
          out = model(data)
          pred.append(out)
          label.append(data.y)
        ret = torch.cat(pred, 0)
        label = torch.cat(label, 0)
        if return_label:
          return ret, label
        else:
          return ret
      
    • duplicate_from_hyper_parameter is a method that could re-generate the trainer. However, if you don’t want to use a solver to search a good hyper-parameters automatically, you don’t need to implement it in fact.

      def duplicate_from_hyper_parameter(self, hp, encoder="same", decoder="same", prediction_head="same", restricted=True):
             hp_trainer = hp.get("trainer", {})
          hp_encoder = hp.get("encoder", {})
          hp_decoder = hp.get("decoder", {})
          hp_phead = hp.get("prediction_head", {})
          if not restricted:
            origin_hp = deepcopy(self.hyper_parameters)
            origin_hp.update(hp_trainer)
            hp = origin_hp
          else:
            hp = hp_trainer
          encoder = encoder if encoder != "same" else self.encoder
          decoder = decoder if decoder != "same" else self.decoder
          prediction_head = prediction_head if prediction_head != "same" else self.prediction_model_head
          encoder = encoder.from_hyper_parameter(hp_encoder)
          decoder.output_dimension = tuple(encoder.get_output_dimensions())[-1]
          if isinstance(encoder, BaseEncoderMaintainer) and isinstance(decoder, BaseDecoderMaintainer):
            decoder = decoder.from_hyper_parameter_and_encoder(hp_decoder, encoder)
          if isinstance(encoder, BaseEncoderMaintainer) and isinstance(prediction_head, BaseDecoderMaintainer):
            prediction_head = prediction_head.from_hyper_parameter_and_encoder(hp_phead, encoder)
          ret = self.__class__(
            model=(encoder, decoder),
            prediction_model_head=prediction_head,
            num_features=self.num_features,
            num_classes=self.num_classes,
            num_graph_features=self.num_graph_features,
            device=self.device,
            feval=self.feval,
            loss=self.loss,
            f_loss=self.f_loss,
            views_fn=self.views_fn_opt,
            aug_ratio=self.aug_ratio,
            z_dim=self.last_dim,
            neg_by_crpt=self.neg_by_crpt,
            tau=self.tau,
            model_path=self.model_path,
            num_workers=self.num_workers,
            batch_size=hp["batch_size"],
            eval_interval=self.eval_interval,
            p_optim=self.p_opt_received,
            p_lr=hp["p_lr"],
            p_lr_scheduler_type=self.p_lr_scheduler_type,
            p_epoch=hp["p_epoch"],
            p_early_stopping_round=hp["p_early_stopping_round"],
            p_weight_decay=hp["p_weight_decay"],
            f_optim=self.f_opt_received,
            f_lr=hp["f_lr"],
            f_lr_scheduler_type=self.f_lr_scheduler_type,
            f_epoch=hp["f_epoch"],
            f_early_stopping_round=hp["f_early_stopping_round"],
            f_weight_decay=hp["f_weight_decay"],
            init=True,
            *self.args,
            **self.kwargs
          )
      
          return ret