diff --git a/solver/sl_solver.py b/solver/sl_solver.py index 70ee08f..0172674 100644 --- a/solver/sl_solver.py +++ b/solver/sl_solver.py @@ -206,10 +206,25 @@ class SLSolver(object): print('Warn: val_map_path not exists: {0}'.format(self.val_map_path)) return + if isinstance(self.model_path, str) and os.path.exists(self.model_path): + self.model.load_state_dict(torch.load(self.model_path)) + self.logger.info(f'==> Load Model from {self.model_path}') + else: + return + + self.model.eval() + map_key_input = 'x_y_valid_lens' map_key_text = 'find_top_text' map_key_value = 'find_value' group_cn_list = ['其他', '开票日期', '发票代码', '机打号码', '车辆类型', '电话', '发动机号码', '车架号', '帐号', '开户银行', '小写'] + skip_list_valid = [ + 'CH-B102897920-2.jpg', + 'CH-B102551284-0.jpg', + 'CH-B102879376-2.jpg', + 'CH-B101509488-page-16.jpg', + 'CH-B102708352-2.jpg', + ] dataset_base_dir = os.path.dirname(self.val_map_path) val_dataset_dir = os.path.join(dataset_base_dir, 'valid') @@ -217,12 +232,13 @@ class SLSolver(object): if not os.path.isdir(save_dir): os.makedirs(save_dir, exist_ok=True) - self.model.eval() - with open(self.val_map_path, 'r') as fp: val_map = json.load(fp) for img_name in sorted(os.listdir(self.val_image_path)): + if img_name in skip_list_valid: + continue + print('Info: start {0}'.format(img_name)) image_path = os.path.join(self.val_image_path, img_name) @@ -232,11 +248,11 @@ class SLSolver(object): draw = ImageDraw.Draw(img_pil) if im_h < im_w: - size = int(im_h * 0.015) + size = int(im_h * 0.010) else: - size = int(im_w * 0.015) - if size < 14: - size = 14 + size = int(im_w * 0.010) + if size < 10: + size = 10 font = ImageFont.truetype(self.draw_font_path, size, encoding='utf-8') green_color = (0, 255, 0) @@ -253,7 +269,7 @@ class SLSolver(object): X = torch.tensor(input_list).unsqueeze(0).to(self.device) y_true = torch.tensor(label_list).unsqueeze(0).float().to(self.device) - valid_lens = torch.tenor([valid_lens_scalar, ]).to(self.device) + valid_lens = torch.tensor([valid_lens_scalar, ]).to(self.device) del input_list del label_list