from __future__ import print_function

import os
import os.path as osp

os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import time
import sys
import numpy as np
import torch
import random
from torch.autograd import Variable
from torch.nn.parameter import Parameter
from torch.autograd.gradcheck import zero_gradients
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch_geometric.datasets import TUDataset
from torch_geometric.data import DataLoader
from torch_geometric.utils import is_undirected, to_undirected, degree
import csv
from sklearn.model_selection import StratifiedKFold
import time

from tqdm import tqdm
from copy import deepcopy
import argparse
import pandas as pd
import scipy.sparse as sp

from models import graph_classifier
from utils import preprocess_adj, cal_acc
from models.Attacker import CAMAttack

# Training settings
parser = argparse.ArgumentParser()
# general set
parser.add_argument('--dataset', type=str, default="MUTAG",
                    help='Dataset to use.')
parser.add_argument('--use-node-attr', type=bool, default=True,
                    help='whether to use node attributes')
parser.add_argument('--method', type=str, default='GCN',
                    help='graph classification model')
parser.add_argument('--no-cuda', action='store_true', default=False,
                    help='Disables CUDA training.')
parser.add_argument('--seed', type=int, default=100, help='Random seed.')

# hyper parameter for GNNs
parser.add_argument('--epochs', type=int, default=100,
                    help='Number of epochs to train.')
parser.add_argument('--batch-size', type=int, default=128,
                    help='batch size')
parser.add_argument('--lr', type=float, default=0.01,
                    help='Initial learning rate.')
parser.add_argument('--weight_decay', type=float, default=0,
                    help='Weight decay (L2 loss on parameters).')
parser.add_argument('--hidden', type=int, default=64,
                    help='Number of hidden units.')
parser.add_argument('--dropout', type=float, default=0.5,
                    help='Dropout rate (1 - keep probability).')

# Attack set
parser.add_argument('--attack-prop', type=float, default=0.1,
                    help='attack edge proportion')
parser.add_argument('--feature-prop', type=int, default=0.1,
                    help='How many features choosen to be attacked')
parser.add_argument("--norm-length", type=float, default=0.1,
                    help="Variable lambda (noise number)")
parser.add_argument("--attacker", type=str, default='CAMA',
                    help="Attack Method")

def init_model():
    # train graph classification model
    if args.method == 'GCN':
        model = graph_classifier.GCN(dataset.num_features, args.hidden, dataset.num_classes).to(device)
        optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)

    return model, optimizer


def train(epoch):
    if epoch == 51:
        for param_group in optimizer.param_groups:
            param_group['lr'] = 0.5 * param_group['lr']

    loss_all = 0
    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()

        # normalize = True if args.method=='GCN' else False
        _An = preprocess_adj(data.edge_index, data.x.size(0), device, norm=normalize)
        h_v, output = model(data.x, _An, data.batch)
        loss = F.cross_entropy(output, data.y)
        loss.backward()
        loss_all += loss.item() * data.num_graphs
        optimizer.step()

    return loss_all / len(train_dataset)


def test(loader):
    model.eval()
    correct = 0
    for data in loader:
        data = data.to(device)
        _An = preprocess_adj(data.edge_index, data.x.size(0), device, normalize)
        h_v, output = model(data.x, _An, data.batch)
        correct += cal_acc(output, data.y)

    return correct / len(loader.dataset)


def evaluate(feat):
    model.eval()
    with torch.no_grad():
        _An = preprocess_adj(edge_index, num_nodes, device, normalize)
        _, output = model(feat, _An, torch.LongTensor([0] * num_nodes).to(device))
        return cal_acc(output, g.y)


def pick_feature(grad, k):
    score = grad.sum(dim=0)
    _, indexs = torch.topk(score.abs(), k)
    signs = torch.zeros(g.x.shape[1])
    for i in indexs:
        signs[i] = score[i].sign()
    return signs, indexs

args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()
if args.cuda:
    torch.cuda.manual_seed(args.seed)
device = torch.device('cuda:0' if args.cuda else 'cpu')

random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)

