import os
import cv2
import random
import shutil
import subprocess
import numpy as np
import torch.optim
from tqdm import tqdm
import torch.nn.parallel
from ops.models import TSN
from ops.transforms import *
from functools import partial
from mtcnn.mtcnn import MTCNN
from keras.models import Model
from multiprocessing import Pool
from keras.models import load_model
from sklearn.externals import joblib
from tensorflow.keras.preprocessing.image import img_to_array




from ops.dataset import TSNDataSet
from torch.nn import functional as F

os.environ["CUDA_VISIBLE_DEVICES"] = '1'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'

class FeatureExtractor(object):
    def __init__(self, input_size=224, out_put_layer='global_average_pooling2d_1', model_path='nceptionResNetV2-final.h5'):
        self.model = load_model(model_path)
        self.input_size = input_size
        self.model_inter = Model(inputs=self.model.input, outputs=self.model.get_layer(out_put_layer).output)

    def inference(self, image):
        image = cv2.resize(image, (self.input_size, self.input_size))
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = image.astype("float") / 255.0
        image = img_to_array(image)
        image = np.expand_dims(image, axis=0)
        feature = self.model_inter.predict(image)[0]
        return feature


def extract_wav(config):
    video_dir = config['VIDEO']['VIDEO_DIR']
    video_save_dir = config['VIDEO']['VIDEO_SAVE_DIR']
    audio_save_dir = config['VIDEO']['AUDIO_SAVE_DIR']

    assert os.path.exists(video_dir)
    video_names = os.listdir(video_dir)
    for video_index, video_name in enumerate(video_names):
        file_name = video_name.split('.')[0]
        video_path = os.path.join(video_dir, video_name)

        assert os.path.exists(audio_save_dir)
        assert os.path.exists(video_save_dir)

        audio_name = file_name + '.wav'
        audio_save_path = os.path.join(audio_save_dir, audio_name)
        video_save_path = os.path.join(video_save_dir, video_name)

        command = 'ffmpeg -i {} -f wav -ar 16000 {}'.format(video_path, audio_save_path)
        os.popen(command)
        shutil.copyfile(video_path, video_save_path)


def video2frame(file_name, class_path, dst_class_path):
    if '.mp4' not in file_name:
        return
    name, ext = os.path.splitext(file_name)
    dst_directory_path = os.path.join(dst_class_path, name)

    video_file_path = os.path.join(class_path, file_name)
    try:
        if os.path.exists(dst_directory_path):
            if not os.path.exists(os.path.join(dst_directory_path, 'img_00001.jpg')):
                subprocess.call('rm -r \"{}\"'.format(dst_directory_path), shell=True)
                print('remove {}'.format(dst_directory_path))
                os.mkdir(dst_directory_path)
            else:
                print('*** convert has been done: {}'.format(dst_directory_path))
                return
        else:
            os.mkdir(dst_directory_path)
    except:
        print(dst_directory_path)
        return
    cmd = 'ffmpeg -i \"{}\" -threads 1 -vf scale=-1:331 -q:v 0 \"{}/img_%05d.jpg\"'.format(video_file_path,
                                                                                           dst_directory_path)
    # print(cmd)
    subprocess.call(cmd, shell=True,
                    stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)


def extract_frame(config):
    video_save_dir = config['VIDEO']['VIDEO_SAVE_DIR']
    frame_save_dir = config['VIDEO']['FRAME_SAVE_DIR']
    n_thread = config['VIDEO']['THREAD_NUM']

    assert os.path.exists(video_save_dir)
    video_names = os.listdir(video_save_dir)

    if not os.path.exists(frame_save_dir):
        os.mkdir(frame_save_dir)

    p = Pool(n_thread)
    worker = partial(video2frame, class_path=video_save_dir, dst_class_path=frame_save_dir)
    for _ in tqdm(p.imap_unordered(worker, video_names), total=len(video_names)):
        pass

    p.close()
    p.join()


def extract_frame_pose(config):
    video_save_dir = config['VIDEO']['VIDEO_SAVE_DIR']
    frame_save_dir = config['VIDEO']['POSE_FRAME_SAVE_DIR']
    n_thread = config['VIDEO']['THREAD_NUM']

    assert os.path.exists(video_save_dir)
    video_names = os.listdir(video_save_dir)

    if not os.path.exists(frame_save_dir):
        os.mkdir(frame_save_dir)

    p = Pool(n_thread)
    worker = partial(video2frame, class_path=video_save_dir, dst_class_path=frame_save_dir)
    for _ in tqdm(p.imap_unordered(worker, video_names), total=len(video_names)):
        pass

    p.close()
    p.join()


