092baca7 by 周伟奇

fix bug

1 parent 60c39554
...@@ -206,10 +206,25 @@ class SLSolver(object): ...@@ -206,10 +206,25 @@ class SLSolver(object):
206 print('Warn: val_map_path not exists: {0}'.format(self.val_map_path)) 206 print('Warn: val_map_path not exists: {0}'.format(self.val_map_path))
207 return 207 return
208 208
209 if isinstance(self.model_path, str) and os.path.exists(self.model_path):
210 self.model.load_state_dict(torch.load(self.model_path))
211 self.logger.info(f'==> Load Model from {self.model_path}')
212 else:
213 return
214
215 self.model.eval()
216
209 map_key_input = 'x_y_valid_lens' 217 map_key_input = 'x_y_valid_lens'
210 map_key_text = 'find_top_text' 218 map_key_text = 'find_top_text'
211 map_key_value = 'find_value' 219 map_key_value = 'find_value'
212 group_cn_list = ['其他', '开票日期', '发票代码', '机打号码', '车辆类型', '电话', '发动机号码', '车架号', '帐号', '开户银行', '小写'] 220 group_cn_list = ['其他', '开票日期', '发票代码', '机打号码', '车辆类型', '电话', '发动机号码', '车架号', '帐号', '开户银行', '小写']
221 skip_list_valid = [
222 'CH-B102897920-2.jpg',
223 'CH-B102551284-0.jpg',
224 'CH-B102879376-2.jpg',
225 'CH-B101509488-page-16.jpg',
226 'CH-B102708352-2.jpg',
227 ]
213 228
214 dataset_base_dir = os.path.dirname(self.val_map_path) 229 dataset_base_dir = os.path.dirname(self.val_map_path)
215 val_dataset_dir = os.path.join(dataset_base_dir, 'valid') 230 val_dataset_dir = os.path.join(dataset_base_dir, 'valid')
...@@ -217,12 +232,13 @@ class SLSolver(object): ...@@ -217,12 +232,13 @@ class SLSolver(object):
217 if not os.path.isdir(save_dir): 232 if not os.path.isdir(save_dir):
218 os.makedirs(save_dir, exist_ok=True) 233 os.makedirs(save_dir, exist_ok=True)
219 234
220 self.model.eval()
221
222 with open(self.val_map_path, 'r') as fp: 235 with open(self.val_map_path, 'r') as fp:
223 val_map = json.load(fp) 236 val_map = json.load(fp)
224 237
225 for img_name in sorted(os.listdir(self.val_image_path)): 238 for img_name in sorted(os.listdir(self.val_image_path)):
239 if img_name in skip_list_valid:
240 continue
241
226 print('Info: start {0}'.format(img_name)) 242 print('Info: start {0}'.format(img_name))
227 image_path = os.path.join(self.val_image_path, img_name) 243 image_path = os.path.join(self.val_image_path, img_name)
228 244
...@@ -232,11 +248,11 @@ class SLSolver(object): ...@@ -232,11 +248,11 @@ class SLSolver(object):
232 draw = ImageDraw.Draw(img_pil) 248 draw = ImageDraw.Draw(img_pil)
233 249
234 if im_h < im_w: 250 if im_h < im_w:
235 size = int(im_h * 0.015) 251 size = int(im_h * 0.010)
236 else: 252 else:
237 size = int(im_w * 0.015) 253 size = int(im_w * 0.010)
238 if size < 14: 254 if size < 10:
239 size = 14 255 size = 10
240 font = ImageFont.truetype(self.draw_font_path, size, encoding='utf-8') 256 font = ImageFont.truetype(self.draw_font_path, size, encoding='utf-8')
241 257
242 green_color = (0, 255, 0) 258 green_color = (0, 255, 0)
...@@ -253,7 +269,7 @@ class SLSolver(object): ...@@ -253,7 +269,7 @@ class SLSolver(object):
253 269
254 X = torch.tensor(input_list).unsqueeze(0).to(self.device) 270 X = torch.tensor(input_list).unsqueeze(0).to(self.device)
255 y_true = torch.tensor(label_list).unsqueeze(0).float().to(self.device) 271 y_true = torch.tensor(label_list).unsqueeze(0).float().to(self.device)
256 valid_lens = torch.tenor([valid_lens_scalar, ]).to(self.device) 272 valid_lens = torch.tensor([valid_lens_scalar, ]).to(self.device)
257 del input_list 273 del input_list
258 del label_list 274 del label_list
259 275
......
Styling with Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!