add drwa
Showing
8 changed files
with
169 additions
and
32 deletions
... | @@ -3,9 +3,9 @@ seed: 3407 | ... | @@ -3,9 +3,9 @@ seed: 3407 |
3 | dataset: | 3 | dataset: |
4 | name: 'SLData' | 4 | name: 'SLData' |
5 | args: | 5 | args: |
6 | data_root: '/Users/zhouweiqi/Downloads/gcfp/data/dataset2' | 6 | data_root: '/dataset160x14' |
7 | train_anno_file: '/Users/zhouweiqi/Downloads/gcfp/data/dataset2/train.csv' | 7 | train_anno_file: '/dataset160x14/train.csv' |
8 | val_anno_file: '/Users/zhouweiqi/Downloads/gcfp/data/dataset2/valid.csv' | 8 | val_anno_file: '/dataset160x14/valid.csv' |
9 | 9 | ||
10 | dataloader: | 10 | dataloader: |
11 | batch_size: 8 | 11 | batch_size: 8 |
... | @@ -18,7 +18,7 @@ model: | ... | @@ -18,7 +18,7 @@ model: |
18 | args: | 18 | args: |
19 | seq_lens: 160 | 19 | seq_lens: 160 |
20 | num_classes: 10 | 20 | num_classes: 10 |
21 | embed_dim: 9 | 21 | embed_dim: 14 |
22 | depth: 6 | 22 | depth: 6 |
23 | num_heads: 1 | 23 | num_heads: 1 |
24 | mlp_ratio: 4.0 | 24 | mlp_ratio: 4.0 |
... | @@ -36,6 +36,11 @@ solver: | ... | @@ -36,6 +36,11 @@ solver: |
36 | epoch: 100 | 36 | epoch: 100 |
37 | base_on: null | 37 | base_on: null |
38 | model_path: null | 38 | model_path: null |
39 | val_image_path: '/labeled/valid/image' | ||
40 | val_go_path: '/go_res/valid' | ||
41 | val_map_path: '/dataset160x14/create_map.json' | ||
42 | draw_font_path: '/dataset160x14/STZHONGS.TTF' | ||
43 | thresholds: 0.5 | ||
39 | 44 | ||
40 | optimizer: | 45 | optimizer: |
41 | name: 'Adam' | 46 | name: 'Adam' |
... | @@ -58,5 +63,5 @@ solver: | ... | @@ -58,5 +63,5 @@ solver: |
58 | alpha: 0.8 | 63 | alpha: 0.8 |
59 | 64 | ||
60 | logger: | 65 | logger: |
61 | log_root: '/Users/zhouweiqi/Downloads/test/logs' | 66 | log_root: '/logs' |
62 | suffix: 'sl-6-1' | 67 | suffix: 'sl-6-1' |
... | \ No newline at end of file | ... | \ No newline at end of file | ... | ... |
... | @@ -7,7 +7,7 @@ import uuid | ... | @@ -7,7 +7,7 @@ import uuid |
7 | import cv2 | 7 | import cv2 |
8 | import pandas as pd | 8 | import pandas as pd |
9 | from tools import get_file_paths, load_json | 9 | from tools import get_file_paths, load_json |
10 | from word2vec import simple_word2vec, jwq_word2vec | 10 | from word2vec import jwq_word2vec, simple_word2vec |
11 | 11 | ||
12 | 12 | ||
13 | def clean_go_res(go_res_dir): | 13 | def clean_go_res(go_res_dir): |
... | @@ -101,7 +101,7 @@ def build_anno_file(dataset_dir, anno_file_path): | ... | @@ -101,7 +101,7 @@ def build_anno_file(dataset_dir, anno_file_path): |
101 | df['name'] = img_list | 101 | df['name'] = img_list |
102 | df.to_csv(anno_file_path) | 102 | df.to_csv(anno_file_path) |
103 | 103 | ||
104 | def build_dataset(img_dir, go_res_dir, label_dir, top_text_list, skip_list, save_dir): | 104 | def build_dataset(img_dir, go_res_dir, label_dir, top_text_list, skip_list, save_dir, is_create_map=False): |
105 | """ | 105 | """ |
106 | Args: | 106 | Args: |
107 | img_dir: str 图片目录 | 107 | img_dir: str 图片目录 |
... | @@ -121,6 +121,7 @@ def build_dataset(img_dir, go_res_dir, label_dir, top_text_list, skip_list, save | ... | @@ -121,6 +121,7 @@ def build_dataset(img_dir, go_res_dir, label_dir, top_text_list, skip_list, save |
121 | group_cn_list = ['开票日期', '发票代码', '机打号码', '车辆类型', '电话', '发动机号码', '车架号', '帐号', '开户银行', '小写'] | 121 | group_cn_list = ['开票日期', '发票代码', '机打号码', '车辆类型', '电话', '发动机号码', '车架号', '帐号', '开户银行', '小写'] |
122 | test_group_id = [1, 2, 5, 9, 20, 15, 16, 22, 24, 28] | 122 | test_group_id = [1, 2, 5, 9, 20, 15, 16, 22, 24, 28] |
123 | 123 | ||
124 | create_map = {} | ||
124 | for img_name in sorted(os.listdir(img_dir)): | 125 | for img_name in sorted(os.listdir(img_dir)): |
125 | if img_name in skip_list: | 126 | if img_name in skip_list: |
126 | print('Info: skip {0}'.format(img_name)) | 127 | print('Info: skip {0}'.format(img_name)) |
... | @@ -188,8 +189,9 @@ def build_dataset(img_dir, go_res_dir, label_dir, top_text_list, skip_list, save | ... | @@ -188,8 +189,9 @@ def build_dataset(img_dir, go_res_dir, label_dir, top_text_list, skip_list, save |
188 | X = list() | 189 | X = list() |
189 | y_true = list() | 190 | y_true = list() |
190 | 191 | ||
191 | text_vec_max_lens = 15 * 50 | 192 | # text_vec_max_lens = 15 * 50 |
192 | dim = 1 + 5 + 8 + text_vec_max_lens | 193 | # dim = 1 + 5 + 8 + text_vec_max_lens |
194 | dim = 1 + 5 + 8 | ||
193 | num_classes = 10 | 195 | num_classes = 10 |
194 | for i in range(160): | 196 | for i in range(160): |
195 | if i >= valid_lens: | 197 | if i >= valid_lens: |
... | @@ -201,7 +203,7 @@ def build_dataset(img_dir, go_res_dir, label_dir, top_text_list, skip_list, save | ... | @@ -201,7 +203,7 @@ def build_dataset(img_dir, go_res_dir, label_dir, top_text_list, skip_list, save |
201 | feature_vec = [1.] | 203 | feature_vec = [1.] |
202 | feature_vec.extend(simple_word2vec(text)) | 204 | feature_vec.extend(simple_word2vec(text)) |
203 | feature_vec.extend([x0/w, y0/h, x1/w, y1/h, x2/w, y2/h, x3/w, y3/h]) | 205 | feature_vec.extend([x0/w, y0/h, x1/w, y1/h, x2/w, y2/h, x3/w, y3/h]) |
204 | feature_vec.extend(jwq_word2vec(text, text_vec_max_lens)) | 206 | # feature_vec.extend(jwq_word2vec(text, text_vec_max_lens)) |
205 | X.append(feature_vec) | 207 | X.append(feature_vec) |
206 | 208 | ||
207 | y_true.append([0 for _ in range(num_classes)]) | 209 | y_true.append([0 for _ in range(num_classes)]) |
... | @@ -211,7 +213,7 @@ def build_dataset(img_dir, go_res_dir, label_dir, top_text_list, skip_list, save | ... | @@ -211,7 +213,7 @@ def build_dataset(img_dir, go_res_dir, label_dir, top_text_list, skip_list, save |
211 | feature_vec = [0.] | 213 | feature_vec = [0.] |
212 | feature_vec.extend(simple_word2vec(text)) | 214 | feature_vec.extend(simple_word2vec(text)) |
213 | feature_vec.extend([x0/w, y0/h, x1/w, y1/h, x2/w, y2/h, x3/w, y3/h]) | 215 | feature_vec.extend([x0/w, y0/h, x1/w, y1/h, x2/w, y2/h, x3/w, y3/h]) |
214 | feature_vec.extend(jwq_word2vec(text, text_vec_max_lens)) | 216 | # feature_vec.extend(jwq_word2vec(text, text_vec_max_lens)) |
215 | X.append(feature_vec) | 217 | X.append(feature_vec) |
216 | 218 | ||
217 | base_label_list = [0 for _ in range(num_classes)] | 219 | base_label_list = [0 for _ in range(num_classes)] |
... | @@ -222,16 +224,34 @@ def build_dataset(img_dir, go_res_dir, label_dir, top_text_list, skip_list, save | ... | @@ -222,16 +224,34 @@ def build_dataset(img_dir, go_res_dir, label_dir, top_text_list, skip_list, save |
222 | feature_vec = [0.] | 224 | feature_vec = [0.] |
223 | feature_vec.extend(simple_word2vec(text)) | 225 | feature_vec.extend(simple_word2vec(text)) |
224 | feature_vec.extend([x0/w, y0/h, x1/w, y1/h, x2/w, y2/h, x3/w, y3/h]) | 226 | feature_vec.extend([x0/w, y0/h, x1/w, y1/h, x2/w, y2/h, x3/w, y3/h]) |
225 | feature_vec.extend(jwq_word2vec(text, text_vec_max_lens)) | 227 | # feature_vec.extend(jwq_word2vec(text, text_vec_max_lens)) |
226 | X.append(feature_vec) | 228 | X.append(feature_vec) |
227 | 229 | ||
228 | y_true.append([0 for _ in range(num_classes)]) | 230 | y_true.append([0 for _ in range(num_classes)]) |
229 | 231 | ||
230 | all_data = [X, y_true, valid_lens] | 232 | all_data = [X, y_true, valid_lens] |
231 | 233 | ||
232 | with open(os.path.join(save_dir, '{0}.json'.format(uuid.uuid3(uuid.NAMESPACE_DNS, img_name))), 'w') as fp: | 234 | save_json_name = '{0}.json'.format(uuid.uuid3(uuid.NAMESPACE_DNS, img_name)) |
235 | with open(os.path.join(save_dir, save_json_name), 'w') as fp: | ||
233 | json.dump(all_data, fp) | 236 | json.dump(all_data, fp) |
234 | 237 | ||
238 | if is_create_map: | ||
239 | create_map[img_name] = { | ||
240 | 'x_y_valid_lens': save_json_name, | ||
241 | 'find_top_text': [go_res_list[i][-1] for i in top_text_idx_set], | ||
242 | 'find_value': {group_cn_list[v]: go_res_list[k][-1] for k, v in label_idx_dict.items()} | ||
243 | } | ||
244 | |||
245 | |||
246 | # break | ||
247 | |||
248 | # print(create_map) | ||
249 | # print(is_create_map) | ||
250 | if create_map: | ||
251 | with open(os.path.join(os.path.dirname(save_dir), 'create_map.json'), 'w') as fp: | ||
252 | json.dump(create_map, fp) | ||
253 | |||
254 | |||
235 | # print('top text find:') | 255 | # print('top text find:') |
236 | # for i in top_text_idx_set: | 256 | # for i in top_text_idx_set: |
237 | # _, text = go_res_list[i] | 257 | # _, text = go_res_list[i] |
... | @@ -249,7 +269,7 @@ def build_dataset(img_dir, go_res_dir, label_dir, top_text_list, skip_list, save | ... | @@ -249,7 +269,7 @@ def build_dataset(img_dir, go_res_dir, label_dir, top_text_list, skip_list, save |
249 | if __name__ == '__main__': | 269 | if __name__ == '__main__': |
250 | base_dir = '/Users/zhouweiqi/Downloads/gcfp/data' | 270 | base_dir = '/Users/zhouweiqi/Downloads/gcfp/data' |
251 | go_dir = os.path.join(base_dir, 'go_res') | 271 | go_dir = os.path.join(base_dir, 'go_res') |
252 | dataset_save_dir = os.path.join(base_dir, 'dataset2') | 272 | dataset_save_dir = os.path.join(base_dir, 'dataset160x14') |
253 | label_dir = os.path.join(base_dir, 'labeled') | 273 | label_dir = os.path.join(base_dir, 'labeled') |
254 | 274 | ||
255 | train_go_path = os.path.join(go_dir, 'train') | 275 | train_go_path = os.path.join(go_dir, 'train') |
... | @@ -331,7 +351,7 @@ if __name__ == '__main__': | ... | @@ -331,7 +351,7 @@ if __name__ == '__main__': |
331 | build_dataset(train_image_path, train_go_path, train_label_path, filter_from_top_text_list, skip_list_train, train_dataset_dir) | 351 | build_dataset(train_image_path, train_go_path, train_label_path, filter_from_top_text_list, skip_list_train, train_dataset_dir) |
332 | build_anno_file(train_dataset_dir, train_anno_file_path) | 352 | build_anno_file(train_dataset_dir, train_anno_file_path) |
333 | 353 | ||
334 | build_dataset(valid_image_path, valid_go_path, valid_label_path, filter_from_top_text_list, skip_list_valid, valid_dataset_dir) | 354 | build_dataset(valid_image_path, valid_go_path, valid_label_path, filter_from_top_text_list, skip_list_valid, valid_dataset_dir, True) |
335 | build_anno_file(valid_dataset_dir, valid_anno_file_path) | 355 | build_anno_file(valid_dataset_dir, valid_anno_file_path) |
336 | 356 | ||
337 | # print(simple_word2vec(' fd2jk接口 额24;叁‘,。测ADF壹试!¥? ')) | 357 | # print(simple_word2vec(' fd2jk接口 额24;叁‘,。测ADF壹试!¥? ')) | ... | ... |
draw.sh
0 → 100755
1 | CUDA_VISIBLE_DEVICES=0 nohup python main.py --config=config/sl.yaml -d > draw.log 2>&1 & | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
eval.sh
0 → 100755
1 | CUDA_VISIBLE_DEVICES=0 nohup python main.py --config=config/sl.yaml -e > eval.log 2>&1 & | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
... | @@ -8,6 +8,7 @@ def main(): | ... | @@ -8,6 +8,7 @@ def main(): |
8 | parser = argparse.ArgumentParser() | 8 | parser = argparse.ArgumentParser() |
9 | parser.add_argument('--config', default='./config/mlp.yaml', type=str, help='config file') | 9 | parser.add_argument('--config', default='./config/mlp.yaml', type=str, help='config file') |
10 | parser.add_argument('-e', '--eval', action="store_true") | 10 | parser.add_argument('-e', '--eval', action="store_true") |
11 | parser.add_argument('-d', '--draw', action="store_true") | ||
11 | args = parser.parse_args() | 12 | args = parser.parse_args() |
12 | 13 | ||
13 | cfg = yaml.load(open(args.config, 'r').read(), Loader=yaml.FullLoader) | 14 | cfg = yaml.load(open(args.config, 'r').read(), Loader=yaml.FullLoader) |
... | @@ -18,6 +19,8 @@ def main(): | ... | @@ -18,6 +19,8 @@ def main(): |
18 | 19 | ||
19 | if args.eval: | 20 | if args.eval: |
20 | solver.evaluate() | 21 | solver.evaluate() |
22 | elif args.draw: | ||
23 | solver.draw_val() | ||
21 | else: | 24 | else: |
22 | solver.run() | 25 | solver.run() |
23 | 26 | ... | ... |
... | @@ -18,7 +18,7 @@ def masked_softmax(X, valid_lens): | ... | @@ -18,7 +18,7 @@ def masked_softmax(X, valid_lens): |
18 | # [batch_size, num_heads, seq_len, seq_len] | 18 | # [batch_size, num_heads, seq_len, seq_len] |
19 | shape = X.shape | 19 | shape = X.shape |
20 | if valid_lens.dim() == 1: | 20 | if valid_lens.dim() == 1: |
21 | valid_lens = torch.repeat_interleave(valid_lens, shape[1]) | 21 | valid_lens = torch.repeat_interleave(valid_lens, shape[2]) |
22 | else: | 22 | else: |
23 | valid_lens = valid_lens.reshape(-1) | 23 | valid_lens = valid_lens.reshape(-1) |
24 | # On the last axis, replace masked elements with a very large negative | 24 | # On the last axis, replace masked elements with a very large negative | ... | ... |
1 | import copy | 1 | import copy |
2 | import os | 2 | import os |
3 | import cv2 | ||
4 | import json | ||
3 | 5 | ||
4 | import torch | 6 | import torch |
7 | from PIL import Image, ImageDraw, ImageFont | ||
5 | 8 | ||
6 | from data import build_dataloader | 9 | from data import build_dataloader |
7 | from loss import build_loss | 10 | from loss import build_loss |
... | @@ -34,6 +37,11 @@ class SLSolver(object): | ... | @@ -34,6 +37,11 @@ class SLSolver(object): |
34 | self.hyper_params = cfg['solver']['args'] | 37 | self.hyper_params = cfg['solver']['args'] |
35 | self.base_on = self.hyper_params['base_on'] | 38 | self.base_on = self.hyper_params['base_on'] |
36 | self.model_path = self.hyper_params['model_path'] | 39 | self.model_path = self.hyper_params['model_path'] |
40 | self.val_image_path = self.hyper_params['val_image_path'] | ||
41 | self.val_go_path = self.hyper_params['val_go_path'] | ||
42 | self.val_map_path = self.hyper_params['val_map_path'] | ||
43 | self.draw_font_path = self.hyper_params['draw_font_path'] | ||
44 | self.thresholds = self.hyper_params['thresholds'] | ||
37 | try: | 45 | try: |
38 | self.epoch = self.hyper_params['epoch'] | 46 | self.epoch = self.hyper_params['epoch'] |
39 | except Exception: | 47 | except Exception: |
... | @@ -41,19 +49,22 @@ class SLSolver(object): | ... | @@ -41,19 +49,22 @@ class SLSolver(object): |
41 | 49 | ||
42 | self.logger, self.log_dir = get_logger_and_log_dir(**cfg['solver']['logger']) | 50 | self.logger, self.log_dir = get_logger_and_log_dir(**cfg['solver']['logger']) |
43 | 51 | ||
44 | def accuracy(self, y_pred, y_true, valid_lens, thresholds=0.5): | 52 | def accuracy(self, y_pred, y_true, valid_lens, eval=False): |
45 | # [batch_size, seq_len, num_classes] | 53 | # [batch_size, seq_len, num_classes] |
46 | y_pred_sigmoid = torch.nn.Sigmoid()(y_pred) | 54 | y_pred_sigmoid = torch.nn.Sigmoid()(y_pred) |
47 | # [batch_size, seq_len] | 55 | # [batch_size, seq_len] |
48 | y_pred_idx = torch.argmax(y_pred_sigmoid, dim=-1) + 1 | 56 | y_pred_idx = torch.argmax(y_pred_sigmoid, dim=-1) + 1 |
49 | # [batch_size, seq_len] | 57 | # [batch_size, seq_len] |
50 | y_pred_is_other = (torch.amax(y_pred_sigmoid, dim=-1) > thresholds).int() | 58 | y_pred_is_other = (torch.amax(y_pred_sigmoid, dim=-1) > self.thresholds).int() |
51 | y_pred_rebuild = torch.multiply(y_pred_idx, y_pred_is_other) | 59 | y_pred_rebuild = torch.multiply(y_pred_idx, y_pred_is_other) |
52 | 60 | ||
53 | y_true_idx = torch.argmax(y_true, dim=-1) + 1 | 61 | y_true_idx = torch.argmax(y_true, dim=-1) + 1 |
54 | y_true_is_other = torch.sum(y_true, dim=-1).int() | 62 | y_true_is_other = torch.sum(y_true, dim=-1).int() |
55 | y_true_rebuild = torch.multiply(y_true_idx, y_true_is_other) | 63 | y_true_rebuild = torch.multiply(y_true_idx, y_true_is_other) |
56 | 64 | ||
65 | if eval: | ||
66 | return y_pred_rebuild, y_true_rebuild | ||
67 | |||
57 | masked_y_true_rebuild = sequence_mask(y_true_rebuild, valid_lens, value=-1) | 68 | masked_y_true_rebuild = sequence_mask(y_true_rebuild, valid_lens, value=-1) |
58 | 69 | ||
59 | return torch.sum((y_pred_rebuild == masked_y_true_rebuild).int()).item() | 70 | return torch.sum((y_pred_rebuild == masked_y_true_rebuild).int()).item() |
... | @@ -168,19 +179,7 @@ class SLSolver(object): | ... | @@ -168,19 +179,7 @@ class SLSolver(object): |
168 | # pred = torch.nn.Sigmoid()(self.model(X)) | 179 | # pred = torch.nn.Sigmoid()(self.model(X)) |
169 | y_pred = self.model(X, valid_lens) | 180 | y_pred = self.model(X, valid_lens) |
170 | 181 | ||
171 | # [batch_size, seq_len, num_classes] | 182 | y_pred_rebuild, y_true_rebuild = self.accuracy(y_pred, y_true, valid_lens, eval=True) |
172 | y_pred_sigmoid = torch.nn.Sigmoid()(y_pred) | ||
173 | # [batch_size, seq_len] | ||
174 | y_pred_idx = torch.argmax(y_pred_sigmoid, dim=-1) + 1 | ||
175 | # [batch_size, seq_len] | ||
176 | y_pred_is_other = (torch.amax(y_pred_sigmoid, dim=-1) > 0.5).int() | ||
177 | y_pred_rebuild = torch.multiply(y_pred_idx, y_pred_is_other) | ||
178 | |||
179 | y_true_idx = torch.argmax(y_true, dim=-1) + 1 | ||
180 | y_true_is_other = torch.sum(y_true, dim=-1).int() | ||
181 | y_true_rebuild = torch.multiply(y_true_idx, y_true_is_other) | ||
182 | |||
183 | # masked_y_true_rebuild = sequence_mask(y_true_rebuild, valid_lens, value=-1) | ||
184 | 183 | ||
185 | for idx, seq_result in enumerate(y_true_rebuild.cpu().numpy().tolist()): | 184 | for idx, seq_result in enumerate(y_true_rebuild.cpu().numpy().tolist()): |
186 | label_true_list.extend(seq_result[: valid_lens.cpu().numpy()[idx]]) | 185 | label_true_list.extend(seq_result[: valid_lens.cpu().numpy()[idx]]) |
... | @@ -193,3 +192,111 @@ class SLSolver(object): | ... | @@ -193,3 +192,111 @@ class SLSolver(object): |
193 | print(acc) | 192 | print(acc) |
194 | print(cm) | 193 | print(cm) |
195 | print(report) | 194 | print(report) |
195 | |||
196 | def draw_val(self): | ||
197 | if not os.path.isdir(self.val_image_path): | ||
198 | print('Warn: val_image_path not exists: {0}'.format(self.val_image_path)) | ||
199 | return | ||
200 | |||
201 | if not os.path.isdir(self.val_go_path): | ||
202 | print('Warn: val_go_path not exists: {0}'.format(self.val_go_path)) | ||
203 | return | ||
204 | |||
205 | if not os.path.isfile(self.val_map_path): | ||
206 | print('Warn: val_map_path not exists: {0}'.format(self.val_map_path)) | ||
207 | return | ||
208 | |||
209 | map_key_input = 'x_y_valid_lens' | ||
210 | map_key_text = 'find_top_text' | ||
211 | map_key_value = 'find_value' | ||
212 | group_cn_list = ['其他', '开票日期', '发票代码', '机打号码', '车辆类型', '电话', '发动机号码', '车架号', '帐号', '开户银行', '小写'] | ||
213 | |||
214 | dataset_base_dir = os.path.dirname(self.val_map_path) | ||
215 | val_dataset_dir = os.path.join(dataset_base_dir, 'valid') | ||
216 | save_dir = os.path.join(dataset_base_dir, 'draw_val') | ||
217 | if not os.path.isdir(save_dir): | ||
218 | os.makedirs(save_dir, exist_ok=True) | ||
219 | |||
220 | self.model.eval() | ||
221 | |||
222 | with open(self.val_map_path, 'r') as fp: | ||
223 | val_map = json.load(fp) | ||
224 | |||
225 | for img_name in sorted(os.listdir(self.val_image_path)): | ||
226 | print('Info: start {0}'.format(img_name)) | ||
227 | image_path = os.path.join(self.val_image_path, img_name) | ||
228 | |||
229 | img = cv2.imread(image_path) | ||
230 | im_h, im_w, _ = img.shape | ||
231 | img_pil = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) | ||
232 | draw = ImageDraw.Draw(img_pil) | ||
233 | |||
234 | if im_h < im_w: | ||
235 | size = int(im_h * 0.015) | ||
236 | else: | ||
237 | size = int(im_w * 0.015) | ||
238 | if size < 14: | ||
239 | size = 14 | ||
240 | font = ImageFont.truetype(self.draw_font_path, size, encoding='utf-8') | ||
241 | |||
242 | green_color = (0, 255, 0) | ||
243 | red_color = (255, 0, 0) | ||
244 | blue_color = (0, 0, 255) | ||
245 | |||
246 | base_image_name, _ = os.path.splitext(img_name) | ||
247 | go_res_json_path = os.path.join(self.val_go_path, '{0}.json'.format(base_image_name)) | ||
248 | with open(go_res_json_path, 'r') as fp: | ||
249 | go_res_list = json.load(fp) | ||
250 | |||
251 | with open(os.path.join(val_dataset_dir, val_map[img_name][map_key_input]), 'r') as fp: | ||
252 | input_list, label_list, valid_lens_scalar = json.load(fp) | ||
253 | |||
254 | X = torch.tensor(input_list).unsqueeze(0).to(self.device) | ||
255 | y_true = torch.tensor(label_list).unsqueeze(0).float().to(self.device) | ||
256 | valid_lens = torch.tenor([valid_lens_scalar, ]).to(self.device) | ||
257 | del input_list | ||
258 | del label_list | ||
259 | |||
260 | y_pred = self.model(X, valid_lens) | ||
261 | |||
262 | y_pred_rebuild, y_true_rebuild = self.accuracy(y_pred, y_true, valid_lens, eval=True) | ||
263 | pred = y_pred_rebuild.cpu().numpy().tolist()[0] | ||
264 | label = y_true_rebuild.cpu().numpy().tolist()[0] | ||
265 | |||
266 | correct = 0 | ||
267 | bbox_draw_dict = dict() | ||
268 | for i in range(valid_lens_scalar): | ||
269 | if pred[i] == label[i]: | ||
270 | correct += 1 | ||
271 | if pred[i] != 0: | ||
272 | # 绿色 | ||
273 | bbox_draw_dict[i] = (group_cn_list[pred[i]], ) | ||
274 | else: | ||
275 | # 红色:左上角label,右上角pred | ||
276 | bbox_draw_dict[i] = (group_cn_list[label[i]], group_cn_list[pred[i]]) | ||
277 | |||
278 | correct_rate = correct / valid_lens_scalar | ||
279 | |||
280 | # 画图 | ||
281 | for idx, text_tuple in bbox_draw_dict.items(): | ||
282 | (x0, y0, x1, y1, x2, y2, x3, y3), _ = go_res_list[idx] | ||
283 | line_color = green_color if len(text_tuple) == 1 else red_color | ||
284 | draw.polygon([(x0, y0), (x1, y1), (x2, y2), (x3, y3)], outline=line_color) | ||
285 | draw.text((int(x0), int(y0)), text_tuple[0], green_color, font=font) | ||
286 | if len(text_tuple) == 2: | ||
287 | draw.text((int(x1), int(y1)), text_tuple[1], red_color, font=font) | ||
288 | |||
289 | draw.text((0, 0), str(correct_rate), blue_color, font=font) | ||
290 | |||
291 | last_y = size | ||
292 | for k, v in val_map[img_name][map_key_value].items(): | ||
293 | draw.text((0, last_y), '{0}: {1}'.format(k, v), blue_color, font=font) | ||
294 | last_y += size | ||
295 | |||
296 | img_pil.save(os.path.join(save_dir, img_name)) | ||
297 | |||
298 | # break | ||
299 | |||
300 | |||
301 | |||
302 | ... | ... |
train.sh
100644 → 100755
-
Please register or sign in to post a comment