import os import torch.optim import numpy as np import torch.optim import torch.nn.parallel from ops.models import TSN from ops.transforms import * from ops.dataset import TSNDataSet from torch.nn import functional as F def gen_file_list(frame_save_dir, frame_list_dir): val_path = os.path.join(frame_list_dir, 'val.txt') video_names = os.listdir(frame_save_dir) ucf101_rgb_val_file = open(val_path, 'w') for video_name in video_names: images_dir = os.path.join(frame_save_dir, video_name) ucf101_rgb_val_file.write(video_name) ucf101_rgb_val_file.write(' ') ucf101_rgb_val_file.write(str(len(os.listdir(images_dir)))) ucf101_rgb_val_file.write('\n') ucf101_rgb_val_file.close() return val_path def start_filter(config): arch = config['FIGHTING_2']['ARCH'] prefix = config['VIDEO']['PREFIX'] modality = config['FIGHTING_2']['MODALITY'] test_crop = config['FIGHTING_2']['TEST_CROP'] batch_size = config['FIGHTING_2']['BATCH_SIZE'] weights_path = config['MODEL']['CLS_FIGHTING_2'] test_segment = config['FIGHTING_2']['TEST_SEGMENT'] frame_save_dir = config['VIDEO']['FRAME_SAVE_DIR'] frame_list_dir = config['VIDEO']['FRAME_LIST_DIR'] result_file_name = config['FIGHTING_2']['RESULT_FILE'] workers = 8 num_class = 2 shift_div = 8 img_feature_dim = 256 softmax = False is_shift = True full_res = False non_local = False dense_sample = False twice_sample = False val_list = gen_file_list(frame_save_dir, frame_list_dir) result_file_path = os.path.join(frame_list_dir, result_file_name) pretrain = 'imagenet' shift_place = 'blockres' crop_fusion_type = 'avg' net = TSN(num_class, test_segment if is_shift else 1, modality, base_model=arch, consensus_type=crop_fusion_type, img_feature_dim=img_feature_dim, pretrain=pretrain, is_shift=is_shift, shift_div=shift_div, shift_place=shift_place, non_local=non_local, ) checkpoint = torch.load(weights_path) checkpoint = checkpoint['state_dict'] base_dict = {'.'.join(k.split('.')[1:]): v for k, v in list(checkpoint.items())} replace_dict = {'base_model.classifier.weight': 'new_fc.weight', 'base_model.classifier.bias': 'new_fc.bias', } for k, v in replace_dict.items(): if k in base_dict: base_dict[v] = base_dict.pop(k) net.load_state_dict(base_dict) input_size = net.scale_size if full_res else net.input_size if test_crop == 1: cropping = torchvision.transforms.Compose([ GroupScale(net.scale_size), GroupCenterCrop(input_size), ]) elif test_crop == 3: # do not flip, so only 5 crops cropping = torchvision.transforms.Compose([ GroupFullResSample(input_size, net.scale_size, flip=False) ]) elif test_crop == 5: # do not flip, so only 5 crops cropping = torchvision.transforms.Compose([ GroupOverSample(input_size, net.scale_size, flip=False) ]) elif test_crop == 10: cropping = torchvision.transforms.Compose([ GroupOverSample(input_size, net.scale_size) ]) else: raise ValueError("Only 1, 5, 10 crops are supported while we got {}".format(test_crop)) data_loader = torch.utils.data.DataLoader( TSNDataSet(frame_save_dir, val_list, num_segments=test_segment, new_length=1 if modality == "RGB" else 5, modality=modality, image_tmpl=prefix, test_mode=True, remove_missing=False, transform=torchvision.transforms.Compose([ cropping, Stack(roll=(arch in ['BNInception', 'InceptionV3'])), ToTorchFormatTensor(div=(arch not in ['BNInception', 'InceptionV3'])), GroupNormalize(net.input_mean, net.input_std), ]), dense_sample=dense_sample, twice_sample=twice_sample), batch_size=batch_size, shuffle=False, num_workers=workers, pin_memory=True, ) net = torch.nn.DataParallel(net.cuda()) net.eval() data_gen = enumerate(data_loader) max_num = len(data_loader.dataset) result_file = open(result_file_path, 'w') for i, data_pair in data_gen: directory, data = data_pair with torch.no_grad(): if i >= max_num: break num_crop = test_crop if dense_sample: num_crop *= 10 # 10 clips for testing when using dense sample if twice_sample: num_crop *= 2 if modality == 'RGB': length = 3 elif modality == 'Flow': length = 10 elif modality == 'RGBDiff': length = 18 else: raise ValueError("Unknown modality " + modality) data_in = data.view(-1, length, data.size(2), data.size(3)) if is_shift: data_in = data_in.view(batch_size * num_crop, test_segment, length, data_in.size(2), data_in.size(3)) rst, feature = net(data_in) rst = rst.reshape(batch_size, num_crop, -1).mean(1) if softmax: # take the softmax to normalize the output to probability rst = F.softmax(rst, dim=1) rst = rst.data.cpu().numpy().copy() if net.module.is_shift: rst = rst.reshape(batch_size, num_class) else: rst = rst.reshape((batch_size, -1, num_class)).mean(axis=1).reshape((batch_size, num_class)) proba = np.squeeze(rst) print(proba) proba = np.exp(proba)/sum(np.exp(proba)) result_file.write(str(directory[0]) + ' ') result_file.write(str(proba[0]) + ',' + str(proba[1]) + '\n') result_file.close() print('fighting filter end')