AutoGL SSL Trainer
AutoGL project use trainer
to implement the graph self-supervised
methods. Currently, we only support the
GraphCL
with semi-supervised downstream tasks:
GraphCLSemisupervisedTrainer
using GraphCL algorithm for semi-supervised downstream tasks, the main interfaces are shown belowtrain(self, dataset, keep_valid_result=True)
: The function of training on the given dataset and keeping valid resultsdataset
: the graph dataset used to be trainedkeep_valid_result
: ifTrue
, save the validation result after training. Only ifkeep_valid_result
isTrue
and after training, the methodget_valid_score
,get_valid_predict_proba
andget_valid_predict
could output meaningful results.
predict(self, dataset, mask="test")
: The function of predicting on the given datasetdataset
: the graph dataset used to be predictedmask
:"train", "val" or "test"
, the dataset mask
predict_proba(self, dataset, mask="test", in_log_format=False)
: The function of predicting the probability on the given dataset.dataset
: the graph dataset used to be predictedmask
:"train", "val" or "test"
, the dataset maskin_log_format
: ifin_log_format
isTrue
, the probability will be log format
evaluate(self, dataset, mask="val", feval=None)
: The function of evaluating the model on the given dataset and keeping valid result.dataset
: the graph dataset used to be evaluatedmask
:"train", "val" or "test"
, the dataset maskfeval
: the evaluation method used in this function. Iffeval
isNone
, it will use thefeval
given when initiate
get_valid_score(self, return_major=True)
: The function of getting valid scores after training.return_major
: ifreturn_major
isTrue
, then return only consists of the major result.
get_valid_predict_proba(self)
: Get the prediction probability of the valid set after training.get_valid_predict(self)
: Get the valid result after training
Lazy Initialization
Similar reason to :ref:model, we also use lazy initialization for all
trainers. Only (part of) the hyper-parameters will be set when
__init__()
is called. The trainer
will have its core model
only after initialize()
is explicitly called, which will be done
automatically in solver
and duplicate_from_hyper_parameter()
,
after all the hyper-parameters are set properly.
For example, if you want to set gcn
as encoder, a simple mlp
as
a decoder, and use mlp
as a classifier to solve a graph
classification problem, there are three steps you need to do.
First, import everything you need
from autogl.module.train.ssl import GraphCLSemisupervisedTrainer from autogl.datasets import build_dataset_from_name, utils from autogl.datasets.utils.conversion import to_pyg_dataset as convert_dataset
Secondly, setup the hyper-parameters of the encoder, decoder and the classifier
trainer_hp = { 'batch_size': 128, 'p_lr': 0.0001, # learning rate of pretraining stage 'p_weight_decay': 0, # weight decay of pretraining stage 'p_epoch': 100, # max epoch of pretraining stage 'p_early_stopping_round': 100, # early stopping round of pretraining stage 'f_lr': 0.0001, # learning rate of fine-tuning stage 'f_weight_decay': 0, # weight decay of fine-tuning stage 'f_epoch': 100, # max epoch of fine-tuning stage 'f_early_stopping_round': 100, # early stopping round of fine-tuning stage } encoder_hp = { 'num_layers': 3, 'hidden': [64, 128], # hidden dimensions, didn't need to set the dimension of final layer 'dropout': 0.5, 'act': 'relu', 'eps': 'false' } decoder_hp = { 'hidden': 64, 'act': 'relu', 'dropout': 0.5 } prediction_head_hp = { 'hidden': 64, 'act': 'relu', 'dropout': 0.5 }
Thirdly, use
duplicate_from_hyper_parameter()
dataset = build_dataset_from_name('proteins') dataset = convert_dataset(dataset) utils.graph_random_splits(dataset, train_ratio=0.1, val_ratio=0.1, seed=2022) # split the dataset # generate a trainer, but it couldn't be used # before you call `duplicate_from_hyper_parameter` trainer = GraphCLSemisupervisedTrainer( model=('gcn', 'sumpoolmlp'), prediction_model_head='sumpoolmlp', views_fn=['random2', 'random2'], num_features=dataset[0].x.size(1), num_classes=max([data.y.item() for data in dataset]) + 1, z_dim=128, # the embedding dimension init=False ) # call duplicate_from_hyper_parameter to set some information about # model architecture and learning hyperparameters trainer.initialize() trainer = trainer.duplicate_from_hyper_parameter( { 'trainer': trainer_hp, 'encoder': encoder_hp, 'decoder': decoder_hp, 'prediction_head': prediction_head_hp } )
Train and Predict
After initializing a trainer, you can train it on the given datasets.
We are given the training and testing functions for the tasks of graph classification. You can also create your own tasks following similar patterns to ours.
We provide some interfaces, and you can easily use them to train or test on the given datasets.
Training:
train()
trainer.train(dataset, keep_valid_result=False)
train()
is the method of training on the given dataset and keeping valid results.It has two parameters, the first parameter is
dataset
, which is the graph dataset used to be trained. And the second parameter iskeep_valid_result
. It is a bool value, if true, the trainer will save the validation result after training if the dataset has a validation set.Testing:
predict()
trainer.predict(dataset, 'test').detach().cpu().numpy()
predict()
is the method of predicting the given dataset.It has two parameters, the first parameter is
dataset
, which is the graph dataset used to be predicted. And the second parameter ismask
. It is a string which can be ‘train’, ‘val’, or ‘test’. And returns the prediction results.Evaluation:
evaluate()
result = trainer.evaluate(dataset, 'test') # return a list of metrics, the default metric is accuracy
evaluate()
is the method of evaluating the model on the given dataset and keeping valid results.It has three parameters, the first parameter is
dataset
, which is the graph dataset used to be evaluated. And the second parameter ismask
. It is a string which can be ‘train’, ‘val’ or ‘test’. And the last parameter isfeval
, which can be a string, tuple of strings, or None, which means the used evaluation methods suchAcc
.And you can write your own evaluation metrics and methods. Here is a simple example:
from autogl.module.train.evaluation import Evaluation, register_evaluate from sklearn.metrics import accuracy_score @register_evaluate("my_acc") # use method register_evaluate, and then you can use this class by its register name 'my_acc' class MyAcc(Evaluation): @staticmethod def get_eval_name(): ''' define the name, didn't need to same as the registered name ''' return "my_acc" @staticmethod def is_higher_better(): ''' return whether this evaluation method is higher better (bool) ''' return True @staticmethod def evaluate(predict, label): ''' return the evaluation result (float) ''' if len(predict.shape) == 2: predict = np.argmax(predict, axis=1) else: predict = [1 if p > 0.5 else 0 for p in predict] return accuracy_score(label, predict)
Implement SSL Trainer
Next, we will show how to implement your own ssl trainer. It is more
difficult to implement the trainer than to use it, it needs to implement
three main functions _train_only()
, _predict_only()
and
duplicate_from_hyper_parameter()
. Now we will implement GraphCL with
unsupervised downstream tasks step by step.
initialize your trainer
First, We need to import some classes and methods, define a basic
__init__()
method, and register our trainer.import torch from torch.optim.lr_scheduler import StepLR from autogl.module.train import register_trainer from autogl.module.train.ssl.base import BaseContrastiveTrainer from autogl.datasets import utils @register_trainer("GraphCLUnsupervisedTrainer") class GraphCLUnsupervisedTrainer(BaseContrastiveTrainer): def __init__( self, model, prediction_model_head, num_features, num_classes, num_graph_features, device, feval, views_fn, z_dim, num_workers, batch_size, eval_interval, init, *args, **kwargs, ): # setup encoder and decoder if isinstance(model, Tuple): encoder, decoder = model elif isinstance(model, BaseAutoModel): raise ValueError("The GraphCL trainer must need an encoder and a decoder, so `model` shouldn't be an instance of `BaseAutoModel`") else: encoder, decoder = model, "sumpoolmlp" self.eval_interval = eval_interval # init contrastive learning super().__init__( encoder=encoder, decoder=decoder, decoder_node=None, num_features=num_features, num_graph_features=num_graph_features, views_fn=views_fn, graph_level=True, # have graph-level features node_level=False, # have node-level features device=device, feval=feval, z_dim=z_dim, # the dimension of the embedding output by encoder z_node_dim=None, *args, **kwargs, ) # initialize something specific for your own method self.views_fn = views_fn self.aug_ratio = aug_ratio self._prediction_model_head = None self.num_classes = num_classes self.prediction_model_head = prediction_model_head self.batch_size = batch_size self.num_workers = num_workers if self.num_workers > 0: mp.set_start_method("fork", force=True) # setup the hyperparameter when initializing self.hyper_parameters = { "batch_size": self.batch_size, "p_epoch": self.p_epoch, "p_early_stopping_round": self.p_early_stopping_round, "p_lr": self.p_lr, "p_weight_decay": self.p_weight_decay, "f_epoch": self.f_epoch, "f_early_stopping_round": self.f_early_stopping_round, "f_lr": self.f_lr, "f_weight_decay": self.f_weight_decay, } self.args = args self.kwargs = kwargs if init: self.initialize()
_train_only(self, dataset)
In this method, the trainer trains the model on the given dataset. You can define several different methods for different training stages.
set the model on the specified device
def _set_model_device(self, dataset): self.encoder.encoder.to(self.device) self.decoder.decoder.to(self.device)
For training, you can simply call
super(). _train_pretraining_only(dataset, per_epoch)
to train the encoder.for i, epoch in enumerate(super()._train_pretraining_only(dataset, per_epoch=True)): # you can define your own training process if you want # for example, we will fine-tune for every eval_interval epoch if (i + 1) % self.eval_interval == 0: # fine-tuning # get dataset train_loader = utils.graph_get_split(dataset, "train", batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True) val_loader = utils.graph_get_split(dataset, "val", batch_size=self.batch_size, num_workers=self.num_workers) # setup model self.encoder.encoder.eval() self.prediction_model_head.initialize(self.encoder) # just fine-tuning the prediction head model = self.prediction_model_head.decoder # setup optimizer and scheduler optimizer = self.f_optimizer(model.parameters(), lr=self.f_lr, weight_decay=self.f_weight_decay) scheduler = self._get_scheduler('finetune', optimizer) for epoch in range(self.f_epoch): model.train() for data in train_loader: optimizer.zero_grad() data = data.to(self.device) embeds = self.encoder.encoder(data) out = model(embeds, data) loss = self.f_loss(out, data.y) loss.backward() optimizer.step() if self.f_lr_scheduler_type: scheduler.step()
To implement the full model, we also need to implement the
_predict_only()
function to evaluate the effect of the model.def _predict_only(self, loader, return_label=False): model = self._compose_model() model.eval() pred = [] label = [] for data in loader: data = data.to(self.device) out = model(data) pred.append(out) label.append(data.y) ret = torch.cat(pred, 0) label = torch.cat(label, 0) if return_label: return ret, label else: return ret
duplicate_from_hyper_parameter
is a method that could re-generate the trainer. However, if you don’t want to use a solver to search a good hyper-parameters automatically, you don’t need to implement it in fact.def duplicate_from_hyper_parameter(self, hp, encoder="same", decoder="same", prediction_head="same", restricted=True): hp_trainer = hp.get("trainer", {}) hp_encoder = hp.get("encoder", {}) hp_decoder = hp.get("decoder", {}) hp_phead = hp.get("prediction_head", {}) if not restricted: origin_hp = deepcopy(self.hyper_parameters) origin_hp.update(hp_trainer) hp = origin_hp else: hp = hp_trainer encoder = encoder if encoder != "same" else self.encoder decoder = decoder if decoder != "same" else self.decoder prediction_head = prediction_head if prediction_head != "same" else self.prediction_model_head encoder = encoder.from_hyper_parameter(hp_encoder) decoder.output_dimension = tuple(encoder.get_output_dimensions())[-1] if isinstance(encoder, BaseEncoderMaintainer) and isinstance(decoder, BaseDecoderMaintainer): decoder = decoder.from_hyper_parameter_and_encoder(hp_decoder, encoder) if isinstance(encoder, BaseEncoderMaintainer) and isinstance(prediction_head, BaseDecoderMaintainer): prediction_head = prediction_head.from_hyper_parameter_and_encoder(hp_phead, encoder) ret = self.__class__( model=(encoder, decoder), prediction_model_head=prediction_head, num_features=self.num_features, num_classes=self.num_classes, num_graph_features=self.num_graph_features, device=self.device, feval=self.feval, loss=self.loss, f_loss=self.f_loss, views_fn=self.views_fn_opt, aug_ratio=self.aug_ratio, z_dim=self.last_dim, neg_by_crpt=self.neg_by_crpt, tau=self.tau, model_path=self.model_path, num_workers=self.num_workers, batch_size=hp["batch_size"], eval_interval=self.eval_interval, p_optim=self.p_opt_received, p_lr=hp["p_lr"], p_lr_scheduler_type=self.p_lr_scheduler_type, p_epoch=hp["p_epoch"], p_early_stopping_round=hp["p_early_stopping_round"], p_weight_decay=hp["p_weight_decay"], f_optim=self.f_opt_received, f_lr=hp["f_lr"], f_lr_scheduler_type=self.f_lr_scheduler_type, f_epoch=hp["f_epoch"], f_early_stopping_round=hp["f_early_stopping_round"], f_weight_decay=hp["f_weight_decay"], init=True, *self.args, **self.kwargs ) return ret