from itertools import combinations
from numpy.core.numeric import False_
import seaborn as sns
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from torch.autograd import Variable

import random
from torch_geometric.utils import degree, num_nodes
import torch.nn.functional as F

from utils import preprocess_adj, cal_acc


class CAMAttack():
    def __init__(self, surrogate, clf_name, device='cpu', grad=False):
        self.surrogate = surrogate
        self.clf_name = clf_name
        self.grad = grad
        self.device = device
        self.last_W = list(surrogate.parameters())[-2].data

        # normalize adj
        if self.clf_name == 'GIN':
            self.normalize = False
        else:
            self.normalize = True

    def get_L_cam(self, graph):
        num_nodes = graph.x.size(0)
        num_classes = self.last_W.size(0)
        self.adj_orig = preprocess_adj(graph.edge_index, total_nodes_num=num_nodes, norm=self.normalize,
                                       device=self.device)  # sparse matrix
        h_v, _ = self.surrogate(graph.x, self.adj_orig, torch.LongTensor([0] * num_nodes).to(self.device))

        if self.clf_name == 'GIN' or self.clf_name == 'GCN':
            x1, x2, x3, x4, x5, logits = self.surrogate(graph.x, self.adj_orig,
                                                        torch.LongTensor([0] * num_nodes).to(self.device),
                                                        return_mid=True)
            x = torch.stack([x1, x2, x3, x4, x5], dim=0)  # [L, N, D]
        else:
            x1, x2, x3, logits = self.surrogate(graph.x, self.adj_orig,
                                                torch.LongTensor([0] * num_nodes).to(self.device), return_mid=True)
            x = torch.stack([x1, x2, x3], dim=0)  # [L, N, D]
        self.x = x1

        # grad CAM
        if self.grad:
            L_cam = torch.empty(num_nodes, num_classes)

            grads = {}

            def save_grad(name):
                def hook(grad):
                    grads[name] = grad

                return hook

            for c in range(num_classes):
                L = [0, 1, 2] if self.clf_name == 'IGNN' or self.clf_name == 'GNet' else [0, 1, 2, 3, 4]
                alphas = torch.empty(len(L), x1.size(1))
                y_c = logits[:, c]
                x1.register_hook(save_grad('x1'))
                x2.register_hook(save_grad('x2'))
                x3.register_hook(save_grad('x3'))
                if self.clf_name == 'GIN' or self.clf_name == 'GCN':
                    x4.register_hook(save_grad('x4'))
                    x5.register_hook(save_grad('x5'))

                    y_c.backward(retain_graph=True)
                    y_grad = torch.stack([grads['x1'], grads['x2'], grads['x3'], grads['x4'], grads['x5']],
                                         dim=0)  # [L,N,D]
                elif self.clf_name == 'IGNN' or self.clf_name == 'GNet':
                    y_c.backward(retain_graph=True)
                    y_grad = torch.stack([grads['x1'], grads['x2'], grads['x3']], dim=0)  # [L,N,D]

                alphas = torch.mean(y_grad, dim=1)  # [L, D]
                L_cam_c = torch.stack([torch.matmul(x[l], alphas[l]) for l in L], dim=0)  # [L, N]
                L_cam[:, c] = torch.mean(F.relu(L_cam_c), dim=0)  # [N, ]
                # W_cam[:,c] = torch.max(F.relu(L_cam), dim=0)[0]     # [N, ]

        # CAM
        else:
            L_cam = F.relu(torch.matmul(h_v, self.last_W.T))

        self.L_cam = L_cam

    def struct_attack(self, graph, n_perturbations, victim_model=None, victim_normalize=None):

        self.get_L_cam(graph)
        L_cam = self.L_cam.detach().cpu().numpy()

        rank_idx_matrix = np.zeros(L_cam.shape)

        # cal similarity
        norm_X = self.x / torch.norm(self.x, p=2, dim=1, keepdim=True)
        similarity = torch.mm(norm_X, norm_X.T)

        for j in range(L_cam.shape[1]):
            U_c = L_cam[:, j].argsort(kind='mergesort')[::-1]
            rank_idx_matrix[:, j] = U_c
            cur_acc = self.struc_attack_col(n_perturbations, U_c, graph, similarity, victim_model, victim_normalize)
            if cur_acc == 0:
                return 0

        # U_global
        a = rank_idx_matrix.reshape(-1)
        indexes = np.unique(a, return_index=True)[1]
        U_global = [int(a[index]) for index in sorted(indexes)]

        return self.struc_attack_col(n_perturbations, U_global, graph, similarity, victim_model, victim_normalize)

    def feature_attack(self, graph, norm_length, indexs, signs, perturb_nodes_num, victim_model=None,
                       victim_normalize=None):
        self.get_L_cam(graph)
        L_cam = self.L_cam.detach().cpu().numpy()
        rank_idx_matrix = np.zeros_like(L_cam)
        num_nodes = graph.x.size(0)

        if victim_model is None:
            victim_model = self.surrogate
            adj = self.adj_orig
        else:
            adj = preprocess_adj(graph.edge_index, num_nodes, self.device, victim_normalize)

        for j in range(L_cam.shape[1]):
            U_c = L_cam[:, j].argsort(kind='mergesort')[::-1]
            rank_idx_matrix[:, j] = U_c

            feat = graph.x.clone()
            perturb_nodes_list = U_c[:perturb_nodes_num]
            for u in perturb_nodes_list:
                for index in indexs:
                    feat[u][index] += norm_length * signs[index]
            self.modified_feature = feat

            victim_model.eval()
            with torch.no_grad():
                _, output = victim_model(feat, adj, torch.LongTensor([0] * num_nodes).to(self.device))
                after_acc = cal_acc(output, graph.y)
            if after_acc == 0:
                return 0

        # U_global
        a = rank_idx_matrix.reshape(-1)
        indexes = np.unique(a, return_index=True)[1]
        rk_nodes = [int(a[index]) for index in sorted(indexes)]

        perturb_nodes_list = rk_nodes[:perturb_nodes_num]
        # print(f'Choose perturb nodes for cam attack for strategy2: ', perturb_nodes_list)
        for u in perturb_nodes_list:
            for index in indexs:
                feat[u][index] += norm_length * signs[index]

        victim_model.eval()
        with torch.no_grad():
            _, output = victim_model(feat, adj, torch.LongTensor([0] * num_nodes).to(self.device))
            return cal_acc(output, graph.y)

    def struc_attack_col(self, n_perturbations, U_c, graph, similarity, victim_model=None, victim_normalize=None):

        edge_index = graph.edge_index.clone()
        edge_set = [(edge[0], edge[1]) for edge in edge_index.T.cpu().numpy()]
        num_nodes = len(U_c)
        if victim_model is None:
            victim_model = self.surrogate
            victim_normalize = self.normalize

        structural_perturbations = []
        r = 1
        import_nodes = list(U_c[:2])

        potential_edge_list = [c for c in combinations(import_nodes, 2)]

        brk = False
        while r < num_nodes:

            for (u, v) in potential_edge_list:
                # delete
                if (u, v) in edge_set and (similarity[u, v] >= 0.95):
                    edge_index = torch.LongTensor(list(
                        set(map(tuple, edge_index.T)) - set(map(tuple, torch.LongTensor([[u, v], [v, u]]).T)))).T.to(
                        self.device)
                    structural_perturbations.append((u, v))
                    if len(structural_perturbations) >= n_perturbations:
                        brk = True
                        break

                # add
                elif (u, v) not in edge_set and (similarity[u, v] <= 1):
                    edge_index = torch.cat((edge_index, torch.LongTensor([[u, v], [v, u]]).to(self.device)), 1)
                    structural_perturbations.append((u, v))
                    if len(structural_perturbations) >= n_perturbations:
                        brk = True
                        break

            if brk:
                break
            # update importance nodes according to rank
            r += 1
            new_nodes = [U_c[r]]  # rank r nodes

            new_out_edges_list = np.array(np.meshgrid(new_nodes, import_nodes)).T.reshape(-1, 2)
            import_nodes = import_nodes + new_nodes
            potential_edge_list = [(e[0], e[1]) for e in new_out_edges_list]

            # random.shuffle(final_edge_list)
        self.modified_adj = edge_index

        # print('CAM structural_perturbations:',structural_perturbations )
        adj = preprocess_adj(edge_index, num_nodes, self.device, victim_normalize)
        _, output = victim_model(graph.x, adj, torch.LongTensor([0] * num_nodes).to(self.device))
        cur_acc = cal_acc(output, graph.y)

        return cur_acc
