# "Graph Neural Architecture Search Under Distribution Shifts" ICML 22'
import torch
from tqdm import trange
from ..estimator.base import BaseEstimator
from ..space import BaseSpace
from .base import BaseNAS
[docs]class Graces(BaseNAS):
"""
GRACES trainer.
Parameters
----------
num_epochs : int
Number of epochs planned for training.
device : str or torch.device
The device of the whole process
"""
def __init__(
self,
num_epochs=250,
device="auto",
disable_progress=False,
args=None,
):
super().__init__(device=device)
self.num_epochs = num_epochs
self.disable_progress = disable_progress
self.args = args
def train_graph(
self,
model_optimizer,
arch_optimizer,
gnn0_optimizer,
eta,
):
self.space.train()
for id, train_data in enumerate(self.train_loader):
model_optimizer.zero_grad()
arch_optimizer.zero_grad()
gnn0_optimizer.zero_grad()
train_data = train_data.to(self.device)
output0, output, cosloss, sslout = self.space(train_data)
output0 = output0.to(self.device)
output = output.to(self.device)
is_labeled = train_data.y == train_data.y
error_loss0 = self.criterion(
output0.to(torch.float32)[is_labeled],
train_data.y.to(torch.float32)[is_labeled],
)
error_loss = self.criterion(
output.to(torch.float32)[is_labeled],
train_data.y.to(torch.float32)[is_labeled],
)
ssltarget = train_data.deratio.view(-1, 3)
ssllossfun = torch.nn.L1Loss()
sslloss = ssllossfun(sslout, ssltarget)
my_loss = (1 - eta) * (
error_loss0 + self.args.gamma * sslloss + self.args.beta * cosloss
) + eta * error_loss
my_loss.backward()
model_optimizer.step()
gnn0_optimizer.step()
arch_optimizer.step()
def _infer(self, mask="train"):
if mask == "train":
dataloader = self.train_loader
elif mask == "val":
dataloader = self.val_loader
elif mask == "test":
dataloader = self.test_loader
metric, loss = self.estimator.infer(self.space, dataloader)
return metric, loss
[docs] def prepare(self, data):
"""
data : list of data objects.
[dataset, train_dataset, val_dataset, test_dataset, train_loader, val_loader, test_loader]
"""
self.train_loader = data[4]
self.val_loader = data[5]
self.test_loader = data[6]
def fit(self):
optimizer = torch.optim.Adam(
self.space.supernet.parameters(),
self.args.learning_rate,
weight_decay=self.args.weight_decay,
)
arch_optimizer = torch.optim.Adam(
self.space.ag.parameters(),
self.args.arch_learning_rate,
weight_decay=self.args.arch_weight_decay,
)
gnn0_optimizer = torch.optim.Adam(
self.space.supernet0.parameters(),
self.args.gnn0_learning_rate,
weight_decay=self.args.gnn0_weight_decay,
)
scheduler_arch = torch.optim.lr_scheduler.CosineAnnealingLR(
arch_optimizer,
float(self.num_epochs),
eta_min=self.args.arch_learning_rate_min,
)
scheduler_gnn0 = torch.optim.lr_scheduler.CosineAnnealingLR(
gnn0_optimizer,
float(self.num_epochs),
eta_min=self.args.gnn0_learning_rate_min,
)
self.criterion = torch.nn.BCEWithLogitsLoss()
eta = self.args.eta
best_performance = 0
min_val_loss = float("inf")
with trange(self.num_epochs, disable=self.disable_progress) as bar:
for epoch in bar:
"""
space training
"""
self.space.train()
eta = (
self.args.eta_max - self.args.eta
) * epoch / self.num_epochs + self.args.eta
optimizer.zero_grad()
arch_optimizer.zero_grad()
gnn0_optimizer.zero_grad()
self.train_graph(
optimizer,
arch_optimizer,
gnn0_optimizer,
eta,
)
scheduler_arch.step()
scheduler_gnn0.step()
"""
space evaluation
"""
self.space.eval()
train_metric, train_loss = self._infer("train")
val_metric, val_loss = self._infer("val")
# test_metric, test_loss = self._infer("test")
if min_val_loss > val_loss:
min_val_loss, best_performance = val_loss, val_metric["auc"]
self.space.keep_prediction()
bar.set_postfix(
{
"train_auc": train_metric["auc"],
"val_auc": val_metric["auc"],
# "test_auc": test_metric["auc"],
}
)
return best_performance, min_val_loss
[docs] def search(self, space: BaseSpace, dataset, estimator: BaseEstimator):
self.estimator = estimator
self.space = space.to(self.device)
self.prepare(dataset)
perf, val_loss = self.fit()
return space.parse_model(None)