fix bug
Showing
1 changed file
with
23 additions
and
7 deletions
... | @@ -206,10 +206,25 @@ class SLSolver(object): | ... | @@ -206,10 +206,25 @@ class SLSolver(object): |
206 | print('Warn: val_map_path not exists: {0}'.format(self.val_map_path)) | 206 | print('Warn: val_map_path not exists: {0}'.format(self.val_map_path)) |
207 | return | 207 | return |
208 | 208 | ||
209 | if isinstance(self.model_path, str) and os.path.exists(self.model_path): | ||
210 | self.model.load_state_dict(torch.load(self.model_path)) | ||
211 | self.logger.info(f'==> Load Model from {self.model_path}') | ||
212 | else: | ||
213 | return | ||
214 | |||
215 | self.model.eval() | ||
216 | |||
209 | map_key_input = 'x_y_valid_lens' | 217 | map_key_input = 'x_y_valid_lens' |
210 | map_key_text = 'find_top_text' | 218 | map_key_text = 'find_top_text' |
211 | map_key_value = 'find_value' | 219 | map_key_value = 'find_value' |
212 | group_cn_list = ['其他', '开票日期', '发票代码', '机打号码', '车辆类型', '电话', '发动机号码', '车架号', '帐号', '开户银行', '小写'] | 220 | group_cn_list = ['其他', '开票日期', '发票代码', '机打号码', '车辆类型', '电话', '发动机号码', '车架号', '帐号', '开户银行', '小写'] |
221 | skip_list_valid = [ | ||
222 | 'CH-B102897920-2.jpg', | ||
223 | 'CH-B102551284-0.jpg', | ||
224 | 'CH-B102879376-2.jpg', | ||
225 | 'CH-B101509488-page-16.jpg', | ||
226 | 'CH-B102708352-2.jpg', | ||
227 | ] | ||
213 | 228 | ||
214 | dataset_base_dir = os.path.dirname(self.val_map_path) | 229 | dataset_base_dir = os.path.dirname(self.val_map_path) |
215 | val_dataset_dir = os.path.join(dataset_base_dir, 'valid') | 230 | val_dataset_dir = os.path.join(dataset_base_dir, 'valid') |
... | @@ -217,12 +232,13 @@ class SLSolver(object): | ... | @@ -217,12 +232,13 @@ class SLSolver(object): |
217 | if not os.path.isdir(save_dir): | 232 | if not os.path.isdir(save_dir): |
218 | os.makedirs(save_dir, exist_ok=True) | 233 | os.makedirs(save_dir, exist_ok=True) |
219 | 234 | ||
220 | self.model.eval() | ||
221 | |||
222 | with open(self.val_map_path, 'r') as fp: | 235 | with open(self.val_map_path, 'r') as fp: |
223 | val_map = json.load(fp) | 236 | val_map = json.load(fp) |
224 | 237 | ||
225 | for img_name in sorted(os.listdir(self.val_image_path)): | 238 | for img_name in sorted(os.listdir(self.val_image_path)): |
239 | if img_name in skip_list_valid: | ||
240 | continue | ||
241 | |||
226 | print('Info: start {0}'.format(img_name)) | 242 | print('Info: start {0}'.format(img_name)) |
227 | image_path = os.path.join(self.val_image_path, img_name) | 243 | image_path = os.path.join(self.val_image_path, img_name) |
228 | 244 | ||
... | @@ -232,11 +248,11 @@ class SLSolver(object): | ... | @@ -232,11 +248,11 @@ class SLSolver(object): |
232 | draw = ImageDraw.Draw(img_pil) | 248 | draw = ImageDraw.Draw(img_pil) |
233 | 249 | ||
234 | if im_h < im_w: | 250 | if im_h < im_w: |
235 | size = int(im_h * 0.015) | 251 | size = int(im_h * 0.010) |
236 | else: | 252 | else: |
237 | size = int(im_w * 0.015) | 253 | size = int(im_w * 0.010) |
238 | if size < 14: | 254 | if size < 10: |
239 | size = 14 | 255 | size = 10 |
240 | font = ImageFont.truetype(self.draw_font_path, size, encoding='utf-8') | 256 | font = ImageFont.truetype(self.draw_font_path, size, encoding='utf-8') |
241 | 257 | ||
242 | green_color = (0, 255, 0) | 258 | green_color = (0, 255, 0) |
... | @@ -253,7 +269,7 @@ class SLSolver(object): | ... | @@ -253,7 +269,7 @@ class SLSolver(object): |
253 | 269 | ||
254 | X = torch.tensor(input_list).unsqueeze(0).to(self.device) | 270 | X = torch.tensor(input_list).unsqueeze(0).to(self.device) |
255 | y_true = torch.tensor(label_list).unsqueeze(0).float().to(self.device) | 271 | y_true = torch.tensor(label_list).unsqueeze(0).float().to(self.device) |
256 | valid_lens = torch.tenor([valid_lens_scalar, ]).to(self.device) | 272 | valid_lens = torch.tensor([valid_lens_scalar, ]).to(self.device) |
257 | del input_list | 273 | del input_list |
258 | del label_list | 274 | del label_list |
259 | 275 | ... | ... |
-
Please register or sign in to post a comment