# Load data and split 10-fold
path = osp.join(osp.dirname(osp.realpath(__file__)), 'data', args.dataset)
dataset = TUDataset(path, name=args.dataset, use_node_attr=args.use_node_attr).shuffle()
skf = StratifiedKFold(n_splits=10, shuffle=True, random_state=args.seed)
idx_list = []
for idx in skf.split(np.zeros(len(dataset.data.y)), dataset.data.y):
    idx_list.append(idx)

# results = [[] for i in range(10)]

num_classes = dataset.num_classes
num_node_attributes = dataset.num_node_attributes
normalize = False if args.method == 'GIN' else True
early_stop = True if args.dataset == 'ENZYMES' else False

results = []

for fold_idx in range(10):

    train_idx, test_idx = idx_list[fold_idx]
    test_dataset = dataset[test_idx.tolist()]
    train_dataset = dataset[train_idx.tolist()]
    test_loader = DataLoader(test_dataset, batch_size=args.batch_size)
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size)

    model, optimizer = init_model()

    # Train Graph Classifier
    cur_acc=0

    for epoch in range(1, args.epochs+1):
        train_loss = train(epoch)
        train_acc = test(train_loader)
        test_acc = test(test_loader)
        if test_acc > cur_acc:
            cur_acc=test_acc
            # best_acc[fold_idx]=cur_acc
            torch.save(model.state_dict(), f'ckpt/dataset_{args.dataset}_{args.method}_model_fold_{fold_idx}.pth')
            patience=0
        else:
            if early_stop:
                patience+=1
                if patience>500:
                    break

        print('Epoch: {:03d}, Train Loss: {:.7f}, '
                'Train Acc: {:.7f}, Test Acc: {:.7f}'.format(epoch, train_loss,
                                                            train_acc, test_acc))
    print(f'==========================Start feature attack for fold {fold_idx}...')
    if args.method == 'GCN':
        model = graph_classifier.GCN(dataset.num_features, 64, dataset.num_classes).to(device)
    elif args.method == 'GIN':
        model = graph_classifier.GIN(dataset.num_features, args.hidden, dataset.num_classes).to(device)
    elif args.method == 'GNet':
        model = graph_classifier.GNet(dataset.num_features, dataset.num_classes, args).to(device)
    else:
        model = graph_classifier.IGNN(dataset.num_features, 32, dataset.num_classes, num_node=None,
                                      dropout=args.dropout, kappa=args.kappa).to(device)

    model.load_state_dict(torch.load(f'ckpt/dataset_{args.dataset}_{args.method}_model_fold_{fold_idx}.pth'))
    last_W = list(model.parameters())[-2].data

    attack_results = 0

    for pos in tqdm(range(len(test_dataset))):
        g = test_dataset[pos]
        g = g.to(device)
        num_nodes = g.num_nodes
        edge_index = g.edge_index

        print(f'There are {len(edge_index[0])} edges, {num_nodes} nodes in this graph. Graph label is {g.y.item()}.')
        if not is_undirected(edge_index):
            print('Change directed graph to undirected')
            edge_index = to_undirected(edge_index)
        orig_acc = evaluate(g.x)

        if orig_acc == 0:
            attack_results +=1
            continue

        r = int(np.ceil(num_nodes * args.attack_prop))
        # get signs of each node feature
        g.x.requires_grad_(True)
        model.eval()
        _An = preprocess_adj(edge_index, num_nodes, device, normalize)
        h_v, logits = model(g.x, _An, torch.LongTensor([0] * num_nodes).to(device))
        print('logits', logits)
        loss = F.cross_entropy(logits, g.y)
        zero_gradients(g.x)
        loss.backward(retain_graph=True)
        grad = g.x.grad
        k = int(np.ceil(args.feature_prop * dataset.num_features))
        signs, indexs = pick_feature(grad, k)
        g.x.requires_grad_(False)
        print(f'Choose {k} node feature: ', indexs)


        if args.attacker == 'CAMA':
            cama = CAMAttack(model, args.method, device)
            after_acc = cama.feature_attack(g, args.norm_length, indexs, signs, r)

            if after_acc == 0:
                attack_results+= 1
        elif args.attacker == 'CAMA_grad':
            cama = CAMAttack(model, args.method, device, grad=True)
            after_acc = cama.feature_attack(g, args.norm_length, indexs, signs, r)
            if after_acc == 0:
                attack_results += 1

    results.append(1 - attack_results/ len(test_dataset))
