from __future__ import print_function

import os
import os.path as osp

os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import time
import sys
import numpy as np
import torch
import random
import torch.nn.functional as F
from torch_geometric.datasets import TUDataset
from torch_geometric.data import DataLoader
from torch_geometric.utils import is_undirected, to_undirected
import csv
from sklearn.model_selection import StratifiedKFold
import time

from tqdm import tqdm
from copy import deepcopy
import argparse
from utils import preprocess_adj, cal_acc
from models import graph_classifier, cam_attack
from models.Attacker import CAMAttack

# Training settings
parser = argparse.ArgumentParser()
# general set
parser.add_argument('--dataset', type=str, default="MUTAG",
                    help='Dataset to use.')  # [MUTAG, PROTEINS, ENZYMES, NCI1, COX2, IMDB-BINARY, IMDB-MULTI]
parser.add_argument('--use-node-attr', type=bool, default=True,
                    help='whether to use node attributes')
parser.add_argument('--method', type=str, default='GCN',
                    help='graph classification model')  # [GCN, GIN, IGNN, GNet]
parser.add_argument('--no-cuda', action='store_true', default=False,
                    help='Disables CUDA training.')
parser.add_argument('--seed', type=int, default=100, help='Random seed.')
parser.add_argument('--epochs', type=int, default=100,
                    help='Number of epochs to train.')
parser.add_argument('--batch-size', type=int, default=128,
                    help='batch size')
# parser.add_argument('--writer', type=bool, default=False,
#                     help='whether save result to csv file')

# hyper parameter for GNNs
parser.add_argument('--lr', type=float, default=0.01,
                    help='Initial learning rate.')
parser.add_argument('--weight_decay', type=float, default=0,
                    help='Weight decay (L2 loss on parameters).')
parser.add_argument('--hidden', type=int, default=32,
                    help='Number of hidden units.')
parser.add_argument('--dropout', type=float, default=0.5,
                    help='Dropout rate (1 - keep probability).')
# hyper parameter for GraphUNet
parser.add_argument('-l_num', type=int, default=4, help='layer num')
parser.add_argument('-h_dim', type=int, default=512, help='hidden dim')
parser.add_argument('-l_dim', type=int, default=48, help='layer dim')
parser.add_argument('-drop_n', type=float, default=0.3, help='drop net')
parser.add_argument('-act_n', type=str, default='ELU', help='network act')
parser.add_argument('-ks', type=list, default=[0.9, 0.7, 0.6, 0.5])
# hyper parameter for IGNN
parser.add_argument('--kappa', type=float, default=0.98,
                    help='Projection parameter. ||W|| <= kappa/lpf(A)')
# Attack set
parser.add_argument('--attack-prop', type=float, default=0.1,
                    help='attack edge proportion')
parser.add_argument("--attacker", type=str, default='CAMA',
                    help="Attack Method")

def init_model():
    # train graph classification model
    if args.method == 'GCN':
        model = graph_classifier.GCN(num_features, 64, dataset.num_classes).to(device)
        optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)

    return model, optimizer


def train(epoch):
    if epoch == 51:
        for param_group in optimizer.param_groups:
            param_group['lr'] = 0.5 * param_group['lr']

    loss_all = 0
    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        adj_ori = preprocess_adj(data.edge_index, data.num_nodes, device, norm=False)
        if data.x is None:
            data.x = torch.sparse.sum(adj_ori, [0]).to_dense().unsqueeze(1).to(device)

        if normalize:
            _An = preprocess_adj(data.edge_index, data.num_nodes, device, norm=True)
            _, output = model(data.x, _An, data.batch)
        else:
            _, output = model(data.x, adj_ori, data.batch)
        loss = F.cross_entropy(output, data.y)
        loss.backward()
        loss_all += loss.item() * data.num_graphs
        optimizer.step()

    return loss_all / len(train_dataset)


def test(loader):
    model.eval()
    correct = 0
    for data in loader:
        data = data.to(device)
        adj_ori = preprocess_adj(data.edge_index, data.num_nodes, device, norm=False)
        if data.x is None:
            data.x = torch.sparse.sum(adj_ori, [0]).to_dense().unsqueeze(1).to(device)
        if normalize:
            _An = preprocess_adj(data.edge_index, data.num_nodes, device, norm=True)
            _, output = model(data.x, _An, data.batch)
        else:
            _, output = model(data.x, adj_ori, data.batch)

        correct += cal_acc(output, data.y)

    return correct / len(loader.dataset)