def extract_is10(config):
    open_smile_dir = config['AUDIO']['OPENSMILE_DIR']
    audio_save_dir = config['VIDEO']['AUDIO_SAVE_DIR']
    is10_save_dir = config['VIDEO']['IS10_FEATURE_CSV_DIR']

    assert os.path.exists(audio_save_dir)
    audio_names = os.listdir(audio_save_dir)

    if not os.path.exists(is10_save_dir):
        os.mkdir(is10_save_dir)

    for audio_name in audio_names:
        audio_save_path = os.path.join(audio_save_dir, audio_name)
        csv_name = audio_name[:-4] + '.csv'
        csv_path = os.path.join(is10_save_dir, csv_name)

        config = '{}/config/IS10_paraling.conf'.format(open_smile_dir)
        command = '{}/SMILExtract -C {} -I {}  -O {}'.format(open_smile_dir, config, audio_save_path, csv_path)
        os.popen(command)


def extract_face_feature(config):
    feature_emotion_path = config['MODEL']['FEATURE_EMOTION']
    frame_save_dir = config['VIDEO']['FRAME_SAVE_DIR']
    face_feature_dir = config['VIDEO']['FACE_FEATURE_DIR']
    interval = config['EMOTION']['INTERVAL']
    input_size = config['EMOTION']['INPUT_SIZE']
    prefix = config['VIDEO']['PREFIX']

    feature_extractor = FeatureExtractor(
        input_size=input_size, out_put_layer='global_average_pooling2d_1', model_path=feature_emotion_path)
    mtcnn_detector = MTCNN()

    video_names = os.listdir(frame_save_dir)
    for video_index, video_name in enumerate(video_names):
        print('{}/{}'.format(video_index, len(video_names)))
        video_dir = os.path.join(frame_save_dir, video_name)
        frame_names = os.listdir(video_dir)
        end = 0
        features = []

        while end < len(frame_names):

            if end % interval == 0:
                frame_name = prefix.format(end + 1)
                frame_path = os.path.join(video_dir, frame_name)

                frame = cv2.imread(frame_path)
                img_h, img_w, img_c = frame.shape
                detect_faces = mtcnn_detector.detect_faces(frame)
                for i, e in enumerate(detect_faces):
                    x1, y1, w, h = e['box']
                    x1 = x1 if x1 > 0 else 0
                    y1 = y1 if y1 > 0 else 0
                    x1 = x1 if x1 < img_w else img_w
                    y1 = y1 if y1 < img_h else img_h

                    face = frame[y1:y1 + h, x1:x1 + w, :]
                    if face is []:
                        continue
                    features.append(feature_extractor.inference(face)[0])
                # top_5 = {}
                # for i, e in enumerate(detect_faces):
                #     x1, y1, w, h = e['box']
                #     x1 = x1 if x1 > 0 else 0
                #     y1 = y1 if y1 > 0 else 0
                #     x1 = x1 if x1 < img_w else img_w
                #     y1 = y1 if y1 < img_h else img_h
                #
                #     top_5[w*h] = [x1, y1, w, h]
                #
                # top_5 = sorted(top_5.items(), key=lambda d:d[0], reverse=True)
                # j = 0
                # for v in top_5:
                #     if j > 5:
                #         break
                #     x1, y1, w, h = v[1]
                #     face = frame[y1:y1+h, x1:x1+w, :]
                #     if face is []:
                #         continue
                #     features.append(feature_extractor.inference(face)[0])

            end += 1
        if len(features) is 0:
            continue
        features_np = np.array(features)
        face_feature_path = os.path.join(face_feature_dir, video_name + '.npy')
        np.save(face_feature_path, features_np)


