Source code for autogl.module.nas.estimator.grna_estimator

import torch
from deeprobust.graph.global_attack import Random,DICE
import numpy as np
import scipy.sparse as sp
from tqdm import tqdm
import time
import torch.nn.functional as F

from . import register_nas_estimator
from ..space import BaseSpace
from .base import BaseEstimator
from ..backend import *
from ...train.evaluation import Acc
from ..utils import get_hardware_aware_metric

[docs]@register_nas_estimator("grna") class GRNAEstimator(BaseEstimator): """ Graph robust neural architecture estimator under adversarial attack. Use model directly to get estimations. Parameters ---------- loss_f : str The name of loss funciton in PyTorch evaluation : list of Evaluation The evaluation metrics in module/train/evaluation GRNA_metric = acc_metric+ robustness_metric lambda: float The hyper-parameter to balance the accuracy metric and robustness metric to perform ultimate evaluation perturb_type: str Perturbation methods to simulate the adversarial attack process adv_sample_num: int Adversarial sample number used in measure architecture robustness. """ def __init__(self, loss_f="nll_loss", evaluation=[Acc()], lambda_=0.05, perturb_type='random', adv_sample_num=10, dis_type='ce', ptbr=0.05): super().__init__(loss_f, evaluation) self.evaluation = evaluation self.lambda_ = lambda_ self.perturb_type = perturb_type self.dis_type = dis_type self.adv_sample_num = adv_sample_num self.ptbr = ptbr print('initialize GRNA estimator')
[docs] def infer(self, model: BaseSpace, dataset, mask="train"): device = next(model.parameters()).device dset = dataset[0].to(device) mask = bk_mask(dset, mask) pred = model(dset)[mask] label = bk_label(dset) y = label[mask] loss = getattr(F, self.loss_f)(pred, y) probs = F.softmax(pred, dim=1).detach().cpu().numpy() # robustness metric dist = 0 for _ in range(self.adv_sample_num): modified_adj = self.gen_adversarial_samples(dset.edge_index, num_nodes=dset.num_nodes, perturb_prop=self.ptbr, attack_method=self.perturb_type) d_data = dset.clone() d_data = d_data.to(device) edge_index = torch.LongTensor(np.vstack((modified_adj.tocoo().row,modified_adj.tocoo().col))) d_data.edge_index = edge_index.to(device) perturb_pred = model(d_data)[mask] dist += distance(perturb_pred, pred, dis_type=self.dis_type) dist = dist/self.adv_sample_num y = y.cpu() metrics = [eva.evaluate(probs, y)+self.lambda_*dist for eva in self.evaluation] return metrics, loss
def gen_adversarial_samples(self, edge_index, num_nodes, perturb_prop, attack_method='random'): if num_nodes is None: num_nodes = max(edge_index[0])+1 edge_index = edge_index.detach().cpu().numpy() delta = int(edge_index.shape[1]//2 * perturb_prop) v = np.ones_like(edge_index[0]) adj = sp.csr_matrix((v,(edge_index[0], edge_index[1])), shape=(num_nodes,num_nodes)) if attack_method=='random': attacker = Random() attacker.attack(adj, n_perturbations=delta, type='flip') elif attack_method=='dice': labels = self.data.y.cpu().numpy() attacker = DICE() attacker.attack(adj, labels, delta) else: assert False, 'Wrong Type of attack method!' modified_adj = attacker.modified_adj # scipy.sparse matrix return modified_adj.tocsr()
def distance(perturb, clean, dis_type='ce', data=None, p=2): """ Distance between logits of perturbed and clean data. Parameters: --------- perturb: torch.Tensor [n, C] clean: torch.Tensor [n,C] type: loss type labels: ground truth labels, needed when type='cw' p: fro norm, needed when type='fro' Return ------ Distance: torch.Tensor [n,] """ # if type=='cos': # return perturb*clean / torch.sqrt(torch.norm(perturb,p=2) * torch.norm(clean)) if dis_type=='fro': return torch.norm(perturb - clean,p=p) elif dis_type=='ce': p_ = F.softmax(clean,-1) logq = F.log_softmax(perturb, -1) return -(p_*logq).mean() elif dis_type=='kl': logq = F.log_softmax(perturb,-1) p_ = F.softmax(clean, -1) logp = F.log_softmax(clean,-1) return (p_*(logp-logq)).mean()*100 elif dis_type=='cw': perturb, clean, labels = perturb[data.train_mask],clean[data.train_mask], data.y[data.train_mask] eye = torch.eye(labels.max() + 1) onehot_mx = eye[labels] one_hot_labels = onehot_mx.to(labels.device) # perturb ptb_best_second_class = (perturb - 1000*one_hot_labels).argmax(1) margin = perturb[np.arange(len(perturb)), labels] - \ perturb[np.arange(len(perturb)), ptb_best_second_class] ptb_loss = -torch.clamp(margin, max = 0, min = -0.1).mean() # clean clean_best_second_class = (perturb - 1000*one_hot_labels).argmax(1) margin = clean[np.arange(len(perturb)), labels] - \ clean[np.arange(len(perturb)), clean_best_second_class] clean_loss = -torch.clamp(margin, max = 0, min = -0.1).mean() return (ptb_loss-clean_loss)*100