cb29b6d7 by jiangwenqiang

first commit

1 parent 78b00ada
File mode changed
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
1 import os
2 import csv
3 import pickle
4 import numpy as np
5 from sklearn.externals import joblib
6
7
8 def start_filter(config):
9 cls_audio_path = config['MODEL']['CLS_AUDIO']
10 feature_save_dir = config['VIDEO']['IS10_FEATURE_NP_DIR']
11 frame_list_dir = config['VIDEO']['FRAME_LIST_DIR']
12 result_file_name = config['AUDIO']['RESULT_FILE']
13 feature_name = config['AUDIO']['DATA_NAME']
14
15 svm_clf = joblib.load(cls_audio_path)
16
17 result_file_path = os.path.join(frame_list_dir, result_file_name)
18 result_file = open(result_file_path, 'w')
19
20 feature_path = os.path.join(feature_save_dir, feature_name)
21 val_annotation_pairs = np.load(feature_path, allow_pickle=True, encoding='latin1')
22
23 for pair in val_annotation_pairs:
24
25 v = pair[0]
26 n = pair[2]
27
28 feature_np = np.reshape(v, (1, -1))
29 res = svm_clf.predict_proba(feature_np)
30 proba = np.squeeze(res)
31
32 # class_pre = svm_clf.predict(feature_np)
33
34 result_file.write(str(pair[2])[:-4] + ' ')
35 result_file.write(str(proba[0]) + ',' + str(proba[1]) + ',' + str(proba[2]) + '\n')
36
37 result_file.close()
38
39
40
41
42
43 def start_filter_xgboost(config):
44 cls_class_path = config['MODEL']['CLS_AUDIO']
45 feature_save_dir = config['VIDEO']['IS10_FEATURE_NP_DIR']
46 frame_list_dir = config['VIDEO']['FRAME_LIST_DIR']
47 result_file_name = config['AUDIO']['RESULT_FILE']
48 feature_name = config['AUDIO']['DATA_NAME']
49
50 xgboost_model = pickle.load(open(cls_class_path, "rb"))
51
52 result_file_path = os.path.join(frame_list_dir, result_file_name)
53 result_file = open(result_file_path, 'w')
54
55 feature_path = os.path.join(feature_save_dir, feature_name)
56 val_annotation_pairs = np.load(feature_path, allow_pickle=True, encoding='latin1')
57
58 X_val = []
59 Y_names = []
60 for pair in val_annotation_pairs:
61 n, v = pair.items()
62 X_val.append(v)
63 Y_names.append(n)
64
65 X_val = np.array(X_val)
66 y_pred = xgboost_model.predict_proba(X_val)
67
68 for i, Y_name in enumerate(Y_names):
69 result_file.write(Y_name + ' ')
70 result_file.write(str(y_pred[i][0]) + ',' + str(y_pred[i][1]) + ',' + str(y_pred[i][2]) + '\n')
71
72 result_file.close()
73
1 import os
2 import cv2
3 import numpy as np
4 import pickle
5
6 def start_filter(config):
7 cls_class_path = config['MODEL']['CLS_BG']
8 feature_save_dir = config['VIDEO']['FACE_FEATURE_DIR']
9 frame_list_dir = config['VIDEO']['FRAME_LIST_DIR']
10 result_file_name = config['BG']['RESULT_FILE']
11 feature_name = config['BG']['DATA_NAME']
12
13 xgboost_model = pickle.load(open(cls_class_path, "rb"))
14
15 result_file_path = os.path.join(frame_list_dir, result_file_name)
16 result_file = open(result_file_path, 'w')
17
18 feature_path = os.path.join(feature_save_dir, feature_name)
19 val_annotation_pairs = np.load(feature_path, allow_pickle=True)
20
21 X_val = []
22 Y_val = []
23 Y_names = []
24 for j in range(len(val_annotation_pairs)):
25 pair = val_annotation_pairs[j]
26 X_val.append(np.squeeze(pair[0]))
27 Y_val.append(pair[1])
28 Y_names.append(pair[2])
29
30 X_val = np.array(X_val)
31 y_pred = xgboost_model.predict_proba(X_val)
32
33 for i, Y_name in enumerate(Y_names):
34 result_file.write(Y_name + ' ')
35 result_file.write(str(y_pred[i][0]) + ',' + str(y_pred[i][1]) + ',' + str(y_pred[i][2]) + '\n')
36
37 result_file.close()
38
39
40
41
42
1 import os
2 import pickle
3 import numpy as np
4
5
6 def start_filter(config):
7
8 cls_class_path = config['MODEL']['CLS_CLASS']
9 feature_save_dir = config['VIDEO']['CLASS_FEATURE_DIR']
10 frame_list_dir = config['VIDEO']['FRAME_LIST_DIR']
11 result_file_name = config['CLASS']['RESULT_FILE']
12 feature_name = config['CLASS']['DATA_NAME']
13
14 xgboost_model = pickle.load(open(cls_class_path, "rb"))
15
16 result_file_path = os.path.join(frame_list_dir, result_file_name)
17 result_file = open(result_file_path, 'w')
18
19 feature_path = os.path.join(feature_save_dir, feature_name)
20 val_annotation_pairs = np.load(feature_path, allow_pickle=True)
21
22 X_val = []
23 Y_val = []
24 Y_names = []
25 for j in range(len(val_annotation_pairs)):
26 pair = val_annotation_pairs[j]
27 X_val.append(pair[0])
28 Y_val.append(pair[1])
29 Y_names.append(pair[2])
30
31 X_val = np.array(X_val)
32 y_pred = xgboost_model.predict(X_val)
33
34 for i, Y_name in enumerate(Y_names):
35 result_file.write(Y_name + ' ')
36 result_file.write(str(y_pred[i]) + '\n')
37
38 result_file.close()
1 MODEL:
2 CLS_FIGHTING_2: '/home/jwq/models/cls_fighting_2/cls_fighting_2_v0.0.1.pth'
3 CLS_EMOTION: '/home/jwq/models/cls_emotion/v0.1.0.m'
4 FEATURE_EMOTION: '/home/jwq/models/feature_emotion/FerPlus3.h5'
5 CLS_AUDIO: '/home/jwq/models/cls_audio/v0.0.1.m'
6 CLS_CLASS: '/home/jwq/models/cls_class/v_0.0.1_xgb.pkl'
7 CLS_VIDEO: '/home/jwq/models/cls_video/v0.4.1.pth'
8 CLS_POSE: '/home/jwq/models/cls_pose/v0.0.1.pth'
9 CLS_FLOW: '/home/jwq/models/cls_flow/v0.1.1.pth'
10 CLS_BG: '/home/jwq/models/cls_bg/v0.1.1.pkl'
11 CLS_PERSON: '/home/jwq/models/cls_person/v0.1.1.pkl'
12
13 THRESHOLD:
14 FACES_THRESHOLD: 0.6
15
16 FILTER:
17
18
19 VIDEO:
20 VIDEO_DIR: '/home/jwq/Desktop/VGAF_EmotiW/Val'
21 LABEL_PATH: '/home/jwq/Desktop/VGAF_EmotiW/Val_labels.txt'
22 VIDEO_SAVE_DIR: '/home/jwq/Desktop/tmp/video'
23 AUDIO_SAVE_DIR: '/home/jwq/npys/'
24 FRAME_SAVE_DIR: '/home/jwq/Desktop/tmp/frame'
25 # FRAME_SAVE_DIR: '/home/jwq/Desktop/VGAF_EmotiW_class/train_frame'
26 FLOW_SAVE_DIR: '/home/jwq/Desktop/tmp/flow'
27 POSE_FRAME_SAVE_DIR: '/home/jwq/Desktop/tmp/pose_frame'
28 FRAME_LIST_DIR: '/home/jwq/Desktop/tmp/file_list'
29 IS10_FEATURE_NP_DIR: '/home/jwq/npys'
30 IS10_FEATURE_CSV_DIR: '/home/jwq/Desktop/tmp/is10'
31 # FACE_FEATURE_DIR: '/home/jwq/Desktop/tmp/face_feature_retina'
32 # FACE_FEATURE_DIR: '/data2/retinaface/random_face_frame_features/'
33 FACE_FEATURE_DIR: '/data1/segment/'
34 # FACE_FEATURE_DIR: '/home/jwq/npys/'
35 FACE_IMAGE_DIR: '/data2/retinaface/train/'
36 CLASS_FEATURE_DIR: '/home/jwq/Desktop/tmp/class'
37 PREFIX: 'img_{:05d}.jpg'
38 FLOW_PREFIX: 'flow_{}_{:05d}.jpg'
39 THREAD_NUM: 10
40 FPS: 5
41
42 VIDEO_FILTER:
43 TEST_SEGMENT: 8
44 TEST_CROP: 1
45 BATCH_SIZE: 1
46 INPUT_SIZE: 224
47 MODALITY: 'RGB'
48 ARCH: 'resnet50'
49 RESULT_FILE: 'video_filter.txt'
50
51 VIDEO_1_FILTER:
52 TEST_SEGMENT: 8
53 TEST_CROP: 1
54 BATCH_SIZE: 1
55 INPUT_SIZE: 224
56 MODALITY: 'RGB'
57 ARCH: 'resnet34'
58 RESULT_FILE: 'video_1_filter.txt'
59
60 EMOTION:
61 INTERVAL: 1
62 INPUT_SIZE: 224
63 RESULT_FILE: 'emotion_filter.txt'
64
65 EMOTION_1:
66 RESULT_FILE: 'emotion_1_filter.txt'
67 DATA_NAME: 'val.npy'
68
69 ARGUE:
70 DIMENSION: 1582
71 RESULT_FILE: 'argue_filter.txt'
72
73 FIGHTING:
74 TEST_SEGMENT: 8
75 TEST_CROP: 1
76 BATCH_SIZE: 1
77 INPUT_SIZE: 224
78 MODALITY: 'RGB'
79 ARCH: 'resnet50'
80 RESULT_FILE: 'fighting_filter.txt'
81
82 FIGHTING_2:
83 TEST_SEGMENT: 8
84 TEST_CROP: 1
85 BATCH_SIZE: 1
86 INPUT_SIZE: 224
87 MODALITY: 'RGB'
88 ARCH: 'resnet50'
89 RESULT_FILE: 'fighting_2_filter.txt'
90
91 MEETING:
92 TEST_SEGMENT: 8
93 TEST_CROP: 1
94 BATCH_SIZE: 1
95 INPUT_SIZE: 224
96 MODALITY: 'RGB'
97 ARCH: 'resnet50'
98 RESULT_FILE: 'meeting_filter.txt'
99
100 TROOPS:
101 TEST_SEGMENT: 8
102 TEST_CROP: 1
103 BATCH_SIZE: 1
104 INPUT_SIZE: 224
105 MODALITY: 'RGB'
106 ARCH: 'resnet50'
107 RESULT_FILE: 'troops_filter.txt'
108
109 FLOW:
110 TEST_SEGMENT: 8
111 TEST_CROP: 1
112 BATCH_SIZE: 1
113 INPUT_SIZE: 224
114 MODALITY: 'Flow'
115 ARCH: 'resnet50'
116 RESULT_FILE: 'flow_filter.txt'
117
118
119 FINAL:
120 RESULT_FILE: 'final.txt'
121 ERROR_FILE: 'error.txt'
122 SIM_FILE: 'image_sim.txt'
123
124 AUDIO:
125 RESULT_FILE: 'audio.txt'
126 OPENSMILE_DIR: '/home/jwq/Downloads/opensmile-2.3.0'
127 DATA_NAME: 'val.npy'
128
129 CLASS:
130 RESULT_FILE: 'class.txt'
131 DATA_NAME: 'val _reannotation.npy'
132
133 POSE:
134 TEST_SEGMENT: 8
135 TEST_CROP: 1
136 BATCH_SIZE: 1
137 INPUT_SIZE: 224
138 MODALITY: 'RGB'
139 ARCH: 'resnet50'
140 RESULT_FILE: 'pose_filter.txt'
141
142 BG:
143 RESULT_FILE: 'bg_filter.txt'
144 DATA_NAME: 'bg_val_feature.npy'
145
146 PERSON:
147 RESULT_FILE: 'person_filter.txt'
148 DATA_NAME: 'person_val_feature.npy'
149
150
1 import os
2 import cv2
3 import numpy as np
4 from keras.models import Model
5 from keras.models import load_model
6 from sklearn.externals import joblib
7 from tensorflow.keras.preprocessing.image import img_to_array
8
9 os.environ["CUDA_VISIBLE_DEVICES"] = '0'
10 os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
11
12
13 class FeatureExtractor(object):
14 def __init__(self, input_size=224, out_put_layer='avg_pool', model_path='FerPlus3.h5'):
15 self.model = load_model(model_path)
16 self.input_size = input_size
17 self.model_inter = Model(inputs=self.model.input, outputs=self.model.get_layer(out_put_layer).output)
18
19 def inference(self, image):
20 image = cv2.resize(image, (self.input_size, self.input_size))
21 image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
22 image = image.astype("float") / 255.0
23 image = img_to_array(image)
24 image = np.expand_dims(image, axis=0)
25 feature = self.model_inter.predict(image)[0]
26 return feature
27
28
29 def features2feature(pics_features):
30
31 pics_features = np.array(pics_features)
32 fea_mean = pics_features.mean(axis=0)
33 fea_max = np.amax(pics_features, axis=0)
34 fea_min = np.amin(pics_features, axis=0)
35 fea_std = pics_features.std(axis=0)
36
37 return np.concatenate((fea_mean, fea_max, fea_min, fea_std), axis=1).reshape(1, -1)
38
39
40 def start_filter(config):
41
42 cls_emotion_path = config['MODEL']['CLS_EMOTION']
43 face_feature_dir = config['VIDEO']['FACE_FEATURE_DIR']
44 frame_list_dir = config['VIDEO']['FRAME_LIST_DIR']
45 result_file_name = config['EMOTION']['RESULT_FILE']
46
47 svm_clf = joblib.load(cls_emotion_path)
48
49 result_file_path = os.path.join(frame_list_dir, result_file_name)
50 result_file = open(result_file_path, 'w')
51
52 face_feature_names = os.listdir(face_feature_dir)
53 for face_feature in face_feature_names:
54 face_feature_path = os.path.join(face_feature_dir, face_feature)
55
56 features_np = np.load(face_feature_path, allow_pickle=True)
57
58 feature = features2feature(features_np)
59 res = svm_clf.predict_proba(feature)
60 proba = np.squeeze(res)
61 # class_pre = svm_clf.predict(feature)
62
63 result_file.write(face_feature[:-4] + ' ')
64 result_file.write(str(proba[0]) + ',' + str(proba[1]) + ',' + str(proba[2]) + '\n')
65
66 result_file.close()
67
68
69
70
71
1 import os
2 import torch.optim
3 import numpy as np
4 import torch.optim
5 import torch.nn.parallel
6 from ops.models import TSN
7 from ops.transforms import *
8 from ops.dataset import TSNDataSet
9 from torch.nn import functional as F
10
11
12 def gen_file_list(frame_save_dir, frame_list_dir):
13
14 val_path = os.path.join(frame_list_dir, 'val.txt')
15 video_names = os.listdir(frame_save_dir)
16 ucf101_rgb_val_file = open(val_path, 'w')
17
18 for video_name in video_names:
19 images_dir = os.path.join(frame_save_dir, video_name)
20 ucf101_rgb_val_file.write(video_name)
21 ucf101_rgb_val_file.write(' ')
22 ucf101_rgb_val_file.write(str(len(os.listdir(images_dir))))
23 ucf101_rgb_val_file.write('\n')
24
25 ucf101_rgb_val_file.close()
26
27 return val_path
28
29
30 def start_filter(config):
31 arch = config['FIGHTING_2']['ARCH']
32 prefix = config['VIDEO']['PREFIX']
33 modality = config['FIGHTING_2']['MODALITY']
34 test_crop = config['FIGHTING_2']['TEST_CROP']
35 batch_size = config['FIGHTING_2']['BATCH_SIZE']
36 weights_path = config['MODEL']['CLS_FIGHTING_2']
37 test_segment = config['FIGHTING_2']['TEST_SEGMENT']
38 frame_save_dir = config['VIDEO']['FRAME_SAVE_DIR']
39 frame_list_dir = config['VIDEO']['FRAME_LIST_DIR']
40 result_file_name = config['FIGHTING_2']['RESULT_FILE']
41
42 workers = 8
43 num_class = 2
44 shift_div = 8
45 img_feature_dim = 256
46
47 softmax = False
48 is_shift = True
49 full_res = False
50 non_local = False
51 dense_sample = False
52 twice_sample = False
53
54 val_list = gen_file_list(frame_save_dir, frame_list_dir)
55 result_file_path = os.path.join(frame_list_dir, result_file_name)
56
57 pretrain = 'imagenet'
58 shift_place = 'blockres'
59 crop_fusion_type = 'avg'
60
61 net = TSN(num_class, test_segment if is_shift else 1, modality,
62 base_model=arch,
63 consensus_type=crop_fusion_type,
64 img_feature_dim=img_feature_dim,
65 pretrain=pretrain,
66 is_shift=is_shift, shift_div=shift_div, shift_place=shift_place,
67 non_local=non_local,
68 )
69
70 checkpoint = torch.load(weights_path)
71 checkpoint = checkpoint['state_dict']
72
73 base_dict = {'.'.join(k.split('.')[1:]): v for k, v in list(checkpoint.items())}
74 replace_dict = {'base_model.classifier.weight': 'new_fc.weight',
75 'base_model.classifier.bias': 'new_fc.bias',
76 }
77 for k, v in replace_dict.items():
78 if k in base_dict:
79 base_dict[v] = base_dict.pop(k)
80
81 net.load_state_dict(base_dict)
82
83 input_size = net.scale_size if full_res else net.input_size
84
85 if test_crop == 1:
86 cropping = torchvision.transforms.Compose([
87 GroupScale(net.scale_size),
88 GroupCenterCrop(input_size),
89 ])
90 elif test_crop == 3: # do not flip, so only 5 crops
91 cropping = torchvision.transforms.Compose([
92 GroupFullResSample(input_size, net.scale_size, flip=False)
93 ])
94 elif test_crop == 5: # do not flip, so only 5 crops
95 cropping = torchvision.transforms.Compose([
96 GroupOverSample(input_size, net.scale_size, flip=False)
97 ])
98 elif test_crop == 10:
99 cropping = torchvision.transforms.Compose([
100 GroupOverSample(input_size, net.scale_size)
101 ])
102 else:
103 raise ValueError("Only 1, 5, 10 crops are supported while we got {}".format(test_crop))
104
105 data_loader = torch.utils.data.DataLoader(
106 TSNDataSet(frame_save_dir, val_list, num_segments=test_segment,
107 new_length=1 if modality == "RGB" else 5,
108 modality=modality,
109 image_tmpl=prefix,
110 test_mode=True,
111 remove_missing=False,
112 transform=torchvision.transforms.Compose([
113 cropping,
114 Stack(roll=(arch in ['BNInception', 'InceptionV3'])),
115 ToTorchFormatTensor(div=(arch not in ['BNInception', 'InceptionV3'])),
116 GroupNormalize(net.input_mean, net.input_std),
117 ]), dense_sample=dense_sample, twice_sample=twice_sample),
118 batch_size=batch_size, shuffle=False,
119 num_workers=workers, pin_memory=True,
120 )
121
122 net = torch.nn.DataParallel(net.cuda())
123 net.eval()
124 data_gen = enumerate(data_loader)
125 max_num = len(data_loader.dataset)
126
127 result_file = open(result_file_path, 'w')
128
129 for i, data_pair in data_gen:
130 directory, data = data_pair
131 with torch.no_grad():
132 if i >= max_num:
133 break
134 num_crop = test_crop
135 if dense_sample:
136 num_crop *= 10 # 10 clips for testing when using dense sample
137
138 if twice_sample:
139 num_crop *= 2
140
141 if modality == 'RGB':
142 length = 3
143 elif modality == 'Flow':
144 length = 10
145 elif modality == 'RGBDiff':
146 length = 18
147 else:
148 raise ValueError("Unknown modality " + modality)
149
150 data_in = data.view(-1, length, data.size(2), data.size(3))
151 if is_shift:
152 data_in = data_in.view(batch_size * num_crop, test_segment, length, data_in.size(2), data_in.size(3))
153 rst, feature = net(data_in)
154 rst = rst.reshape(batch_size, num_crop, -1).mean(1)
155
156 if softmax:
157 # take the softmax to normalize the output to probability
158 rst = F.softmax(rst, dim=1)
159
160 rst = rst.data.cpu().numpy().copy()
161
162 if net.module.is_shift:
163 rst = rst.reshape(batch_size, num_class)
164 else:
165 rst = rst.reshape((batch_size, -1, num_class)).mean(axis=1).reshape((batch_size, num_class))
166
167 proba = np.squeeze(rst)
168 print(proba)
169 proba = np.exp(proba)/sum(np.exp(proba))
170 result_file.write(str(directory[0]) + ' ')
171 result_file.write(str(proba[0]) + ',' + str(proba[1]) + '\n')
172
173 result_file.close()
174 print('fighting filter end')
...\ No newline at end of file ...\ No newline at end of file
1 import os
2 import torch.optim
3 import numpy as np
4 import torch.optim
5 import torch.nn.parallel
6 from ops.models import TSN
7 from ops.transforms import *
8 from ops.dataset import TSNDataSet
9 from torch.nn import functional as F
10
11
12 def gen_file_list(frame_save_dir, frame_list_dir):
13
14 val_path = os.path.join(frame_list_dir, 'flow_val.txt')
15 video_names = os.listdir(frame_save_dir)
16 ucf101_rgb_val_file = open(val_path, 'w')
17
18 for video_name in video_names:
19 images_dir = os.path.join(frame_save_dir, video_name)
20 ucf101_rgb_val_file.write(video_name)
21 ucf101_rgb_val_file.write(' ')
22 ori_list = os.listdir(images_dir)
23 select_list = [element for element in ori_list if 'x' in element]
24 ucf101_rgb_val_file.write(str(len(select_list)))
25 ucf101_rgb_val_file.write('\n')
26
27 ucf101_rgb_val_file.close()
28
29 return val_path
30
31
32 def start_filter(config):
33 arch = config['FLOW']['ARCH']
34 prefix = config['VIDEO']['FLOW_PREFIX']
35 modality = config['FLOW']['MODALITY']
36 test_crop = config['FLOW']['TEST_CROP']
37 batch_size = config['FLOW']['BATCH_SIZE']
38 weights_path = config['MODEL']['CLS_FLOW']
39 test_segment = config['FLOW']['TEST_SEGMENT']
40 frame_save_dir = config['VIDEO']['FLOW_SAVE_DIR']
41 frame_list_dir = config['VIDEO']['FRAME_LIST_DIR']
42 result_file_name = config['FLOW']['RESULT_FILE']
43
44 workers = 8
45 num_class = 3
46 shift_div = 8
47 img_feature_dim = 256
48
49 softmax = False
50 is_shift = True
51 full_res = False
52 non_local = False
53 dense_sample = False
54 twice_sample = False
55
56 val_list = gen_file_list(frame_save_dir, frame_list_dir)
57 result_file_path = os.path.join(frame_list_dir, result_file_name)
58
59 pretrain = 'imagenet'
60 shift_place = 'blockres'
61 crop_fusion_type = 'avg'
62
63 net = TSN(num_class, test_segment if is_shift else 1, modality,
64 base_model=arch,
65 consensus_type=crop_fusion_type,
66 img_feature_dim=img_feature_dim,
67 pretrain=pretrain,
68 is_shift=is_shift, shift_div=shift_div, shift_place=shift_place,
69 non_local=non_local,
70 )
71
72 checkpoint = torch.load(weights_path)
73 checkpoint = checkpoint['state_dict']
74
75 base_dict = {'.'.join(k.split('.')[1:]): v for k, v in list(checkpoint.items())}
76 replace_dict = {'base_model.classifier.weight': 'new_fc.weight',
77 'base_model.classifier.bias': 'new_fc.bias',
78 }
79 for k, v in replace_dict.items():
80 if k in base_dict:
81 base_dict[v] = base_dict.pop(k)
82
83 net.load_state_dict(base_dict)
84
85 input_size = net.scale_size if full_res else net.input_size
86
87 if test_crop == 1:
88 cropping = torchvision.transforms.Compose([
89 GroupScale(net.scale_size),
90 GroupCenterCrop(input_size),
91 ])
92 elif test_crop == 3: # do not flip, so only 5 crops
93 cropping = torchvision.transforms.Compose([
94 GroupFullResSample(input_size, net.scale_size, flip=False)
95 ])
96 elif test_crop == 5: # do not flip, so only 5 crops
97 cropping = torchvision.transforms.Compose([
98 GroupOverSample(input_size, net.scale_size, flip=False)
99 ])
100 elif test_crop == 10:
101 cropping = torchvision.transforms.Compose([
102 GroupOverSample(input_size, net.scale_size)
103 ])
104 else:
105 raise ValueError("Only 1, 5, 10 crops are supported while we got {}".format(test_crop))
106
107 data_loader = torch.utils.data.DataLoader(
108 TSNDataSet(frame_save_dir, val_list, num_segments=test_segment,
109 new_length=1 if modality == "RGB" else 5,
110 modality=modality,
111 image_tmpl=prefix,
112 test_mode=True,
113 remove_missing=False,
114 transform=torchvision.transforms.Compose([
115 cropping,
116 Stack(roll=(arch in ['BNInception', 'InceptionV3'])),
117 ToTorchFormatTensor(div=(arch not in ['BNInception', 'InceptionV3'])),
118 GroupNormalize(net.input_mean, net.input_std),
119 ]), dense_sample=dense_sample, twice_sample=twice_sample),
120 batch_size=batch_size, shuffle=False,
121 num_workers=workers, pin_memory=True,
122 )
123
124 net = torch.nn.DataParallel(net.cuda())
125 net.eval()
126 data_gen = enumerate(data_loader)
127 max_num = len(data_loader.dataset)
128
129 result_file = open(result_file_path, 'w')
130
131 for i, data_pair in data_gen:
132 directory, data = data_pair
133 with torch.no_grad():
134 if i >= max_num:
135 break
136 num_crop = test_crop
137 if dense_sample:
138 num_crop *= 10 # 10 clips for testing when using dense sample
139
140 if twice_sample:
141 num_crop *= 2
142
143 if modality == 'RGB':
144 length = 3
145 elif modality == 'Flow':
146 length = 10
147 elif modality == 'RGBDiff':
148 length = 18
149 else:
150 raise ValueError("Unknown modality " + modality)
151
152 data_in = data.view(-1, length, data.size(2), data.size(3))
153 if is_shift:
154 data_in = data_in.view(batch_size * num_crop, test_segment, length, data_in.size(2), data_in.size(3))
155 rst, feature = net(data_in)
156 rst = rst.reshape(batch_size, num_crop, -1).mean(1)
157
158 if softmax:
159 # take the softmax to normalize the output to probability
160 rst = F.softmax(rst, dim=1)
161
162 rst = rst.data.cpu().numpy().copy()
163
164 if net.module.is_shift:
165 rst = rst.reshape(batch_size, num_class)
166 else:
167 rst = rst.reshape((batch_size, -1, num_class)).mean(axis=1).reshape((batch_size, num_class))
168
169 proba = np.squeeze(rst)
170 proba = np.exp(proba)/sum(np.exp(proba))
171 result_file.write(str(directory[0]) + ' ')
172 result_file.write(str(proba[0]) + ',' + str(proba[1]) + ',' + str(proba[2]) + '\n')
173
174 result_file.close()
175 print('fighting filter end')
...\ No newline at end of file ...\ No newline at end of file
1 import os
2 import cv2
3 import yaml
4 import tensorflow as tf
5
6
7 def load_config(config_path):
8 with open(config_path, 'r') as cf:
9 config_obj = yaml.load(cf, Loader=yaml.FullLoader)
10 print(config_obj)
11 return config_obj
12
13
14 def load_argue_model(config):
15
16 cls_argue_path = config['MODEL']['CLS_ARGUE']
17 with tf.Graph().as_default():
18
19 if os.path.isfile(cls_argue_path):
20 print('Model filename: %s' % cls_argue_path)
21 with tf.gfile.GFile(cls_argue_path, 'rb') as f:
22 graph_def = tf.GraphDef()
23 graph_def.ParseFromString(f.read())
24 tf.import_graph_def(graph_def, name='')
25
26 x = tf.get_default_graph().get_tensor_by_name("x_batch:0")
27 output = tf.get_default_graph().get_tensor_by_name("output/BiasAdd:0")
28
29 config = tf.ConfigProto()
30 config.gpu_options.allow_growth = False
31 sess = tf.Session(config=config)
32
33 return x, output, sess
1 from ops.basic_ops import *
...\ No newline at end of file ...\ No newline at end of file
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
1 import torch
2
3
4 class Identity(torch.nn.Module):
5 def forward(self, input):
6 return input
7
8
9 class SegmentConsensus(torch.nn.Module):
10
11 def __init__(self, consensus_type, dim=1):
12 super(SegmentConsensus, self).__init__()
13 self.consensus_type = consensus_type
14 self.dim = dim
15 self.shape = None
16
17 def forward(self, input_tensor):
18 self.shape = input_tensor.size()
19 if self.consensus_type == 'avg':
20 output = input_tensor.mean(dim=self.dim, keepdim=True)
21 elif self.consensus_type == 'identity':
22 output = input_tensor
23 else:
24 output = None
25
26 return output
27
28
29 class ConsensusModule(torch.nn.Module):
30
31 def __init__(self, consensus_type, dim=1):
32 super(ConsensusModule, self).__init__()
33 self.consensus_type = consensus_type if consensus_type != 'rnn' else 'identity'
34 self.dim = dim
35
36 def forward(self, input):
37 return SegmentConsensus(self.consensus_type, self.dim)(input)
1 # Code for "TSM: Temporal Shift Module for Efficient Video Understanding"
2 # arXiv:1811.08383
3 # Ji Lin*, Chuang Gan, Song Han
4 # {jilin, songhan}@mit.edu, ganchuang@csail.mit.edu
5
6 import torch.utils.data as data
7
8 from PIL import Image
9 import os
10 import numpy as np
11 from numpy.random import randint
12
13
14 class VideoRecord(object):
15 def __init__(self, row):
16 self._data = row
17
18 @property
19 def path(self):
20 return self._data[0]
21
22 @property
23 def num_frames(self):
24 return int(self._data[1])
25
26
27 class TSNDataSet(data.Dataset):
28 def __init__(self, root_path, list_file,
29 num_segments=3, new_length=1, modality='RGB',
30 image_tmpl='img_{:05d}.jpg', transform=None,
31 random_shift=True, test_mode=False,
32 remove_missing=False, dense_sample=False, twice_sample=False):
33
34 self.root_path = root_path
35 self.list_file = list_file
36 self.num_segments = num_segments
37 self.new_length = new_length
38 self.modality = modality
39 self.image_tmpl = image_tmpl
40 self.transform = transform
41 self.random_shift = random_shift
42 self.test_mode = test_mode
43 self.remove_missing = remove_missing
44 self.dense_sample = dense_sample # using dense sample as I3D
45 self.twice_sample = twice_sample # twice sample for more validation
46 if self.dense_sample:
47 print('=> Using dense sample for the dataset...')
48 if self.twice_sample:
49 print('=> Using twice sample for the dataset...')
50
51 if self.modality == 'RGBDiff':
52 self.new_length += 1 # Diff needs one more image to calculate diff
53
54 self._parse_list()
55
56 def _load_image(self, directory, idx):
57 if self.modality == 'RGB' or self.modality == 'RGBDiff':
58 try:
59 return [Image.open(os.path.join(self.root_path, directory, self.image_tmpl.format(idx))).convert('RGB')]
60 except Exception:
61 print('error loading image:', os.path.join(self.root_path, directory, self.image_tmpl.format(idx)))
62 return [Image.open(os.path.join(self.root_path, directory, self.image_tmpl.format(1))).convert('RGB')]
63 elif self.modality == 'Flow':
64 if self.image_tmpl == 'flow_{}_{:05d}.jpg': # ucf
65 x_img = Image.open(os.path.join(self.root_path, directory, self.image_tmpl.format('x', idx))).convert(
66 'L')
67 y_img = Image.open(os.path.join(self.root_path, directory, self.image_tmpl.format('y', idx))).convert(
68 'L')
69 elif self.image_tmpl == '{:06d}-{}_{:05d}.jpg': # something v1 flow
70 x_img = Image.open(os.path.join(self.root_path, '{:06d}'.format(int(directory)), self.image_tmpl.
71 format(int(directory), 'x', idx))).convert('L')
72 y_img = Image.open(os.path.join(self.root_path, '{:06d}'.format(int(directory)), self.image_tmpl.
73 format(int(directory), 'y', idx))).convert('L')
74 else:
75 try:
76 # idx_skip = 1 + (idx-1)*5
77 flow = Image.open(os.path.join(self.root_path, directory, self.image_tmpl.format(idx))).convert(
78 'RGB')
79 except Exception:
80 print('error loading flow file:',
81 os.path.join(self.root_path, directory, self.image_tmpl.format(idx)))
82 flow = Image.open(os.path.join(self.root_path, directory, self.image_tmpl.format(1))).convert('RGB')
83 # the input flow file is RGB image with (flow_x, flow_y, blank) for each channel
84 flow_x, flow_y, _ = flow.split()
85 x_img = flow_x.convert('L')
86 y_img = flow_y.convert('L')
87
88 return [x_img, y_img]
89
90 def _parse_list(self):
91 # check the frame number is large >3:
92 tmp = [x.strip().split(' ') for x in open(self.list_file)]
93 if not self.test_mode or self.remove_missing:
94 tmp = [item for item in tmp if int(item[1]) >= 3]
95 self.video_list = [VideoRecord(item) for item in tmp]
96
97 if self.image_tmpl == '{:06d}-{}_{:05d}.jpg':
98 for v in self.video_list:
99 v._data[1] = int(v._data[1]) / 2
100 print('video number:%d' % (len(self.video_list)))
101
102 def _sample_indices(self, record):
103 """
104
105 :param record: VideoRecord
106 :return: list
107 """
108 if self.dense_sample: # i3d dense sample
109 sample_pos = max(1, 1 + record.num_frames - 64)
110 t_stride = 64 // self.num_segments
111 start_idx = 0 if sample_pos == 1 else np.random.randint(0, sample_pos - 1)
112 offsets = [(idx * t_stride + start_idx) % record.num_frames for idx in range(self.num_segments)]
113 return np.array(offsets) + 1
114 else: # normal sample
115 average_duration = (record.num_frames - self.new_length + 1) // self.num_segments
116 if average_duration > 0:
117 offsets = np.multiply(list(range(self.num_segments)), average_duration) + randint(average_duration,
118 size=self.num_segments)
119 elif record.num_frames > self.num_segments:
120 offsets = np.sort(randint(record.num_frames - self.new_length + 1, size=self.num_segments))
121 else:
122 offsets = np.zeros((self.num_segments,))
123 return offsets + 1
124
125 def _get_val_indices(self, record):
126 if self.dense_sample: # i3d dense sample
127 sample_pos = max(1, 1 + record.num_frames - 64)
128 t_stride = 64 // self.num_segments
129 start_idx = 0 if sample_pos == 1 else np.random.randint(0, sample_pos - 1)
130 offsets = [(idx * t_stride + start_idx) % record.num_frames for idx in range(self.num_segments)]
131 return np.array(offsets) + 1
132 else:
133 if record.num_frames > self.num_segments + self.new_length - 1:
134 tick = (record.num_frames - self.new_length + 1) / float(self.num_segments)
135 offsets = np.array([int(tick / 2.0 + tick * x) for x in range(self.num_segments)])
136 else:
137 offsets = np.zeros((self.num_segments,))
138 return offsets + 1
139
140 def _get_test_indices(self, record):
141 if self.dense_sample:
142 sample_pos = max(1, 1 + record.num_frames - 64)
143 t_stride = 64 // self.num_segments
144 start_list = np.linspace(0, sample_pos - 1, num=10, dtype=int)
145 offsets = []
146 for start_idx in start_list.tolist():
147 offsets += [(idx * t_stride + start_idx) % record.num_frames for idx in range(self.num_segments)]
148 return np.array(offsets) + 1
149 elif self.twice_sample:
150 tick = (record.num_frames - self.new_length + 1) / float(self.num_segments)
151
152 offsets = np.array([int(tick / 2.0 + tick * x) for x in range(self.num_segments)] +
153 [int(tick * x) for x in range(self.num_segments)])
154
155 return offsets + 1
156 else:
157 tick = (record.num_frames - self.new_length + 1) / float(self.num_segments)
158 offsets = np.array([int(tick / 2.0 + tick * x) for x in range(self.num_segments)])
159 return offsets + 1
160
161 def __getitem__(self, index):
162 record = self.video_list[index]
163 # check this is a legit video folder
164
165 if self.image_tmpl == 'flow_{}_{:05d}.jpg':
166 file_name = self.image_tmpl.format('x', 1)
167 full_path = os.path.join(self.root_path, record.path, file_name)
168 elif self.image_tmpl == '{:06d}-{}_{:05d}.jpg':
169 file_name = self.image_tmpl.format(int(record.path), 'x', 1)
170 full_path = os.path.join(self.root_path, '{:06d}'.format(int(record.path)), file_name)
171 else:
172 file_name = self.image_tmpl.format(1)
173 full_path = os.path.join(self.root_path, record.path, file_name)
174
175 while not os.path.exists(full_path):
176 print('################## Not Found:', os.path.join(self.root_path, record.path, file_name))
177 index = np.random.randint(len(self.video_list))
178 record = self.video_list[index]
179 if self.image_tmpl == 'flow_{}_{:05d}.jpg':
180 file_name = self.image_tmpl.format('x', 1)
181 full_path = os.path.join(self.root_path, record.path, file_name)
182 elif self.image_tmpl == '{:06d}-{}_{:05d}.jpg':
183 file_name = self.image_tmpl.format(int(record.path), 'x', 1)
184 full_path = os.path.join(self.root_path, '{:06d}'.format(int(record.path)), file_name)
185 else:
186 file_name = self.image_tmpl.format(1)
187 full_path = os.path.join(self.root_path, record.path, file_name)
188
189 if not self.test_mode:
190 segment_indices = self._sample_indices(record) if self.random_shift else self._get_val_indices(record)
191 else:
192 segment_indices = self._get_test_indices(record)
193 return self.get(record, segment_indices)
194
195 def get(self, record, indices):
196
197 images = list()
198 for seg_ind in indices:
199 p = int(seg_ind)
200 for i in range(self.new_length):
201 seg_imgs = self._load_image(record.path, p)
202 images.extend(seg_imgs)
203 if p < record.num_frames:
204 p += 1
205
206 process_data = self.transform(images)
207 return record.path, process_data
208
209 def __len__(self):
210 return len(self.video_list)
1 # Code for "TSM: Temporal Shift Module for Efficient Video Understanding"
2 # arXiv:1811.08383
3 # Ji Lin*, Chuang Gan, Song Han
4 # {jilin, songhan}@mit.edu, ganchuang@csail.mit.edu
5
6 import os
7
8 ROOT_DATASET = '/data1/action_1_images/' # '/data/jilin/'
9
10
11 def return_ucf101(modality):
12 filename_categories = 'labels/classInd.txt'
13 if modality == 'RGB':
14 root_data = ROOT_DATASET + 'images'
15 filename_imglist_train = 'file_list/ucf101_rgb_train_split_1.txt'
16 filename_imglist_val = 'file_list/ucf101_rgb_val_split_1.txt'
17 prefix = 'img_{:05d}.jpg'
18 elif modality == 'Flow':
19 root_data = ROOT_DATASET + 'UCF101/jpg'
20 filename_imglist_train = 'UCF101/file_list/ucf101_flow_train_split_1.txt'
21 filename_imglist_val = 'UCF101/file_list/ucf101_flow_val_split_1.txt'
22 prefix = 'flow_{}_{:05d}.jpg'
23 else:
24 raise NotImplementedError('no such modality:' + modality)
25 return filename_categories, filename_imglist_train, filename_imglist_val, root_data, prefix
26
27
28 def return_hmdb51(modality):
29 filename_categories = 51
30 if modality == 'RGB':
31 root_data = ROOT_DATASET + 'HMDB51/images'
32 filename_imglist_train = 'HMDB51/splits/hmdb51_rgb_train_split_1.txt'
33 filename_imglist_val = 'HMDB51/splits/hmdb51_rgb_val_split_1.txt'
34 prefix = 'img_{:05d}.jpg'
35 elif modality == 'Flow':
36 root_data = ROOT_DATASET + 'HMDB51/images'
37 filename_imglist_train = 'HMDB51/splits/hmdb51_flow_train_split_1.txt'
38 filename_imglist_val = 'HMDB51/splits/hmdb51_flow_val_split_1.txt'
39 prefix = 'flow_{}_{:05d}.jpg'
40 else:
41 raise NotImplementedError('no such modality:' + modality)
42 return filename_categories, filename_imglist_train, filename_imglist_val, root_data, prefix
43
44
45 def return_something(modality):
46 filename_categories = 'something/v1/category.txt'
47 if modality == 'RGB':
48 root_data = ROOT_DATASET + 'something/v1/20bn-something-something-v1'
49 filename_imglist_train = 'something/v1/train_videofolder.txt'
50 filename_imglist_val = 'something/v1/val_videofolder.txt'
51 prefix = '{:05d}.jpg'
52 elif modality == 'Flow':
53 root_data = ROOT_DATASET + 'something/v1/20bn-something-something-v1-flow'
54 filename_imglist_train = 'something/v1/train_videofolder_flow.txt'
55 filename_imglist_val = 'something/v1/val_videofolder_flow.txt'
56 prefix = '{:06d}-{}_{:05d}.jpg'
57 else:
58 print('no such modality:'+modality)
59 raise NotImplementedError
60 return filename_categories, filename_imglist_train, filename_imglist_val, root_data, prefix
61
62
63 def return_somethingv2(modality):
64 filename_categories = 'something/v2/category.txt'
65 if modality == 'RGB':
66 root_data = ROOT_DATASET + 'something/v2/20bn-something-something-v2-frames'
67 filename_imglist_train = 'something/v2/train_videofolder.txt'
68 filename_imglist_val = 'something/v2/val_videofolder.txt'
69 prefix = '{:06d}.jpg'
70 elif modality == 'Flow':
71 root_data = ROOT_DATASET + 'something/v2/20bn-something-something-v2-flow'
72 filename_imglist_train = 'something/v2/train_videofolder_flow.txt'
73 filename_imglist_val = 'something/v2/val_videofolder_flow.txt'
74 prefix = '{:06d}.jpg'
75 else:
76 raise NotImplementedError('no such modality:'+modality)
77 return filename_categories, filename_imglist_train, filename_imglist_val, root_data, prefix
78
79
80 def return_jester(modality):
81 filename_categories = 'jester/category.txt'
82 if modality == 'RGB':
83 prefix = '{:05d}.jpg'
84 root_data = ROOT_DATASET + 'jester/20bn-jester-v1'
85 filename_imglist_train = 'jester/train_videofolder.txt'
86 filename_imglist_val = 'jester/val_videofolder.txt'
87 else:
88 raise NotImplementedError('no such modality:'+modality)
89 return filename_categories, filename_imglist_train, filename_imglist_val, root_data, prefix
90
91
92 def return_kinetics(modality):
93 filename_categories = 400
94 if modality == 'RGB':
95 root_data = ROOT_DATASET + 'kinetics/images'
96 filename_imglist_train = 'kinetics/labels/train_videofolder.txt'
97 filename_imglist_val = 'kinetics/labels/val_videofolder.txt'
98 prefix = 'img_{:05d}.jpg'
99 else:
100 raise NotImplementedError('no such modality:' + modality)
101 return filename_categories, filename_imglist_train, filename_imglist_val, root_data, prefix
102
103
104 def return_dataset(dataset, modality):
105 dict_single = {'jester': return_jester, 'something': return_something, 'somethingv2': return_somethingv2,
106 'ucf101': return_ucf101, 'hmdb51': return_hmdb51,
107 'kinetics': return_kinetics}
108 if dataset in dict_single:
109 file_categories, file_imglist_train, file_imglist_val, root_data, prefix = dict_single[dataset](modality)
110 else:
111 raise ValueError('Unknown dataset '+dataset)
112
113 file_imglist_train = os.path.join(ROOT_DATASET, file_imglist_train)
114 file_imglist_val = os.path.join(ROOT_DATASET, file_imglist_val)
115 if isinstance(file_categories, str):
116 file_categories = os.path.join(ROOT_DATASET, file_categories)
117 with open(file_categories) as f:
118 lines = f.readlines()
119 categories = [item.rstrip() for item in lines]
120 else: # number of categories
121 categories = [None] * file_categories
122 n_class = len(categories)
123 print('{}: {} classes'.format(dataset, n_class))
124 return n_class, file_imglist_train, file_imglist_val, root_data, prefix
1 # Non-local block using embedded gaussian
2 # Code from
3 # https://github.com/AlexHex7/Non-local_pytorch/blob/master/Non-Local_pytorch_0.3.1/lib/non_local_embedded_gaussian.py
4 import torch
5 from torch import nn
6 from torch.nn import functional as F
7
8
9 class _NonLocalBlockND(nn.Module):
10 def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True):
11 super(_NonLocalBlockND, self).__init__()
12
13 assert dimension in [1, 2, 3]
14
15 self.dimension = dimension
16 self.sub_sample = sub_sample
17
18 self.in_channels = in_channels
19 self.inter_channels = inter_channels
20
21 if self.inter_channels is None:
22 self.inter_channels = in_channels // 2
23 if self.inter_channels == 0:
24 self.inter_channels = 1
25
26 if dimension == 3:
27 conv_nd = nn.Conv3d
28 max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2))
29 bn = nn.BatchNorm3d
30 elif dimension == 2:
31 conv_nd = nn.Conv2d
32 max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2))
33 bn = nn.BatchNorm2d
34 else:
35 conv_nd = nn.Conv1d
36 max_pool_layer = nn.MaxPool1d(kernel_size=(2))
37 bn = nn.BatchNorm1d
38
39 self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
40 kernel_size=1, stride=1, padding=0)
41
42 if bn_layer:
43 self.W = nn.Sequential(
44 conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
45 kernel_size=1, stride=1, padding=0),
46 bn(self.in_channels)
47 )
48 nn.init.constant_(self.W[1].weight, 0)
49 nn.init.constant_(self.W[1].bias, 0)
50 else:
51 self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
52 kernel_size=1, stride=1, padding=0)
53 nn.init.constant_(self.W.weight, 0)
54 nn.init.constant_(self.W.bias, 0)
55
56 self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
57 kernel_size=1, stride=1, padding=0)
58 self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
59 kernel_size=1, stride=1, padding=0)
60
61 if sub_sample:
62 self.g = nn.Sequential(self.g, max_pool_layer)
63 self.phi = nn.Sequential(self.phi, max_pool_layer)
64
65 def forward(self, x):
66 '''
67 :param x: (b, c, t, h, w)
68 :return:
69 '''
70
71 batch_size = x.size(0)
72
73 g_x = self.g(x).view(batch_size, self.inter_channels, -1)
74 g_x = g_x.permute(0, 2, 1)
75
76 theta_x = self.theta(x).view(batch_size, self.inter_channels, -1)
77 theta_x = theta_x.permute(0, 2, 1)
78 phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)
79 f = torch.matmul(theta_x, phi_x)
80 f_div_C = F.softmax(f, dim=-1)
81
82 y = torch.matmul(f_div_C, g_x)
83 y = y.permute(0, 2, 1).contiguous()
84 y = y.view(batch_size, self.inter_channels, *x.size()[2:])
85 W_y = self.W(y)
86 z = W_y + x
87
88 return z
89
90
91 class NONLocalBlock1D(_NonLocalBlockND):
92 def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):
93 super(NONLocalBlock1D, self).__init__(in_channels,
94 inter_channels=inter_channels,
95 dimension=1, sub_sample=sub_sample,
96 bn_layer=bn_layer)
97
98
99 class NONLocalBlock2D(_NonLocalBlockND):
100 def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):
101 super(NONLocalBlock2D, self).__init__(in_channels,
102 inter_channels=inter_channels,
103 dimension=2, sub_sample=sub_sample,
104 bn_layer=bn_layer)
105
106
107 class NONLocalBlock3D(_NonLocalBlockND):
108 def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):
109 super(NONLocalBlock3D, self).__init__(in_channels,
110 inter_channels=inter_channels,
111 dimension=3, sub_sample=sub_sample,
112 bn_layer=bn_layer)
113
114
115 class NL3DWrapper(nn.Module):
116 def __init__(self, block, n_segment):
117 super(NL3DWrapper, self).__init__()
118 self.block = block
119 self.nl = NONLocalBlock3D(block.bn3.num_features)
120 self.n_segment = n_segment
121
122 def forward(self, x):
123 x = self.block(x)
124
125 nt, c, h, w = x.size()
126 x = x.view(nt // self.n_segment, self.n_segment, c, h, w).transpose(1, 2) # n, c, t, h, w
127 x = self.nl(x)
128 x = x.transpose(1, 2).contiguous().view(nt, c, h, w)
129 return x
130
131
132 def make_non_local(net, n_segment):
133 import torchvision
134 import archs
135 if isinstance(net, torchvision.models.ResNet):
136 net.layer2 = nn.Sequential(
137 NL3DWrapper(net.layer2[0], n_segment),
138 net.layer2[1],
139 NL3DWrapper(net.layer2[2], n_segment),
140 net.layer2[3],
141 )
142 net.layer3 = nn.Sequential(
143 NL3DWrapper(net.layer3[0], n_segment),
144 net.layer3[1],
145 NL3DWrapper(net.layer3[2], n_segment),
146 net.layer3[3],
147 NL3DWrapper(net.layer3[4], n_segment),
148 net.layer3[5],
149 )
150 else:
151 raise NotImplementedError
152
153
154 if __name__ == '__main__':
155 from torch.autograd import Variable
156 import torch
157
158 sub_sample = True
159 bn_layer = True
160
161 img = Variable(torch.zeros(2, 3, 20))
162 net = NONLocalBlock1D(3, sub_sample=sub_sample, bn_layer=bn_layer)
163 out = net(img)
164 print(out.size())
165
166 img = Variable(torch.zeros(2, 3, 20, 20))
167 net = NONLocalBlock2D(3, sub_sample=sub_sample, bn_layer=bn_layer)
168 out = net(img)
169 print(out.size())
170
171 img = Variable(torch.randn(2, 3, 10, 20, 20))
172 net = NONLocalBlock3D(3, sub_sample=sub_sample, bn_layer=bn_layer)
173 out = net(img)
174 print(out.size())
...\ No newline at end of file ...\ No newline at end of file
1 # Code for "TSM: Temporal Shift Module for Efficient Video Understanding"
2 # arXiv:1811.08383
3 # Ji Lin*, Chuang Gan, Song Han
4 # {jilin, songhan}@mit.edu, ganchuang@csail.mit.edu
5
6 import torch
7 import torch.nn as nn
8 import torch.nn.functional as F
9
10
11 class TemporalShift(nn.Module):
12 def __init__(self, net, n_segment=3, n_div=8, inplace=False):
13 super(TemporalShift, self).__init__()
14 self.net = net
15 self.n_segment = n_segment
16 self.fold_div = n_div
17 self.inplace = inplace
18 if inplace:
19 print('=> Using in-place shift...')
20 print('=> Using fold div: {}'.format(self.fold_div))
21
22 def forward(self, x):
23 x = self.shift(x, self.n_segment, fold_div=self.fold_div, inplace=self.inplace)
24 return self.net(x)
25
26 @staticmethod
27 def shift(x, n_segment, fold_div=3, inplace=False):
28 nt, c, h, w = x.size()
29 n_batch = nt // n_segment
30 x = x.view(n_batch, n_segment, c, h, w)
31
32 fold = c // fold_div
33 if inplace:
34 # Due to some out of order error when performing parallel computing.
35 # May need to write a CUDA kernel.
36 raise NotImplementedError
37 # out = InplaceShift.apply(x, fold)
38 else:
39 out = torch.zeros_like(x)
40 out[:, :-1, :fold] = x[:, 1:, :fold] # shift left
41 out[:, 1:, fold: 2 * fold] = x[:, :-1, fold: 2 * fold] # shift right
42 out[:, :, 2 * fold:] = x[:, :, 2 * fold:] # not shift
43
44 return out.view(nt, c, h, w)
45
46
47 class InplaceShift(torch.autograd.Function):
48 # Special thanks to @raoyongming for the help to this function
49 @staticmethod
50 def forward(ctx, input, fold):
51 # not support higher order gradient
52 # input = input.detach_()
53 ctx.fold_ = fold
54 n, t, c, h, w = input.size()
55 buffer = input.data.new(n, t, fold, h, w).zero_()
56 buffer[:, :-1] = input.data[:, 1:, :fold]
57 input.data[:, :, :fold] = buffer
58 buffer.zero_()
59 buffer[:, 1:] = input.data[:, :-1, fold: 2 * fold]
60 input.data[:, :, fold: 2 * fold] = buffer
61 return input
62
63 @staticmethod
64 def backward(ctx, grad_output):
65 # grad_output = grad_output.detach_()
66 fold = ctx.fold_
67 n, t, c, h, w = grad_output.size()
68 buffer = grad_output.data.new(n, t, fold, h, w).zero_()
69 buffer[:, 1:] = grad_output.data[:, :-1, :fold]
70 grad_output.data[:, :, :fold] = buffer
71 buffer.zero_()
72 buffer[:, :-1] = grad_output.data[:, 1:, fold: 2 * fold]
73 grad_output.data[:, :, fold: 2 * fold] = buffer
74 return grad_output, None
75
76
77 class TemporalPool(nn.Module):
78 def __init__(self, net, n_segment):
79 super(TemporalPool, self).__init__()
80 self.net = net
81 self.n_segment = n_segment
82
83 def forward(self, x):
84 x = self.temporal_pool(x, n_segment=self.n_segment)
85 return self.net(x)
86
87 @staticmethod
88 def temporal_pool(x, n_segment):
89 nt, c, h, w = x.size()
90 n_batch = nt // n_segment
91 x = x.view(n_batch, n_segment, c, h, w).transpose(1, 2) # n, c, t, h, w
92 x = F.max_pool3d(x, kernel_size=(3, 1, 1), stride=(2, 1, 1), padding=(1, 0, 0))
93 x = x.transpose(1, 2).contiguous().view(nt // 2, c, h, w)
94 return x
95
96
97 def make_temporal_shift(net, n_segment, n_div=8, place='blockres', temporal_pool=False):
98 if temporal_pool:
99 n_segment_list = [n_segment, n_segment // 2, n_segment // 2, n_segment // 2]
100 else:
101 n_segment_list = [n_segment] * 4
102 assert n_segment_list[-1] > 0
103 print('=> n_segment per stage: {}'.format(n_segment_list))
104
105 import torchvision
106 if isinstance(net, torchvision.models.ResNet):
107 if place == 'block':
108 def make_block_temporal(stage, this_segment):
109 blocks = list(stage.children())
110 print('=> Processing stage with {} blocks'.format(len(blocks)))
111 for i, b in enumerate(blocks):
112 blocks[i] = TemporalShift(b, n_segment=this_segment, n_div=n_div)
113 return nn.Sequential(*(blocks))
114
115 net.layer1 = make_block_temporal(net.layer1, n_segment_list[0])
116 net.layer2 = make_block_temporal(net.layer2, n_segment_list[1])
117 net.layer3 = make_block_temporal(net.layer3, n_segment_list[2])
118 net.layer4 = make_block_temporal(net.layer4, n_segment_list[3])
119
120 elif 'blockres' in place:
121 n_round = 1
122 if len(list(net.layer3.children())) >= 23:
123 n_round = 2
124 print('=> Using n_round {} to insert temporal shift'.format(n_round))
125
126 def make_block_temporal(stage, this_segment):
127 blocks = list(stage.children())
128 print('=> Processing stage with {} blocks residual'.format(len(blocks)))
129 for i, b in enumerate(blocks):
130 if i % n_round == 0:
131 blocks[i].conv1 = TemporalShift(b.conv1, n_segment=this_segment, n_div=n_div)
132 return nn.Sequential(*blocks)
133
134 net.layer1 = make_block_temporal(net.layer1, n_segment_list[0])
135 net.layer2 = make_block_temporal(net.layer2, n_segment_list[1])
136 net.layer3 = make_block_temporal(net.layer3, n_segment_list[2])
137 net.layer4 = make_block_temporal(net.layer4, n_segment_list[3])
138 else:
139 raise NotImplementedError(place)
140
141
142 def make_temporal_pool(net, n_segment):
143 import torchvision
144 if isinstance(net, torchvision.models.ResNet):
145 print('=> Injecting nonlocal pooling')
146 net.layer2 = TemporalPool(net.layer2, n_segment)
147 else:
148 raise NotImplementedError
149
150
151 if __name__ == '__main__':
152 # test inplace shift v.s. vanilla shift
153 tsm1 = TemporalShift(nn.Sequential(), n_segment=8, n_div=8, inplace=False)
154 tsm2 = TemporalShift(nn.Sequential(), n_segment=8, n_div=8, inplace=True)
155
156 print('=> Testing CPU...')
157 # test forward
158 with torch.no_grad():
159 for i in range(10):
160 x = torch.rand(2 * 8, 3, 224, 224)
161 y1 = tsm1(x)
162 y2 = tsm2(x)
163 assert torch.norm(y1 - y2).item() < 1e-5
164
165 # test backward
166 with torch.enable_grad():
167 for i in range(10):
168 x1 = torch.rand(2 * 8, 3, 224, 224)
169 x1.requires_grad_()
170 x2 = x1.clone()
171 y1 = tsm1(x1)
172 y2 = tsm2(x2)
173 grad1 = torch.autograd.grad((y1 ** 2).mean(), [x1])[0]
174 grad2 = torch.autograd.grad((y2 ** 2).mean(), [x2])[0]
175 assert torch.norm(grad1 - grad2).item() < 1e-5
176
177 print('=> Testing GPU...')
178 tsm1.cuda()
179 tsm2.cuda()
180 # test forward
181 with torch.no_grad():
182 for i in range(10):
183 x = torch.rand(2 * 8, 3, 224, 224).cuda()
184 y1 = tsm1(x)
185 y2 = tsm2(x)
186 assert torch.norm(y1 - y2).item() < 1e-5
187
188 # test backward
189 with torch.enable_grad():
190 for i in range(10):
191 x1 = torch.rand(2 * 8, 3, 224, 224).cuda()
192 x1.requires_grad_()
193 x2 = x1.clone()
194 y1 = tsm1(x1)
195 y2 = tsm2(x2)
196 grad1 = torch.autograd.grad((y1 ** 2).mean(), [x1])[0]
197 grad2 = torch.autograd.grad((y2 ** 2).mean(), [x2])[0]
198 assert torch.norm(grad1 - grad2).item() < 1e-5
199 print('Test passed.')
200
201
202
203
1 import numpy as np
2
3
4 def softmax(scores):
5 es = np.exp(scores - scores.max(axis=-1)[..., None])
6 return es / es.sum(axis=-1)[..., None]
7
8
9 class AverageMeter(object):
10 """Computes and stores the average and current value"""
11
12 def __init__(self):
13 self.reset()
14
15 def reset(self):
16 self.val = 0
17 self.avg = 0
18 self.sum = 0
19 self.count = 0
20
21 def update(self, val, n=1):
22 self.val = val
23 self.sum += val * n
24 self.count += n
25 self.avg = self.sum / self.count
26
27
28 def accuracy(output, target, topk=(1,)):
29 """Computes the precision@k for the specified values of k"""
30 maxk = max(topk)
31 batch_size = target.size(0)
32
33 _, pred = output.topk(maxk, 1, True, True)
34 pred = pred.t()
35 correct = pred.eq(target.view(1, -1).expand_as(pred))
36
37 res = []
38 for k in topk:
39 correct_k = correct[:k].view(-1).float().sum(0)
40 res.append(correct_k.mul_(100.0 / batch_size))
41 return res
...\ No newline at end of file ...\ No newline at end of file
1 import os
2 import cv2
3 import numpy as np
4 import pickle
5
6 def start_filter(config):
7 cls_class_path = config['MODEL']['CLS_PERSON']
8 feature_save_dir = config['VIDEO']['FACE_FEATURE_DIR']
9 frame_list_dir = config['VIDEO']['FRAME_LIST_DIR']
10 result_file_name = config['PERSON']['RESULT_FILE']
11 feature_name = config['PERSON']['DATA_NAME']
12
13 xgboost_model = pickle.load(open(cls_class_path, "rb"))
14
15 result_file_path = os.path.join(frame_list_dir, result_file_name)
16 result_file = open(result_file_path, 'w')
17
18 feature_path = os.path.join(feature_save_dir, feature_name)
19 val_annotation_pairs = np.load(feature_path, allow_pickle=True, encoding='latin1')
20
21 X_val = []
22 Y_val = []
23 Y_names = []
24 for j in range(len(val_annotation_pairs)):
25 pair = val_annotation_pairs[j]
26 X_val.append(np.squeeze(pair[0]))
27 Y_val.append(pair[1])
28 Y_names.append(pair[2])
29
30 X_val = np.array(X_val)
31 y_pred = xgboost_model.predict_proba(X_val)
32
33 for i, Y_name in enumerate(Y_names):
34 result_file.write(Y_name + ' ')
35 result_file.write(str(y_pred[i][0]) + ',' + str(y_pred[i][1]) + ',' + str(y_pred[i][2]) + '\n')
36
37 result_file.close()
38
39
40
41
42
1 import os
2 import torch.optim
3 import numpy as np
4 import torch.optim
5 import torch.nn.parallel
6 from ops.models import TSN
7 from ops.transforms import *
8 from ops.dataset import TSNDataSet
9 from torch.nn import functional as F
10
11
12 def gen_file_list(frame_save_dir, frame_list_dir):
13
14 val_path = os.path.join(frame_list_dir, 'val.txt')
15 video_names = os.listdir(frame_save_dir)
16 ucf101_rgb_val_file = open(val_path, 'w')
17
18 for video_name in video_names:
19 images_dir = os.path.join(frame_save_dir, video_name)
20 ucf101_rgb_val_file.write(video_name)
21 ucf101_rgb_val_file.write(' ')
22 ucf101_rgb_val_file.write(str(len(os.listdir(images_dir))))
23 ucf101_rgb_val_file.write('\n')
24
25 ucf101_rgb_val_file.close()
26
27 return val_path
28
29
30 def start_filter(config):
31 arch = config['FIGHTING']['ARCH']
32 prefix = config['VIDEO']['PREFIX']
33 modality = config['POSE']['MODALITY']
34 test_crop = config['POSE']['TEST_CROP']
35 batch_size = config['POSE']['BATCH_SIZE']
36 weights_path = config['MODEL']['CLS_POSE']
37 test_segment = config['POSE']['TEST_SEGMENT']
38 frame_save_dir = config['VIDEO']['POSE_FRAME_SAVE_DIR']
39 frame_list_dir = config['VIDEO']['FRAME_LIST_DIR']
40 result_file_name = config['POSE']['RESULT_FILE']
41
42 workers = 8
43 num_class = 3
44 shift_div = 8
45 img_feature_dim = 256
46
47 softmax = False
48 is_shift = True
49 full_res = False
50 non_local = False
51 dense_sample = False
52 twice_sample = False
53
54 val_list = gen_file_list(frame_save_dir, frame_list_dir)
55 result_file_path = os.path.join(frame_list_dir, result_file_name)
56
57 pretrain = 'imagenet'
58 shift_place = 'blockres'
59 crop_fusion_type = 'avg'
60
61 net = TSN(num_class, test_segment if is_shift else 1, modality,
62 base_model=arch,
63 consensus_type=crop_fusion_type,
64 img_feature_dim=img_feature_dim,
65 pretrain=pretrain,
66 is_shift=is_shift, shift_div=shift_div, shift_place=shift_place,
67 non_local=non_local,
68 )
69
70 checkpoint = torch.load(weights_path)
71 checkpoint = checkpoint['state_dict']
72
73 base_dict = {'.'.join(k.split('.')[1:]): v for k, v in list(checkpoint.items())}
74 replace_dict = {'base_model.classifier.weight': 'new_fc.weight',
75 'base_model.classifier.bias': 'new_fc.bias',
76 }
77 for k, v in replace_dict.items():
78 if k in base_dict:
79 base_dict[v] = base_dict.pop(k)
80
81 net.load_state_dict(base_dict)
82
83 input_size = net.scale_size if full_res else net.input_size
84
85 if test_crop == 1:
86 cropping = torchvision.transforms.Compose([
87 GroupScale(net.scale_size),
88 GroupCenterCrop(input_size),
89 ])
90 elif test_crop == 3: # do not flip, so only 5 crops
91 cropping = torchvision.transforms.Compose([
92 GroupFullResSample(input_size, net.scale_size, flip=False)
93 ])
94 elif test_crop == 5: # do not flip, so only 5 crops
95 cropping = torchvision.transforms.Compose([
96 GroupOverSample(input_size, net.scale_size, flip=False)
97 ])
98 elif test_crop == 10:
99 cropping = torchvision.transforms.Compose([
100 GroupOverSample(input_size, net.scale_size)
101 ])
102 else:
103 raise ValueError("Only 1, 5, 10 crops are supported while we got {}".format(test_crop))
104
105 data_loader = torch.utils.data.DataLoader(
106 TSNDataSet(frame_save_dir, val_list, num_segments=test_segment,
107 new_length=1 if modality == "RGB" else 5,
108 modality=modality,
109 image_tmpl=prefix,
110 test_mode=True,
111 remove_missing=False,
112 transform=torchvision.transforms.Compose([
113 cropping,
114 Stack(roll=(arch in ['BNInception', 'InceptionV3'])),
115 ToTorchFormatTensor(div=(arch not in ['BNInception', 'InceptionV3'])),
116 GroupNormalize(net.input_mean, net.input_std),
117 ]), dense_sample=dense_sample, twice_sample=twice_sample),
118 batch_size=batch_size, shuffle=False,
119 num_workers=workers, pin_memory=True,
120 )
121
122 net = torch.nn.DataParallel(net.cuda())
123 net.eval()
124 data_gen = enumerate(data_loader)
125 max_num = len(data_loader.dataset)
126
127 result_file = open(result_file_path, 'w')
128
129 for i, data_pair in data_gen:
130 directory, data = data_pair
131 with torch.no_grad():
132 if i >= max_num:
133 break
134 num_crop = test_crop
135 if dense_sample:
136 num_crop *= 10 # 10 clips for testing when using dense sample
137
138 if twice_sample:
139 num_crop *= 2
140
141 if modality == 'RGB':
142 length = 3
143 elif modality == 'Flow':
144 length = 10
145 elif modality == 'RGBDiff':
146 length = 18
147 else:
148 raise ValueError("Unknown modality " + modality)
149
150 data_in = data.view(-1, length, data.size(2), data.size(3))
151 if is_shift:
152 data_in = data_in.view(batch_size * num_crop, test_segment, length, data_in.size(2), data_in.size(3))
153 rst, feature = net(data_in)
154 rst = rst.reshape(batch_size, num_crop, -1).mean(1)
155
156 if softmax:
157 # take the softmax to normalize the output to probability
158 rst = F.softmax(rst, dim=1)
159
160 rst = rst.data.cpu().numpy().copy()
161
162 if net.module.is_shift:
163 rst = rst.reshape(batch_size, num_class)
164 else:
165 rst = rst.reshape((batch_size, -1, num_class)).mean(axis=1).reshape((batch_size, num_class))
166
167 proba = np.squeeze(rst)
168 proba = np.exp(proba)/sum(np.exp(proba))
169 result_file.write(str(directory[0]) + ' ')
170 result_file.write(str(proba[0]) + ',' + str(proba[1]) + ',' + str(proba[2]) + '\n')
171
172 result_file.close()
173 print('video filter end')
...\ No newline at end of file ...\ No newline at end of file
1 import os
2 import cv2
3 import load_util
4 import media_util
5 import numpy as np
6 from sklearn.metrics import confusion_matrix
7 import fighting_filter, emotion_filter, argue_filter, audio_filter, class_filter
8 import video_filter, pose_filter, flow_filter
9
10
11
12 def accuracy_cal(config):
13
14 label_file_path = config['VIDEO']['LABEL_PATH']
15 frame_list_dir = config['VIDEO']['FRAME_LIST_DIR']
16 final_file_name = config['AUDIO']['RESULT_FILE']
17
18 final_file_path = os.path.join(frame_list_dir, final_file_name)
19 final_file_lines = open(final_file_path).readlines()
20 label_file_lines = open(label_file_path).readlines()
21
22
23 final_pairs = {line.strip().split(' ')[0]: line.strip().split(' ')[1] for line in final_file_lines}
24
25 lines_num = len(label_file_lines) - 1
26 hit = 0
27 for i, label_line in enumerate(label_file_lines):
28 if i == 0:
29 continue
30 file, label = label_line.strip().split(' ')
31 final_pre = final_pairs[file]
32 final_pre_class = np.argmax(np.array(final_pre.split(','))) + 1
33 print(final_pre_class, label)
34 if final_pre_class == int(label):
35 hit += 1
36
37 return hit/lines_num
38
39
40 def main():
41 config_path = r'config.yaml'
42 config = load_util.load_config(config_path)
43
44 media_util.extract_wav(config)
45 media_util.extract_frame(config)
46 media_util.extract_frame_pose(config)
47 media_util.extract_is10(config)
48 media_util.extract_random_face_feature(config)
49 media_util.extract_mirror(config)
50
51 fighting_2_filter.start_filter(config)
52 emotion_filter.start_filter(config)
53
54 audio_filter.start_filter(config)
55
56 class_filter.start_filter(config)
57 video_filter.start_filter(config)
58 pose_filter.start_filter(config)
59
60 flow_filter.start_filter(config)
61
62 acc = accuracy_cal(config)
63 print(acc)
64
65
66 if __name__ == '__main__':
67 main()
...\ No newline at end of file ...\ No newline at end of file
1 import os
2 import torch.optim
3 import numpy as np
4 import torch.nn.parallel
5 from ops.models import TSN
6 from ops.transforms import *
7 from ops.dataset import TSNDataSet
8 from torch.nn import functional as F
9
10
11 def gen_file_list(frame_save_dir, frame_list_dir):
12
13 val_path = os.path.join(frame_list_dir, 'val.txt')
14 video_names = os.listdir(frame_save_dir)
15 ucf101_rgb_val_file = open(val_path, 'w')
16
17 for video_name in video_names:
18 images_dir = os.path.join(frame_save_dir, video_name)
19 ucf101_rgb_val_file.write(video_name)
20 ucf101_rgb_val_file.write(' ')
21 ucf101_rgb_val_file.write(str(len(os.listdir(images_dir))))
22 ucf101_rgb_val_file.write('\n')
23
24 ucf101_rgb_val_file.close()
25
26 return val_path
27
28
29 def start_filter(config):
30 arch = config['FIGHTING']['ARCH']
31 prefix = config['VIDEO']['PREFIX']
32 modality = config['VIDEO_FILTER']['MODALITY']
33 test_crop = config['VIDEO_FILTER']['TEST_CROP']
34 batch_size = config['VIDEO_FILTER']['BATCH_SIZE']
35 weights_path = config['MODEL']['CLS_VIDEO']
36 test_segment = config['VIDEO_FILTER']['TEST_SEGMENT']
37 frame_save_dir = config['VIDEO']['FRAME_SAVE_DIR']
38 frame_list_dir = config['VIDEO']['FRAME_LIST_DIR']
39 result_file_name = config['VIDEO_FILTER']['RESULT_FILE']
40
41 workers = 8
42 num_class = 3
43 shift_div = 8
44 img_feature_dim = 256
45
46 softmax = False
47 is_shift = True
48 full_res = False
49 non_local = False
50 dense_sample = False
51 twice_sample = False
52
53 val_list = gen_file_list(frame_save_dir, frame_list_dir)
54 result_file_path = os.path.join(frame_list_dir, result_file_name)
55
56 pretrain = 'imagenet'
57 shift_place = 'blockres'
58 crop_fusion_type = 'avg'
59
60 net = TSN(num_class, test_segment if is_shift else 1, modality,
61 base_model=arch,
62 consensus_type=crop_fusion_type,
63 img_feature_dim=img_feature_dim,
64 pretrain=pretrain,
65 is_shift=is_shift, shift_div=shift_div, shift_place=shift_place,
66 non_local=non_local,
67 )
68
69 checkpoint = torch.load(weights_path)
70 checkpoint = checkpoint['state_dict']
71
72 base_dict = {'.'.join(k.split('.')[1:]): v for k, v in list(checkpoint.items())}
73 replace_dict = {'base_model.classifier.weight': 'new_fc.weight',
74 'base_model.classifier.bias': 'new_fc.bias',
75 }
76 for k, v in replace_dict.items():
77 if k in base_dict:
78 base_dict[v] = base_dict.pop(k)
79
80 net.load_state_dict(base_dict)
81
82 input_size = net.scale_size if full_res else net.input_size
83
84 if test_crop == 1:
85 cropping = torchvision.transforms.Compose([
86 GroupScale(net.scale_size),
87 GroupCenterCrop(input_size),
88 ])
89 elif test_crop == 3: # do not flip, so only 5 crops
90 cropping = torchvision.transforms.Compose([
91 GroupFullResSample(input_size, net.scale_size, flip=False)
92 ])
93 elif test_crop == 5: # do not flip, so only 5 crops
94 cropping = torchvision.transforms.Compose([
95 GroupOverSample(input_size, net.scale_size, flip=False)
96 ])
97 elif test_crop == 10:
98 cropping = torchvision.transforms.Compose([
99 GroupOverSample(input_size, net.scale_size)
100 ])
101 else:
102 raise ValueError("Only 1, 5, 10 crops are supported while we got {}".format(test_crop))
103
104 data_loader = torch.utils.data.DataLoader(
105 TSNDataSet(frame_save_dir, val_list, num_segments=test_segment,
106 new_length=1 if modality == "RGB" else 5,
107 modality=modality,
108 image_tmpl=prefix,
109 test_mode=True,
110 remove_missing=False,
111 transform=torchvision.transforms.Compose([
112 cropping,
113 Stack(roll=(arch in ['BNInception', 'InceptionV3'])),
114 ToTorchFormatTensor(div=(arch not in ['BNInception', 'InceptionV3'])),
115 GroupNormalize(net.input_mean, net.input_std),
116 ]), dense_sample=dense_sample, twice_sample=twice_sample),
117 batch_size=batch_size, shuffle=False,
118 num_workers=workers, pin_memory=True,
119 )
120
121 net = torch.nn.DataParallel(net.cuda())
122 net.eval()
123 data_gen = enumerate(data_loader)
124 max_num = len(data_loader.dataset)
125
126 result_file = open(result_file_path, 'w')
127
128 for i, data_pair in data_gen:
129 directory, data = data_pair
130 with torch.no_grad():
131 if i >= max_num:
132 break
133 num_crop = test_crop
134 if dense_sample:
135 num_crop *= 10 # 10 clips for testing when using dense sample
136
137 if twice_sample:
138 num_crop *= 2
139
140 if modality == 'RGB':
141 length = 3
142 elif modality == 'Flow':
143 length = 10
144 elif modality == 'RGBDiff':
145 length = 18
146 else:
147 raise ValueError("Unknown modality " + modality)
148
149 data_in = data.view(-1, length, data.size(2), data.size(3))
150 if is_shift:
151 data_in = data_in.view(batch_size * num_crop, test_segment, length, data_in.size(2), data_in.size(3))
152
153 rst, feature = net(data_in)
154 rst = rst.reshape(batch_size, num_crop, -1).mean(1)
155
156 if softmax:
157 # take the softmax to normalize the output to probability
158 rst = F.softmax(rst, dim=1)
159
160 rst = rst.data.cpu().numpy().copy()
161
162 if net.module.is_shift:
163 rst = rst.reshape(batch_size, num_class)
164 else:
165 rst = rst.reshape((batch_size, -1, num_class)).mean(axis=1).reshape((batch_size, num_class))
166
167 proba = np.squeeze(rst)
168 proba = np.exp(proba)/sum(np.exp(proba))
169 result_file.write(str(directory[0]) + ' ')
170 result_file.write(str(proba[0]) + ',' + str(proba[1]) + ',' + str(proba[2]) + '\n')
171
172 result_file.close()
173 print('video filter end')
...\ No newline at end of file ...\ No newline at end of file
Styling with Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!