def extract_random_face_feature(config):
    feature_emotion_path = config['MODEL']['FEATURE_EMOTION']
    face_save_dir = config['VIDEO']['FACE_IMAGE_DIR']
    face_feature_dir = config['VIDEO']['FACE_FEATURE_DIR']
    input_size = config['EMOTION']['INPUT_SIZE']

    feature_extractor = FeatureExtractor(
        input_size=input_size, out_put_layer='avg_pool', model_path=feature_emotion_path)

    video_dirs = []
    class_names = os.listdir(face_save_dir)
    for class_name in class_names:
        class_dir = os.path.join(face_save_dir, class_name)
        video_names = os.listdir(class_dir)
        for video_name in video_names:
            video_dir = os.path.join(class_dir, video_name)
            video_dirs.append(video_dir)

    for video_dir_index, video_dir in enumerate(video_dirs):
        print('{}/{}'.format(video_dir_index, len(video_dirs)))
        class_name, video_name = video_dir.split('/')[-2], video_dir.split('/')[-1]

        video_file_name = video_name.split('.')[0]
        save_class_dir = os.path.join(face_feature_dir, class_name)
        face_feature_path = os.path.join(save_class_dir, video_file_name + '.npy')
        if os.path.exists(face_feature_path):
            print('file is exists')
            continue

        image_names = os.listdir(video_dir)
        image_dirs = []
        for image_name in image_names:
            image_dir = os.path.join(video_dir, image_name)
            image_dirs.append(image_dir)

        features = []
        for image_dir_index, image_dir in enumerate(image_dirs):
            sub_face_names = os.listdir(image_dir)
            sub_face_num = len(sub_face_names)
            for face_index in range(sub_face_num):
                face_path = os.path.join(image_dir, sub_face_names[face_index])
                face_image = cv2.imread(face_path)
                features.append(feature_extractor.inference(face_image)[0])

        face_num = len(features)
        random_1 = random.sample(range(face_num), int(0.8 * face_num))
        features_random_1 = [features[c] for c in random_1]

        random_2 = random.sample(range(face_num), int(0.6 * face_num))
        features_random_2 = [features[d] for d in random_2]

        random_3 = random.sample(range(face_num), int(0.4 * face_num))
        features_random_3 = [features[e] for e in random_3]

        if len(features) is 0:
            continue

        if os.path.exists(save_class_dir) is False:
            os.mkdir(save_class_dir)

        features_np = np.array(features)
        face_feature_path = os.path.join(save_class_dir, video_file_name + '.npy')
        np.save(face_feature_path, features_np)

        features_np_random_1 = np.array(features_random_1)
        face_feature_1_path = os.path.join(save_class_dir, video_file_name + '_1.npy')
        np.save(face_feature_1_path, features_np_random_1)

        features_np_random_2 = np.array(features_random_2)
        face_feature_2_path = os.path.join(save_class_dir, video_file_name + '_2.npy')
        np.save(face_feature_2_path, features_np_random_2)

        features_np_random_3 = np.array(features_random_3)
        face_feature_3_path = os.path.join(save_class_dir, video_file_name + '_3.npy')
        np.save(face_feature_3_path, features_np_random_3)


def get_vid_fea(pics_features):
    pics_features = np.array(pics_features)
    fea_mean = pics_features.mean(axis=0)
    fea_max = np.amax(pics_features, axis=0)
    fea_min = np.amin(pics_features, axis=0)
    fea_std = pics_features.std(axis=0)

    feature_concate = np.concatenate((fea_mean, fea_max, fea_min, fea_std), axis=1)
    return np.squeeze(feature_concate)


def extract_random_face_and_frame_feature_():
    face_feature_dir = r'/data2/3_log-ResNet50/train_mirror/'
    new_face_feature_dir = r'/data2/retinaface/random_face_frame_features_train_mirror/'

    video_dirs = []
    class_names = os.listdir(face_feature_dir)
    for class_name in class_names:
        class_dir = os.path.join(face_feature_dir, class_name)
        video_names = os.listdir(class_dir)
        for video_name in video_names:
            video_dir = os.path.join(class_dir, video_name)
            video_dirs.append(video_dir)

    for video_dir in video_dirs:
        video_name = video_dir.split('/')[-1]
        frame_names = os.listdir(video_dir)
        feature = []
        for frame_name in frame_names:
            feature_dir = os.path.join(video_dir, frame_name)
            face_features_names = os.listdir(feature_dir)
            for face_features_name in face_features_names:
                face_features_path = os.path.join(feature_dir, face_features_name)
                feature_np = np.load(face_features_path, allow_pickle=True)
                feature.append(feature_np)

        feature_num = len(feature)
        if feature_num < 4:
            continue

        random_1 = random.sample(range(feature_num), int(0.9 * feature_num))
        features_random_1 = [feature[c] for c in random_1]

        random_2 = random.sample(range(feature_num), int(0.7 * feature_num))
        features_random_2 = [feature[d] for d in random_2]

        random_3 = random.sample(range(feature_num), int(0.5 * feature_num))
        features_random_3 = [feature[e] for e in random_3]

        video_file_name = video_name.split('.')[0]

        features_np = get_vid_fea(feature)
        face_feature_path = os.path.join(new_face_feature_dir, video_file_name + '.npy')
        np.save(face_feature_path, features_np)

        features_np_random_1 = get_vid_fea(features_random_1)
        face_feature_1_path = os.path.join(new_face_feature_dir, video_file_name + '_1.npy')
        np.save(face_feature_1_path, features_np_random_1)

        features_np_random_2 = get_vid_fea(features_random_2)
        face_feature_2_path = os.path.join(new_face_feature_dir, video_file_name + '_2.npy')
        np.save(face_feature_2_path, features_np_random_2)

        features_np_random_3 = get_vid_fea(features_random_3)
        face_feature_3_path = os.path.join(new_face_feature_dir, video_file_name + '_3.npy')
        np.save(face_feature_3_path, features_np_random_3)