def evaluate(edge_list):
    model.eval()
    with torch.no_grad():
        _An = preprocess_adj(edge_list, num_nodes, device, normalize)
        _, output = model(g.x, _An, torch.LongTensor([0] * num_nodes).to(device))
        return cal_acc(output, g.y)


torch.cuda.empty_cache()
args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()
if args.cuda:
    torch.cuda.manual_seed(args.seed)
device = torch.device('cuda' if args.cuda else 'cpu')

random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)

# Load data and split 10-fold
path = osp.join(osp.dirname(osp.realpath(__file__)), 'data', args.dataset)
dataset = TUDataset(path, name=args.dataset, use_node_attr=args.use_node_attr).shuffle()
skf = StratifiedKFold(n_splits=10, shuffle=True, random_state=args.seed)
idx_list = []
for idx in skf.split(np.zeros(len(dataset.data.y)), dataset.data.y):
    idx_list.append(idx)

best_results = [0] * 10

num_features = dataset.num_features
if num_features == 0:
    num_features = 1

num_classes = dataset.num_classes
num_node_attributes = dataset.num_node_attributes
normalize = False if args.method == 'GIN' else True

time_cost = []
after_acc_list = []

for fold_idx in range(10):

    train_idx, test_idx = idx_list[fold_idx]
    test_dataset = dataset[test_idx.tolist()]
    train_dataset = dataset[train_idx.tolist()]
    test_loader = DataLoader(test_dataset, batch_size=args.batch_size)
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size)

    print(f'==========================Start training for fold {fold_idx}==========================')
    model, optimizer = init_model()

    # Train Graph Classifier
    cur_acc=0

    for epoch in range(1, args.epochs+1):
        train_loss = train(epoch)
        train_acc = test(train_loader)
        test_acc = test(test_loader)
        if test_acc > cur_acc:
            cur_acc=test_acc
            # best_acc[fold_idx]=cur_acc
            torch.save(model.state_dict(), f'ckpt/dataset_{args.dataset}_{args.method}_try_model_fold_{fold_idx}.pth')
            patience=0
        else:
            if early_stop:
                patience+=1
                if patience>500:
                    break

        print('Epoch: {:03d}, Train Loss: {:.7f}, '
                'Train Acc: {:.7f}, Test Acc: {:.7f}'.format(epoch, train_loss,
                                                            train_acc, test_acc))

    # ================================================start evasion attack====================================
    print('================Start structural attack for fold ', fold_idx, '====================')

    if args.method == 'GCN':
        model = graph_classifier.GCN(num_features, 64, dataset.num_classes).to(device)

    model.load_state_dict(torch.load(f'ckpt/dataset_{args.dataset}_{args.method}_model_fold_{fold_idx}.pth'))
    last_W = list(model.parameters())[-2].data
    attacked = 0

    time0 = time.time()
    for pos in tqdm(range(len(test_dataset))):
        g = test_dataset[pos]
        g = g.to(device)
        num_nodes = g.num_nodes
        edge_index = g.edge_index.clone()
        if not is_undirected(edge_index):
            print('Change directed graph to undirected')
            edge_index = to_undirected(edge_index)
        n_perturbations = int(np.ceil(len(edge_index[0]) / 2 * args.attack_prop))
        print(
            f'There are {len(edge_index[0]) / 2} edges, {num_nodes} nodes in this graph. Graph label is {g.y.item()}.')
        print('n_perturbations:', n_perturbations)

        with torch.no_grad():
            _An = preprocess_adj(edge_index, num_nodes, device, normalize)
            h_v, output = model(g.x, _An, torch.LongTensor([0] * num_nodes).to(device))

        orig_acc = cal_acc(output, g.y)

        if orig_acc == 0:
            attacked += 1
            continue

        if args.attacker == 'CAMA':
            cama = CAMAttack(model, args.method, device)
            after_acc = cama.struct_attack(g, n_perturbations)
            if after_acc == 0:
                attacked += 1

        elif args.attacker == 'CAMA_grad':
            cama = CAMAttack(model, args.method, device, grad=True)
            after_acc = cama.struct_attack(g, n_perturbations)

            if after_acc == 0:
                attacked+= 1

    after_acc_list.append(1 - attacked / len(test_dataset))
