video_filter.py 6.02 KB
import os
import torch.optim
import numpy as np
import torch.nn.parallel
from ops.models import TSN
from ops.transforms import *
from ops.dataset import TSNDataSet
from torch.nn import functional as F


def gen_file_list(frame_save_dir, frame_list_dir):

    val_path = os.path.join(frame_list_dir, 'val.txt')
    video_names = os.listdir(frame_save_dir)
    ucf101_rgb_val_file = open(val_path, 'w')

    for video_name in video_names:
        images_dir = os.path.join(frame_save_dir, video_name)
        ucf101_rgb_val_file.write(video_name)
        ucf101_rgb_val_file.write(' ')
        ucf101_rgb_val_file.write(str(len(os.listdir(images_dir))))
        ucf101_rgb_val_file.write('\n')

    ucf101_rgb_val_file.close()

    return val_path


def start_filter(config):
    arch = config['FIGHTING']['ARCH']
    prefix = config['VIDEO']['PREFIX']
    modality = config['VIDEO_FILTER']['MODALITY']
    test_crop = config['VIDEO_FILTER']['TEST_CROP']
    batch_size = config['VIDEO_FILTER']['BATCH_SIZE']
    weights_path = config['MODEL']['CLS_VIDEO']
    test_segment = config['VIDEO_FILTER']['TEST_SEGMENT']
    frame_save_dir = config['VIDEO']['FRAME_SAVE_DIR']
    frame_list_dir = config['VIDEO']['FRAME_LIST_DIR']
    result_file_name = config['VIDEO_FILTER']['RESULT_FILE']

    workers = 8
    num_class = 3
    shift_div = 8
    img_feature_dim = 256

    softmax = False
    is_shift = True
    full_res = False
    non_local = False
    dense_sample = False
    twice_sample = False

    val_list = gen_file_list(frame_save_dir, frame_list_dir)
    result_file_path = os.path.join(frame_list_dir, result_file_name)

    pretrain = 'imagenet'
    shift_place = 'blockres'
    crop_fusion_type = 'avg'

    net = TSN(num_class, test_segment if is_shift else 1, modality,
              base_model=arch,
              consensus_type=crop_fusion_type,
              img_feature_dim=img_feature_dim,
              pretrain=pretrain,
              is_shift=is_shift, shift_div=shift_div, shift_place=shift_place,
              non_local=non_local,
              )

    checkpoint = torch.load(weights_path)
    checkpoint = checkpoint['state_dict']

    base_dict = {'.'.join(k.split('.')[1:]): v for k, v in list(checkpoint.items())}
    replace_dict = {'base_model.classifier.weight': 'new_fc.weight',
                    'base_model.classifier.bias': 'new_fc.bias',
                    }
    for k, v in replace_dict.items():
        if k in base_dict:
            base_dict[v] = base_dict.pop(k)

    net.load_state_dict(base_dict)

    input_size = net.scale_size if full_res else net.input_size

    if test_crop == 1:
        cropping = torchvision.transforms.Compose([
            GroupScale(net.scale_size),
            GroupCenterCrop(input_size),
        ])
    elif test_crop == 3:  # do not flip, so only 5 crops
        cropping = torchvision.transforms.Compose([
            GroupFullResSample(input_size, net.scale_size, flip=False)
        ])
    elif test_crop == 5:  # do not flip, so only 5 crops
        cropping = torchvision.transforms.Compose([
            GroupOverSample(input_size, net.scale_size, flip=False)
        ])
    elif test_crop == 10:
        cropping = torchvision.transforms.Compose([
            GroupOverSample(input_size, net.scale_size)
        ])
    else:
        raise ValueError("Only 1, 5, 10 crops are supported while we got {}".format(test_crop))

    data_loader = torch.utils.data.DataLoader(
            TSNDataSet(frame_save_dir, val_list, num_segments=test_segment,
                       new_length=1 if modality == "RGB" else 5,
                       modality=modality,
                       image_tmpl=prefix,
                       test_mode=True,
                       remove_missing=False,
                       transform=torchvision.transforms.Compose([
                           cropping,
                           Stack(roll=(arch in ['BNInception', 'InceptionV3'])),
                           ToTorchFormatTensor(div=(arch not in ['BNInception', 'InceptionV3'])),
                           GroupNormalize(net.input_mean, net.input_std),
                       ]), dense_sample=dense_sample, twice_sample=twice_sample),
            batch_size=batch_size, shuffle=False,
            num_workers=workers, pin_memory=True,
    )

    net = torch.nn.DataParallel(net.cuda())
    net.eval()
    data_gen = enumerate(data_loader)
    max_num = len(data_loader.dataset)

    result_file = open(result_file_path, 'w')

    for i, data_pair in data_gen:
        directory, data = data_pair
        with torch.no_grad():
            if i >= max_num:
                break
            num_crop = test_crop
            if dense_sample:
                num_crop *= 10  # 10 clips for testing when using dense sample

            if twice_sample:
                num_crop *= 2

            if modality == 'RGB':
                length = 3
            elif modality == 'Flow':
                length = 10
            elif modality == 'RGBDiff':
                length = 18
            else:
                raise ValueError("Unknown modality " + modality)
 
            data_in = data.view(-1, length, data.size(2), data.size(3))
            if is_shift:
                data_in = data_in.view(batch_size * num_crop, test_segment, length, data_in.size(2), data_in.size(3))
            
            rst, feature = net(data_in)
            rst = rst.reshape(batch_size, num_crop, -1).mean(1)

            if softmax:
                # take the softmax to normalize the output to probability
                rst = F.softmax(rst, dim=1)

            rst = rst.data.cpu().numpy().copy()

            if net.module.is_shift:
                rst = rst.reshape(batch_size, num_class)
            else:
                rst = rst.reshape((batch_size, -1, num_class)).mean(axis=1).reshape((batch_size, num_class))

            proba = np.squeeze(rst)
            proba = np.exp(proba)/sum(np.exp(proba))
            result_file.write(str(directory[0]) + ' ')
            result_file.write(str(proba[0]) + ',' + str(proba[1]) + ',' + str(proba[2]) + '\n')

    result_file.close()
    print('video filter end')