def extract_random_face_and_frame_feature(config):
    feature_emotion_path = config['MODEL']['FEATURE_EMOTION']
    input_size = config['EMOTION']['INPUT_SIZE']
    face_dir = r'/data2/retinaface/train/'
    new_face_feature_dir = r'/data2/3_log-ResNet50/train_mirror/'

    feature_extractor = FeatureExtractor(
        input_size=input_size, out_put_layer='avg_pool', model_path=feature_emotion_path)

    sub_face_paths = []
    class_names = os.listdir(face_dir)
    for class_name in class_names:
        class_dir = os.path.join(face_dir, class_name)
        video_names = os.listdir(class_dir)
        for video_name in video_names:
            video_dir = os.path.join(class_dir, video_name)
            frame_names = os.listdir(video_dir)
            for frame_name in frame_names:
                frame_dir = os.path.join(video_dir, frame_name)
                sub_face_names = os.listdir(frame_dir)
                for sub_face_name in sub_face_names:
                    sub_face_path = os.path.join(frame_dir, sub_face_name)
                    sub_face_paths.append(sub_face_path)

    for face_index, sub_face_path in enumerate(sub_face_paths):
        print('{}/{}'.format(face_index+1, len(sub_face_paths)))
        class_name, video_name, frame_name, sub_face_name = sub_face_path.split('/')[-4]\
            , sub_face_path.split('/')[-3], sub_face_path.split('/')[-2], sub_face_path.split('/')[-1]

        class_dir = os.path.join(new_face_feature_dir, class_name)
        video_dir = os.path.join(class_dir, video_name)
        frame_dir = os.path.join(video_dir, frame_name)
        sub_face_name = sub_face_name.split('.')[0] + '.npy'
        face_feature_save_path = os.path.join(frame_dir, sub_face_name)
        if os.path.exists(face_feature_save_path):
            print('file exists')
            continue 
        
        face_image = cv2.imread(sub_face_path)
        mirror_face_image = cv2.flip(face_image, 0)
        feature = feature_extractor.inference(mirror_face_image)[0]

        
        if os.path.exists(class_dir) is False:
            os.mkdir(class_dir)

        if os.path.exists(video_dir) is False:
            os.mkdir(video_dir)

        if os.path.exists(frame_dir) is False:
            os.mkdir(frame_dir)

        np.save(face_feature_save_path, feature)


def gen_file_list(frame_save_dir, frame_list_dir):

    val_path = os.path.join(frame_list_dir, 'train.txt')
    video_names = os.listdir(frame_save_dir)
    ucf101_rgb_val_file = open(val_path, 'w')

    for video_name in video_names:
        images_dir = os.path.join(frame_save_dir, video_name)
        ucf101_rgb_val_file.write(video_name)
        ucf101_rgb_val_file.write(' ')
        ucf101_rgb_val_file.write(str(len(os.listdir(images_dir))))
        ucf101_rgb_val_file.write('\n')

    ucf101_rgb_val_file.close()

    return val_path



