From 59cbfab2302f8f30cec976aee85d93188e7450fa Mon Sep 17 00:00:00 2001
From: 周伟奇 <zhouweiqi@situdata.com>
Date: Fri, 16 Oct 2020 17:31:38 +0800
Subject: [PATCH] fix bug & add skip_img_sheet

---
 .gitignore                                          |   3 ++-
 src/apps/doc/consts.py                              |   2 ++
 src/apps/doc/management/commands/doc_ocr_process.py | 280 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++----------------------------------------------------------------------------------------------------------------------------------------
 src/apps/doc/ocr/wb.py                              |  56 +++++++++++++++++++++++++++-----------------------------
 4 files changed, 175 insertions(+), 166 deletions(-)

diff --git a/.gitignore b/.gitignore
index 5d7f9ee..a8919aa 100644
--- a/.gitignore
+++ b/.gitignore
@@ -33,4 +33,5 @@ data/*
 # 脚本
 src/*.sh
 
-test*
\ No newline at end of file
+test*
+ocr_test.py
\ No newline at end of file
diff --git a/src/apps/doc/consts.py b/src/apps/doc/consts.py
index 7849036..356549c 100644
--- a/src/apps/doc/consts.py
+++ b/src/apps/doc/consts.py
@@ -60,6 +60,8 @@ TRANS_MAP = {
 }
 TRANS = str.maketrans(TRANS_MAP)
 ERROR_CHARS = {'.', '·', '•'}
+SKIP_IMG_SHEET_NAME = '未处理图片'
+SKIP_IMG_SHEET_HEADER = ('页码', '序号')
 
 CARD_RATIO = 0.9
 UNKNOWN_CARD = '未知卡号'
diff --git a/src/apps/doc/management/commands/doc_ocr_process.py b/src/apps/doc/management/commands/doc_ocr_process.py
index dc6fe62..586e87e 100644
--- a/src/apps/doc/management/commands/doc_ocr_process.py
+++ b/src/apps/doc/management/commands/doc_ocr_process.py
@@ -80,19 +80,20 @@ class Command(BaseCommand, LoggerMixin):
             self.log_base, business_type, doc.id, pdf_path))
         return doc_data_path, excel_path, src_excel_path, pdf_path
 
-    @staticmethod
-    def bs_process(wb, ocr_data, bs_summary, unknown_summary, img_path, classify):
+    def bs_process(self, wb, ocr_data, bs_summary, unknown_summary, img_path, classify, skip_img):
         sheets = ocr_data.get('data', [])
         if not sheets:
+            skip_img.append(self.parse_img_path(img_path))
             return
         confidence = ocr_data.get('confidence', 1)
         img_name, _ = os.path.splitext(os.path.basename(img_path))
         for i, sheet in enumerate(sheets):
-            sheet_name = '{0}_{1}'.format(img_name, i)
-            ws = wb.create_sheet(sheet_name)
             cells = sheet.get('cells')
             if not cells:
+                skip_img.append(self.parse_img_path(img_path))
                 continue
+            sheet_name = '{0}_{1}'.format(img_name, i)
+            ws = wb.create_sheet(sheet_name)
             for cell in cells:
                 c1 = cell.get('start_column')
                 r1 = cell.get('start_row')
@@ -147,9 +148,10 @@ class Command(BaseCommand, LoggerMixin):
                     ed_list.append(summary[6])
 
     @staticmethod
-    def license1_process(ocr_data, license_summary, classify):
+    def license1_process(ocr_data, license_summary, classify, skip_img, img_path):
         license_data = ocr_data.get('data', [])
         if not license_data:
+            skip_img.append(img_path)
             return
         for license_dict in license_data:
             res_list = []
@@ -157,8 +159,7 @@ class Command(BaseCommand, LoggerMixin):
                 res_list.append((field, value))
             license_summary.setdefault(classify, []).append(res_list)
 
-    @staticmethod
-    def license2_process(ocr_res_2, license_summary, pid, classify):
+    def license2_process(self, ocr_res_2, license_summary, pid, classify, skip_img, img_path):
         if ocr_res_2.get('ErrorCode') in consts.SUCCESS_CODE_SET:
             if pid == consts.BC_PID:
                 # 银行卡
@@ -174,113 +175,16 @@ class Command(BaseCommand, LoggerMixin):
                         res_list.append(
                             (field_dict.get('chn_key', ''), field_dict.get('value', '')))
                     license_summary.setdefault(classify, []).append(res_list)
-
-    async def fetch_ocr_result(self, url, json_data):
-        async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl=False)) as session:
-            async with session.post(url, json=json_data) as response:
-                if response.status == 200:
-                    return await response.json()
-
-    # async def img_2_ocr_2_wb(self, wb, img_path, summary):
-    #     res = await self.fetch_ocr_result(img_path)
-    #     self.cronjob_log.info('{0} [fetch ocr result success] [img={1}] [res={2}]'.format(self.log_base, img_path, res))
-    #     sheets_list = res.get('result').get('res')
-    #     img_name = os.path.basename(img_path)
-    #     self.append_sheet(wb, sheets_list, img_name, summary)
-
-    async def img_2_ocr_2_wb(self, wb, img_path, bs_summary, unknown_summary, license_summary):
-        with open(img_path, 'rb') as f:
-            base64_data = base64.b64encode(f.read())
-            # 获取解码后的base64值
-            file_data = base64_data.decode()
-        json_data_1 = {
-            "file": file_data
-        }
-        ocr_res_1 = await self.fetch_ocr_result(self.ocr_url_1, json_data_1)
-        if ocr_res_1 is None:
-            raise Exception('ocr 1 error, img_path={0}'.format(img_path))
         else:
-            self.cronjob_log.info('{0} [ocr_1 result] [img={1}] [res={2}]'.format(
-                self.log_base, img_path, ocr_res_1))
-
-            if ocr_res_1.get('code') == 1:
-                ocr_data = ocr_res_1.get('data', {})
-                classify = ocr_data.get('classify')
-                if classify is None:
-                    return
-                elif classify in consts.OTHER_CLASSIFY_SET:  # 其他类
-                    return
-                elif classify in consts.LICENSE_CLASSIFY_SET_1:  # 证件1
-                    self.license1_process(ocr_data, license_summary, classify)
-                elif classify in consts.LICENSE_CLASSIFY_SET_2:  # 证件2
-                    pid, _ = consts.LICENSE_CLASSIFY_MAPPING.get(classify)
-                    json_data_2 = {
-                        "pid": str(pid),
-                        "key": conf.OCR_KEY,
-                        "secret": conf.OCR_SECRET,
-                        "file": file_data
-                    }
-                    ocr_res_2 = await self.fetch_ocr_result(self.ocr_url_2, json_data_2)
-                    if ocr_res_2 is None:
-                        raise Exception('ocr 2 error, img_path={0}'.format(img_path))
-                    else:
-                        # 识别结果
-                        self.cronjob_log.info('{0} [ocr_2 result] [img={1}] [res={2}]'.format(
-                            self.log_base, img_path, ocr_res_2))
-                        self.license2_process(ocr_res_2, license_summary, pid, classify)
-                else:  # 流水处理
-                    self.bs_process(wb, ocr_data, bs_summary, unknown_summary, img_path, classify)
+            skip_img.append(self.parse_img_path(img_path))
 
-    # def img_2_ocr_2_wb(self, wb, img_path, bs_summary, unknown_summary, license_summary):
-    #     # # 流水
-    #     # res = {
-    #     #     'code': 1,
-    #     #     'msg': 'success',
-    #     #     'data': {
-    #     #         'classify': 0,
-    #     #         'confidence': 0.999,
-    #     #         'data': [
-    #     #             {
-    #     #                 'summary': ['户名', '卡号', '页码', '回单验证码', '打印时间', '起始时间', '终止时间'],
-    #     #                 'cells': []
-    #     #             },
-    #     #             {
-    #     #                 'summary': ['户名', '卡号', '页码', '回单验证码', '打印时间', '起始时间', '终止时间'],
-    #     #                 'cells': []
-    #     #             }
-    #     #         ]
-    #     #     }
-    #     # }
-    #     #
-    #     # # 证件-1
-    #     # res = {
-    #     #     'code': 1,
-    #     #     'msg': 'success',
-    #     #     'data': {
-    #     #         'classify': 0,
-    #     #         'confidence': 0.999,
-    #     #         'data': [
-    #     #             {
-    #     #                 'cn_key': 'value',
-    #     #                 'cn_key': 'value',
-    #     #             },
-    #     #             {
-    #     #                 'cn_key': 'value',
-    #     #                 'cn_key': 'value',
-    #     #             },
-    #     #         ]
-    #     #     }
-    #     # }
-    #     #
-    #     # # 证件-2 or 其他类
-    #     # res = {
-    #     #     'code': 1,
-    #     #     'msg': 'success',
-    #     #     'data': {
-    #     #         'classify': 0,
-    #     #         'confidence': 0.999,
-    #     #     }
-    #     # }
+    # async def fetch_ocr_result(self, url, json_data):
+    #     async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl=False)) as session:
+    #         async with session.post(url, json=json_data) as response:
+    #             if response.status == 200:
+    #                 return await response.json()
+    #
+    # async def img_2_ocr_2_wb(self, wb, img_path, bs_summary, unknown_summary, license_summary):
     #     with open(img_path, 'rb') as f:
     #         base64_data = base64.b64encode(f.read())
     #         # 获取解码后的base64值
@@ -288,9 +192,10 @@ class Command(BaseCommand, LoggerMixin):
     #     json_data_1 = {
     #         "file": file_data
     #     }
-    #     response_1 = requests.post(self.ocr_url_1, json=json_data_1)
-    #     if response_1.status_code == 200:
-    #         ocr_res_1 = response_1.json()
+    #     ocr_res_1 = await self.fetch_ocr_result(self.ocr_url_1, json_data_1)
+    #     if ocr_res_1 is None:
+    #         raise Exception('ocr 1 error, img_path={0}'.format(img_path))
+    #     else:
     #         self.cronjob_log.info('{0} [ocr_1 result] [img={1}] [res={2}]'.format(
     #             self.log_base, img_path, ocr_res_1))
     #
@@ -311,21 +216,119 @@ class Command(BaseCommand, LoggerMixin):
     #                     "secret": conf.OCR_SECRET,
     #                     "file": file_data
     #                 }
-    #                 response_2 = requests.post(self.ocr_url_2, data=json_data_2)
-    #                 if response_2.status_code == 200:
+    #                 ocr_res_2 = await self.fetch_ocr_result(self.ocr_url_2, json_data_2)
+    #                 if ocr_res_2 is None:
+    #                     raise Exception('ocr 2 error, img_path={0}'.format(img_path))
+    #                 else:
     #                     # 识别结果
-    #                     ocr_res_2 = response_2.json()
     #                     self.cronjob_log.info('{0} [ocr_2 result] [img={1}] [res={2}]'.format(
     #                         self.log_base, img_path, ocr_res_2))
     #                     self.license2_process(ocr_res_2, license_summary, pid, classify)
-    #                 else:
-    #                     raise Exception('ocr 2 error, img_path={0}'.format(img_path))
     #             else:  # 流水处理
     #                 self.bs_process(wb, ocr_data, bs_summary, unknown_summary, img_path, classify)
-    #         else:
-    #             pass
-    #     else:
-    #         raise Exception('ocr 1 error, img_path={0}'.format(img_path))
+
+    def img_2_ocr_2_wb(self, wb, img_path, bs_summary, unknown_summary, license_summary, skip_img):
+        # # 流水
+        # res = {
+        #     'code': 1,
+        #     'msg': 'success',
+        #     'data': {
+        #         'classify': 0,
+        #         'confidence': 0.999,
+        #         'data': [
+        #             {
+        #                 'summary': ['户名', '卡号', '页码', '回单验证码', '打印时间', '起始时间', '终止时间'],
+        #                 'cells': []
+        #             },
+        #             {
+        #                 'summary': ['户名', '卡号', '页码', '回单验证码', '打印时间', '起始时间', '终止时间'],
+        #                 'cells': []
+        #             }
+        #         ]
+        #     }
+        # }
+        #
+        # # 证件-1
+        # res = {
+        #     'code': 1,
+        #     'msg': 'success',
+        #     'data': {
+        #         'classify': 0,
+        #         'confidence': 0.999,
+        #         'data': [
+        #             {
+        #                 'cn_key': 'value',
+        #                 'cn_key': 'value',
+        #             },
+        #             {
+        #                 'cn_key': 'value',
+        #                 'cn_key': 'value',
+        #             },
+        #         ]
+        #     }
+        # }
+        #
+        # # 证件-2 or 其他类
+        # res = {
+        #     'code': 1,
+        #     'msg': 'success',
+        #     'data': {
+        #         'classify': 0,
+        #         'confidence': 0.999,
+        #     }
+        # }
+        with open(img_path, 'rb') as f:
+            base64_data = base64.b64encode(f.read())
+            # 获取解码后的base64值
+            file_data = base64_data.decode()
+        json_data_1 = {
+            "file": file_data
+        }
+        response_1 = requests.post(self.ocr_url_1, json=json_data_1)
+        if response_1.status_code == 200:
+            ocr_res_1 = response_1.json()
+            self.cronjob_log.info('{0} [ocr_1 result] [img={1}] [res={2}]'.format(
+                self.log_base, img_path, ocr_res_1))
+
+            if ocr_res_1.get('code') == 1:
+                ocr_data = ocr_res_1.get('data', {})
+                classify = ocr_data.get('classify')
+                if classify is None:
+                    skip_img.append(self.parse_img_path(img_path))
+                    return
+                elif classify in consts.OTHER_CLASSIFY_SET:  # 其他类
+                    skip_img.append(self.parse_img_path(img_path))
+                    return
+                elif classify in consts.LICENSE_CLASSIFY_SET_1:  # 证件1
+                    self.license1_process(ocr_data, license_summary, classify, skip_img, img_path)
+                elif classify in consts.LICENSE_CLASSIFY_SET_2:  # 证件2
+                    pid, _ = consts.LICENSE_CLASSIFY_MAPPING.get(classify)
+                    json_data_2 = {
+                        "pid": str(pid),
+                        "key": conf.OCR_KEY,
+                        "secret": conf.OCR_SECRET,
+                        "file": file_data
+                    }
+                    response_2 = requests.post(self.ocr_url_2, data=json_data_2)
+                    if response_2.status_code == 200:
+                        # 识别结果
+                        ocr_res_2 = response_2.json()
+                        self.cronjob_log.info('{0} [ocr_2 result] [img={1}] [res={2}]'.format(
+                            self.log_base, img_path, ocr_res_2))
+                        self.license2_process(ocr_res_2, license_summary, pid, classify, skip_img, img_path)
+                    else:
+                        raise Exception('ocr 2 error, img_path={0}'.format(img_path))
+                else:  # 流水处理
+                    self.bs_process(wb, ocr_data, bs_summary, unknown_summary, img_path, classify)
+            else:
+                skip_img.append(self.parse_img_path(img_path))
+        else:
+            raise Exception('ocr 1 error, img_path={0}'.format(img_path))
+
+    @staticmethod
+    def parse_img_path(img_path):
+        img_name, _ = os.path.splitext(os.path.basename(img_path))
+        return img_name[5], img_name[11]
 
     @staticmethod
     def get_most(value_list):
@@ -425,8 +428,10 @@ class Command(BaseCommand, LoggerMixin):
                     merged_bs_summary[card] = summary
         else:
             # 1卡号
+            one_card = False
             if len(bs_summary) == 1:
                 merged_bs_summary = self.prune_bs_summary(bs_summary)
+                one_card = True
             # 多卡号
             else:
                 merged_bs_summary = self.merge_card(bs_summary)
@@ -435,7 +440,7 @@ class Command(BaseCommand, LoggerMixin):
                 merge_role = []
                 classify_summary = unknown_summary.get(card_summary['classify'], {})
                 for role, summary in classify_summary.items():
-                    if role in card_summary['role_set']:
+                    if one_card or role in card_summary['role_set']:
                         merge_role.append(role)
                         card_summary['sheet'].extend(summary['sheet'])
                         card_summary['code'].extend(summary['code'])
@@ -503,6 +508,7 @@ class Command(BaseCommand, LoggerMixin):
                 bs_summary = {}
                 license_summary = {}
                 unknown_summary = {}
+                skip_img = []
                 interest_keyword = Keywords.objects.filter(
                     type=KeywordsType.INTEREST.value, on_off=True).values_list('keyword', flat=True)
                 salary_keyword = Keywords.objects.filter(
@@ -515,27 +521,29 @@ class Command(BaseCommand, LoggerMixin):
                 # wb = Workbook()
 
                 # 4.1 获取OCR结果
-                loop = asyncio.get_event_loop()
-                tasks = [self.img_2_ocr_2_wb(wb, img_path, bs_summary, unknown_summary, license_summary)
-                         for img_path in pdf_handler.img_path_list]
-                loop.run_until_complete(asyncio.wait(tasks))
+                # loop = asyncio.get_event_loop()
+                # tasks = [self.img_2_ocr_2_wb(wb, img_path, bs_summary, unknown_summary, license_summary)
+                #          for img_path in pdf_handler.img_path_list]
+                # loop.run_until_complete(asyncio.wait(tasks))
                 # loop.close()
 
-                # for img_path in pdf_handler.img_path_list:
-                #     self.img_2_ocr_2_wb(wb, img_path, bs_summary, unknown_summary, license_summary)
+                for img_path in pdf_handler.img_path_list:
+                    self.img_2_ocr_2_wb(wb, img_path, bs_summary, unknown_summary, license_summary, skip_img)
 
-                self.cronjob_log.info('{0} [bs_summary={1}] [unknown_summary={2}] [license_summary={3}]'.format(
-                    self.log_base, bs_summary, unknown_summary, license_summary))
+                self.cronjob_log.info('{0} [business_type={1}] [doc_id={2}] [bs_summary={3}] [unknown_summary={4}] '
+                                      '[license_summary={5}]'.format(self.log_base, business_type, doc.id, bs_summary,
+                                                                     unknown_summary, license_summary))
 
                 merged_bs_summary = self.rebuild_bs_summary(bs_summary, unknown_summary)
 
-                self.cronjob_log.info('{0} [merged_bs_summary={1}] [unknown_summary={2}]'.format(
-                    self.log_base, merged_bs_summary, unknown_summary))
+                self.cronjob_log.info('{0} [business_type={1}] [doc_id={2}] [merged_bs_summary={3}] '
+                                      '[unknown_summary={4}]'.format(self.log_base, business_type, doc.id,
+                                                                     merged_bs_summary, unknown_summary))
                 del unknown_summary
 
                 # 4.2 重构Excel文件
                 wb.save(src_excel_path)
-                wb.rebuild(merged_bs_summary, license_summary)
+                wb.rebuild(merged_bs_summary, license_summary, skip_img)
                 wb.save(excel_path)
             except Exception as e:
                 doc.status = DocStatus.PROCESS_FAILED.value
diff --git a/src/apps/doc/ocr/wb.py b/src/apps/doc/ocr/wb.py
index 7ab073b..a949579 100644
--- a/src/apps/doc/ocr/wb.py
+++ b/src/apps/doc/ocr/wb.py
@@ -141,32 +141,22 @@ class BSWorkbook(Workbook):
                 # month_info process
                 month_info = month_mapping.setdefault('xxxx-xx', [])
                 month_info.append((ws.title, min_row, ws.max_row, 0))
-            elif len(month_list) == 1:
-                # reverse_trend_list process
-                reverse_trend = self.get_reverse_trend(dti.day, idx_list)
-                reverse_trend_list.append(reverse_trend)
-                # month_info process
-                month_info = month_mapping.setdefault(month_list[0], [])
-                day_mean = np.mean(dti.day.dropna())
-                if len(month_info) == 0:
-                    month_info.append((ws.title, min_row, ws.max_row, day_mean))
-                else:
-                    for i, item in enumerate(month_info):
-                        if day_mean <= item[-1]:
-                            month_info.insert(i, (ws.title, min_row, ws.max_row, day_mean))
-                            break
-                    else:
-                        month_info.append((ws.title, min_row, ws.max_row, day_mean))
             else:
                 # reverse_trend_list process
                 reverse_trend = self.get_reverse_trend(dti.day, idx_list)
                 reverse_trend_list.append(reverse_trend)
                 # month_info process
-                for i, item in enumerate(month_list[:-1]):
-                    month_mapping.setdefault(item, []).append(
-                        (ws.title, idx_list[i] + min_row, idx_list[i + 1] + min_row - 1, self.MAX_MEAN))
-                month_mapping.setdefault(month_list[-1], []).insert(
-                    0, (ws.title, idx_list[-1] + min_row, ws.max_row, 0))
+                day_idx = dti.day
+                idx_list_max_idx = len(idx_list) - 1
+                for i, item in enumerate(month_list):
+                    if i == idx_list_max_idx:
+                        day_mean = np.mean(day_idx[idx_list[i]:].dropna())
+                        month_mapping.setdefault(item, []).append(
+                            (ws.title, idx_list[i] + min_row, ws.max_row, day_mean))
+                    else:
+                        day_mean = np.mean(day_idx[idx_list[i]: idx_list[i + 1]].dropna())
+                        month_mapping.setdefault(item, []).append(
+                            (ws.title, idx_list[i] + min_row, idx_list[i + 1] + min_row - 1, day_mean))
 
     def build_metadata_rows(self, confidence, code, print_time, start_date, end_date):
         if start_date is None or end_date is None:
@@ -259,7 +249,7 @@ class BSWorkbook(Workbook):
                 except Exception as e:
                     continue
                 else:
-                    over_cell.number_format = numbers.FORMAT_NUMBER_COMMA_SEPARATED1
+                    over_cell.number_format = numbers.FORMAT_GENERAL
 
                 # 3.4.金额转数值
                 try:
@@ -281,7 +271,7 @@ class BSWorkbook(Workbook):
                 else:
                     if rows[consts.BORROW_IDX].value in consts.BORROW_OUTLAY_SET:
                         amount_cell.value = -amount_cell.value
-                    amount_cell.number_format = numbers.FORMAT_NUMBER_COMMA_SEPARATED1
+                    amount_cell.number_format = numbers.FORMAT_GENERAL
                     same_amount_mapping = amount_mapping.get(date_cell.value, {})
                     fill_rows = same_amount_mapping.get(-amount_cell.value)
                     if fill_rows:
@@ -357,11 +347,11 @@ class BSWorkbook(Workbook):
                                        end_date)
 
             # 3.创建月份表、提取/高亮关键行
-            is_reverse = False
-            if sum(reverse_trend_list) > 0:  # 倒序处理
-                is_reverse = True
-                for month_list in month_mapping.values():
-                    month_list.sort(key=lambda x: x[-1], reverse=True)
+            # 倒序处理
+            is_reverse = True if sum(reverse_trend_list) > 0 else False
+            for month_list in month_mapping.values():
+                month_list.sort(key=lambda x: x[-1], reverse=is_reverse)
+
             self.build_month_sheet(card, month_mapping, ms, is_reverse)
 
             # 4.删除原表
@@ -379,6 +369,14 @@ class BSWorkbook(Workbook):
                     ws.append(bl_field)
                 ws.append((None, ))
 
-    def rebuild(self, bs_summary, license_summary):
+    def skip_img_sheet(self, skip_img):
+        if skip_img:
+            ws = self.create_sheet(consts.SKIP_IMG_SHEET_NAME)
+            ws.append(consts.SKIP_IMG_SHEET_HEADER)
+            for img_tuple in skip_img:
+                ws.append(img_tuple)
+
+    def rebuild(self, bs_summary, license_summary, skip_img):
         self.bs_rebuild(bs_summary)
         self.license_rebuild(license_summary)
+        self.skip_img_sheet(skip_img)
--
libgit2 0.24.0