Source code for autogllight.nas.algorithm.autogt

# "AutoGT: Automated Graph Transformer Architecture Search" ICLR 23'


import json
import os
import random
import time

import torch
import torch.nn.functional as F
from tqdm import tqdm, trange

from ..estimator.base import BaseEstimator
from ..space import BaseSpace
from .base import BaseNAS


[docs]class Autogt(BaseNAS): """ AutoGT trainer. Parameters ---------- num_epochs : int Number of epochs planned for training. device : str or torch.device The device of the whole process """ def __init__( self, num_epochs=100, device="auto", disable_progress=False, args=None, ): super().__init__(device=device) self.device = device self.num_epochs = num_epochs self.disable_progress = disable_progress self.args = args def prepare(self, data): self.train_loader = data[0] self.valid_loader = data[1] self.test_loader = data[2] def train_supernet(self, optimizer, scheduler): self.space.train() total_loss = 0 for batched_data in self.train_loader[1:]: optimizer.zero_grad() y_hat = self.space(batched_data, self.space.model.gen_params()).squeeze() y_gt = batched_data.y.view(-1) loss = F.binary_cross_entropy_with_logits(y_hat, y_gt.float()) loss.backward() optimizer.step() total_loss += loss.double() * batched_data.y.shape[0] scheduler.step() return total_loss / self.train_loader[0] def train(self, optimizer, scheduler, params=None): self.space.train() total_loss = 0 for batched_data in self.train_loader[1:]: optimizer.zero_grad() y_hat = self.space(batched_data, params).squeeze() y_gt = batched_data.y.view(-1) loss = F.binary_cross_entropy_with_logits(y_hat, y_gt.float()) loss.backward() optimizer.step() total_loss += loss.double() * batched_data.y.shape[0] scheduler.step() return total_loss / self.train_loader[0] def test(self, model, data_loader, params=None): model.eval() total_correct = 0 total_loss = 0 for batched_data in data_loader[1:]: out = model(batched_data, params).squeeze() total_correct += int(((out > 0.5) == batched_data.y).sum()) loss = F.binary_cross_entropy_with_logits(out, batched_data.y.view(-1).float()) total_loss += loss.double() * batched_data.y.shape[0] return total_loss / data_loader[0], total_correct / data_loader[0] # def _infer(self, mask="train"): # if mask == "train": # mask = self.train_idx # elif mask == "valid": # mask = self.valid_idx # else: # mask = self.test_idx # metric, loss = self.estimator.infer(self.space, self.data, self.args.arch, mask=mask) # return metric, loss def gen_params(self, path, spa, edg, pma, cen): with open(path, 'r') as f: dic = json.load(f) depth = random.choice(dic['depth']) layers = [] for _ in range(0, depth): layer = [] hidden_in = random.choice(dic['hidden_in']) num_heads = random.choice(dic['num_heads']) att_size = random.choice(dic['att_size']) hidden_mid = random.choice(dic['hidden_mid']) ffn_size = random.choice(dic['ffn_size']) mask = random.choice(dic['mask']) layer.append((hidden_in, num_heads, att_size, hidden_mid, ffn_size, mask)) cen = random.choice([True, False]) eig = random.choice([True, False]) svd = random.choice([True, False]) layer.append((cen, eig, svd)) spa = spa > 0 edg = edg > 0 pma = random.choice([True, False]) layer.append((spa, edg, pma)) layers.append(tuple(layer)) return (depth, tuple(layers)) def get_directory(self): directory = f'./PROTEINS/checkpoints/{self.args.dataset_name}_4/{str(self.args.seed)}/{str(self.args.data_split)}/' return directory def get_ord(self, params): spa = int(params[1][0][2][0]) edg = int(params[1][0][2][1]) pma = int(params[1][0][2][2]) cen = int(params[1][0][1][0]) ord = spa + (edg << 1)# + (pma << 2) + (cen << 3) return ord def gen_layer(self, dic, params): layer = [] hidden_in = random.choice(dic['hidden_in']) num_heads = random.choice(dic['num_heads']) att_size = random.choice(dic['att_size']) hidden_mid = random.choice(dic['hidden_mid']) ffn_size = random.choice(dic['ffn_size']) mask = random.choice(dic['mask']) layer.append([hidden_in, num_heads, att_size, hidden_mid, ffn_size, mask]) cen = random.choice([True, False]) eig = random.choice([True, False]) svd = random.choice([True, False]) layer.append([cen, eig, svd]) spa = params[1][0][2][0] edg = params[1][0][2][1] pma = random.choice([True, False]) layer.append([spa, edg, pma]) return layer def evolution(self, directory): start = time.time() with open(self.args.path, 'r') as f: dic = json.load(f) models = [] for ord in range(4): sub_name = 'supernet_' + str(ord) + '.pt' models.append(self.space.load_model(directory + sub_name)[0]) information = {} population = [] candidates = [] def is_legal(params): if params in information: return False info = {} ord = self.get_ord(params) _, valid_acc = self.test(models[ord], self.valid_loader, params) _, test_acc_ = self.test(models[ord], self.test_loader, params) info['valid_acc'] = valid_acc info['test_acc_'] = test_acc_ # print('Top-1 Valid Accuracy = {}, Top-1 Test Accuracy = {}, Parameters = {}'.format( # valid_acc, test_acc_, params)) information[params] = info return True def get_mutation(): # start = time.time() # print('Start Mutation!') result = [] def random_function(): params = random.choice(population) depth, layers = params layers = [list(list(item) for item in layer) for layer in layers] if random.random() < self.args.m_prob: new_depth = random.choice(dic['depth']) if new_depth > depth: layers = layers + [self.gen_layer(dic, params) for _ in range(new_depth - depth)] else: layers = layers[:new_depth] depth = new_depth for i in range(depth): if random.random() < self.args.m_prob: layers[i][0][0] = random.choice(dic['hidden_in']) if random.random() < self.args.m_prob: layers[i][0][1] = random.choice(dic['num_heads']) if random.random() < self.args.m_prob: layers[i][0][2] = random.choice(dic['att_size']) if random.random() < self.args.m_prob: layers[i][0][3] = random.choice(dic['hidden_mid']) if random.random() < self.args.m_prob: layers[i][0][4] = random.choice(dic['ffn_size']) if random.random() < self.args.m_prob: layers[i][0][5] = random.choice(dic['mask']) if random.random() < self.args.m_prob: layers[i][1][0] = random.choice([True, False]) if random.random() < self.args.m_prob: layers[i][1][1] = random.choice([True, False]) if random.random() < self.args.m_prob: layers[i][1][2] = random.choice([True, False]) if random.random() < self.args.m_prob: flag = random.choice([True, False]) for i in range(depth): layers[i][2][0] = flag if random.random() < self.args.m_prob: flag = random.choice([True, False]) for i in range(depth): layers[i][2][1] = flag if random.random() < self.args.m_prob: flag = random.choice([True, False]) for i in range(depth): layers[i][2][2] = flag layers = tuple([tuple([tuple(item) for item in layer]) for layer in layers]) result = tuple([depth, layers]) return result iters = self.args.mutation_num * 10 while len(result) < self.args.mutation_num and iters > 0: iters -= 1 params = random_function() if not is_legal(params): continue result.append(params) # end = time.time() # print("End Mutation! Use time: {} s".format(end - start)) return result def get_hybridization(): # start = time.time() # print('Start Hybridization!') result = [] def random_function(): params1 = random.choice(population) params2 = random.choice(population) iters = self.args.population_num while params1[0] != params2[0] and iters > 0: iters -= 1 params1 = random.choice(population) params2 = random.choice(population) AT_choice = random.choice([params1[1][0][2], params2[1][0][2]]) layers = [] for i in range(params1[0]): shape_choice = [] for j in range(6): shape_choice.append(random.choice([params1[1][i][0][j], params2[1][i][0][j]])) PE_choice = [] for j in range(3): PE_choice.append(random.choice([params1[1][i][1][j], params2[1][i][1][j]])) layers.append(tuple([tuple(shape_choice), tuple(PE_choice), AT_choice])) params = (params1[0], tuple(layers)) return params iters = 10 * self.args.hybridization_num while len(result) < self.args.hybridization_num and iters > 0: iters -= 1 params = random_function() if not is_legal(params): continue result.append(params) # end = time.time() # print("End Hybridization! Use time: {} s".format(end - start)) return result # print("Evolution Started!") with trange(self.args.evol_epochs) as bar: for epoch in bar: # for epoch in tqdm(range(self.args.evol_epochs)): while len(candidates) < self.args.population_num: spa = random.choice([True, False]) edg = random.choice([True, False]) pma = random.choice([True, False]) cen = random.choice([True, False]) params = self.gen_params(self.args.path, spa, edg, pma, cen) if not is_legal(params): continue candidates.append(params) population += candidates population.sort(key=lambda x: information[x]['valid_acc'], reverse=True) population = population[:self.args.population_num] best_valid = information[population[0]]['valid_acc'] best_test = information[population[0]]['test_acc_'] bar.set_postfix(epoch=epoch, val_acc=best_valid, test_acc=best_test) # print('epoch = {} : top {} result'.format(epoch, len(population))) # for i, params in enumerate(population): # print('No.{} Top-1 Valid Accuracy = {}, Top-1 Test Accuracy = {}, Parameters = {}'.format( # i + 1, information[params]['valid_acc'], information[params]['test_acc_'], params)) # evolution candidates = get_mutation() + get_hybridization() # end = time.time() # print("Evolution Ended! Use time: {} s".format(end - start)) def fit(self): optimizer, lr_scheduler = self.space.model.configure_optimizers() scheduler = lr_scheduler['scheduler'] best_performance = 0 min_val_loss = float("inf") print('=> Phase 1: train supernet') for epoch in tqdm(range(self.args.split_epochs)): self.train_supernet(optimizer, scheduler) directory = self.get_directory() if not os.path.exists(directory): os.makedirs(directory) name = 'supernet.pt' self.space.save_model(optimizer, scheduler, directory + name) print('=> Phase 2: train subnets') for ord in range(4): spa = int((ord & 1) != 0) edg = int((ord & 2) != 0) pma = int((ord & 4) != 0) cen = int((ord & 8) != 0) _, optimizer, scheduler = self.space.load_model(directory + name) for epoch in tqdm(range(self.args.split_epochs, self.args.end_epochs)): params = self.gen_params(self.args.path, spa, edg, pma, cen) self.train(optimizer, scheduler, params) sub_name = 'supernet_' + str(ord) + '.pt' self.space.save_model(optimizer, scheduler, directory + sub_name) print('=> Phase 3: evolution') self.evolution(directory)
[docs] def search(self, space: BaseSpace, data, estimator: BaseEstimator): self.estimator = estimator self.space = space.to(self.device) self.prepare(data) self.fit()
# return space.parse_model(None)