def extract_video_features(config):
    arch = config['FIGHTING']['ARCH']
    prefix = config['VIDEO']['PREFIX']
    modality = config['VIDEO_FILTER']['MODALITY']
    test_crop = config['VIDEO_FILTER']['TEST_CROP']
    batch_size = config['VIDEO_FILTER']['BATCH_SIZE']
    weights_path = config['MODEL']['CLS_VIDEO']
    test_segment = config['VIDEO_FILTER']['TEST_SEGMENT']
    frame_save_dir = config['VIDEO']['FRAME_SAVE_DIR']
    frame_list_dir = config['VIDEO']['FRAME_LIST_DIR']
    feature_save_dir = r'/home/jwq/Desktop/tmp/video2np/train/'

    workers = 8
    num_class = 3
    shift_div = 8
    img_feature_dim = 256

    softmax = False
    is_shift = True
    full_res = False
    non_local = False
    dense_sample = False
    twice_sample = False

    val_list = gen_file_list(frame_save_dir, frame_list_dir)
    
    pretrain = 'imagenet'
    shift_place = 'blockres'
    crop_fusion_type = 'avg'

    net = TSN(num_class, test_segment if is_shift else 1, modality,
              base_model=arch,
              consensus_type=crop_fusion_type,
              img_feature_dim=img_feature_dim,
              pretrain=pretrain,
              is_shift=is_shift, shift_div=shift_div, shift_place=shift_place,
              non_local=non_local,
              )

    checkpoint = torch.load(weights_path)
    checkpoint = checkpoint['state_dict']

    base_dict = {'.'.join(k.split('.')[1:]): v for k, v in list(checkpoint.items())}
    replace_dict = {'base_model.classifier.weight': 'new_fc.weight',
                    'base_model.classifier.bias': 'new_fc.bias',
                    }
    for k, v in replace_dict.items():
        if k in base_dict:
            base_dict[v] = base_dict.pop(k)

    net.load_state_dict(base_dict)

    input_size = net.scale_size if full_res else net.input_size

    if test_crop == 1:
        cropping = torchvision.transforms.Compose([
            GroupScale(net.scale_size),
            GroupCenterCrop(input_size),
        ])
    elif test_crop == 3:  # do not flip, so only 5 crops
        cropping = torchvision.transforms.Compose([
            GroupFullResSample(input_size, net.scale_size, flip=False)
        ])
    elif test_crop == 5:  # do not flip, so only 5 crops
        cropping = torchvision.transforms.Compose([
            GroupOverSample(input_size, net.scale_size, flip=False)
        ])
    elif test_crop == 10:
        cropping = torchvision.transforms.Compose([
            GroupOverSample(input_size, net.scale_size)
        ])
    else:
        raise ValueError("Only 1, 5, 10 crops are supported while we got {}".format(test_crop))

    data_loader = torch.utils.data.DataLoader(
            TSNDataSet(frame_save_dir, val_list, num_segments=test_segment,
                       new_length=1 if modality == "RGB" else 5,
                       modality=modality,
                       image_tmpl=prefix,
                       test_mode=True,
                       remove_missing=False,
                       transform=torchvision.transforms.Compose([
                           cropping,
                           Stack(roll=(arch in ['BNInception', 'InceptionV3'])),
                           ToTorchFormatTensor(div=(arch not in ['BNInception', 'InceptionV3'])),
                           GroupNormalize(net.input_mean, net.input_std),
                       ]), dense_sample=dense_sample, twice_sample=twice_sample),
            batch_size=batch_size, shuffle=False,
            num_workers=workers, pin_memory=True,
    )

    net = torch.nn.DataParallel(net.cuda())
    net.eval()
    data_gen = enumerate(data_loader)
    max_num = len(data_loader.dataset)

    for i, data_pair in data_gen:
        directory, data = data_pair
        with torch.no_grad():
            if i >= max_num:
                break
            num_crop = test_crop
            if dense_sample:
                num_crop *= 10  # 10 clips for testing when using dense sample

            if twice_sample:
                num_crop *= 2

            if modality == 'RGB':
                length = 3
            elif modality == 'Flow':
                length = 10
            elif modality == 'RGBDiff':
                length = 18
            else:
                raise ValueError("Unknown modality " + modality)

            data_in = data.view(-1, length, data.size(2), data.size(3))
            if is_shift:
                data_in = data_in.view(batch_size * num_crop, test_segment, length, data_in.size(2), data_in.size(3))
            rst, feature = net(data_in)
            
            feature = np.squeeze(feature.cpu())
            print(feature.shape)
            feature_name = str(directory[0]) + '.npy'
            feature_save_path = os.path.join(feature_save_dir, feature_name)
            np.save(feature_save_path, feature)


if __name__ == '__main__':
    extract_random_face_and_frame_feature_()