doc_process.py 13.2 KB
import os
import time
import fitz
import xlwt
import signal
import base64
import asyncio
import aiohttp
from PIL import Image
from io import BytesIO

from django.core.management import BaseCommand
from common.mixins import LoggerMixin
from common.redis_cache import redis_handler as rh
from common.tools.file_tools import write_zip_file
from apps.doc.models import DocStatus, HILDoc, AFCDoc
from apps.doc import consts
from settings import conf


class Command(BaseCommand, LoggerMixin):

    def __init__(self):
        super().__init__()
        self.log_base = '[doc process]'
        # 处理文件开关
        self.switch = True
        # 数据目录
        self.data_dir = conf.DATA_DIR
        # pdf页面转图片
        self.zoom_x = 2.0
        self.zoom_y = 2.0
        self.trans = fitz.Matrix(self.zoom_x, self.zoom_y).preRotate(0)  # zoom factor 2 in each dimension
        # ocr相关
        self.ocr_url = conf.OCR_URL
        self.ocr_header = {
            'X-Auth-Token': conf.OCR_TOKEN,
            'Content-Type': 'application/json'
        }
        # 优雅退出信号:15
        signal.signal(signal.SIGTERM, self.signal_handler)

    def signal_handler(self, sig, frame):
        self.switch = False  # 停止处理文件

    def get_doc_info(self):
        task_str, is_priority = rh.dequeue()
        if task_str is None:
            self.cronjob_log.info('{0} [get_doc_info] [queue empty]'.format(self.log_base))
            return None, None, None, None

        business_type, doc_id_str = task_str.split('_')
        doc_id = int(doc_id_str)
        doc_class = HILDoc if business_type == consts.HIL_PREFIX else AFCDoc
        doc_info = doc_class.objects.filter(id=doc_id, status=DocStatus.INIT.value).values(
            'id', 'metadata_version_id', 'document_name').first()
        if doc_info is None:
            self.cronjob_log.warn('{0} [get_doc_info] [doc completed] [task_str={1}] [is_priority={2}]'.format(
                self.log_base, task_str, is_priority))
            return None, None, None, None
        doc_class.objects.filter(id=doc_id).update(status=DocStatus.PROCESSING.value)
        self.cronjob_log.info('{0} [get_doc_info] [task_str={1}] [is_priority={2}] [doc_info={3}]'.format(
            self.log_base, task_str, is_priority, doc_info))
        return doc_info, doc_class, doc_id, business_type

    def pdf_download(self, doc_id, doc_info, business_type):
        if doc_info is None:
            return None, None, None
        # TODO EDMS下载pdf
        # pdf_path = '/Users/clay/Desktop/biz/biz_logic/data/2/横版-表格-工商银行CH-B008802400.pdf'
        # doc_data_path = os.path.dirname(pdf_path)

        doc_data_path = os.path.join(self.data_dir, business_type, str(doc_id))
        pdf_path = os.path.join(doc_data_path, '{0}.pdf'.format(doc_id))
        excel_path = os.path.join(doc_data_path, '{0}.xls'.format(doc_id))
        self.cronjob_log.info('{0} [pdf download success] [business_type={1}] [doc_info={2}] [pdf_path={3}]'.format(
            self.log_base, business_type, doc_info, pdf_path))
        return doc_data_path, excel_path, pdf_path

    @staticmethod
    def append_sheet(wb, sheets_list, img_name):
        for i, sheet in enumerate(sheets_list):
            ws = wb.add_sheet('{0}_{1}'.format(img_name, i))
            cells = sheet.get('cells')
            for cell in cells:
                c1 = cell.get('start_column')
                c2 = cell.get('end_column')
                r1 = cell.get('start_row')
                r2 = cell.get('end_row')
                label = cell.get('words')
                ws.write_merge(r1, r2, c1, c2, label=label)

    @staticmethod
    def get_ocr_json(img_path):
        with open(img_path, "rb") as f:
            base64_data = base64.b64encode(f.read())
        return {'imgBase64': base64_data.decode('utf-8')}

    async def fetch_ocr_result(self, img_path):
        async with aiohttp.ClientSession(
                headers=self.ocr_header, connector=aiohttp.TCPConnector(ssl=False)
        ) as session:
            json_data = self.get_ocr_json(img_path)
            async with session.post(self.ocr_url, json=json_data) as response:
                return await response.json()

    async def img_ocr_excel(self, wb, img_path):
        res = await self.fetch_ocr_result(img_path)
        self.cronjob_log.info('{0} [fetch ocr result success] [img={1}] [res={2}]'.format(self.log_base, img_path, res))
        sheets_list = res.get('result').get('res')
        img_name = os.path.basename(img_path)
        self.append_sheet(wb, sheets_list, img_name)

    @staticmethod
    def getimage(pix):
        if pix.colorspace.n != 4:
            return pix
        tpix = fitz.Pixmap(fitz.csRGB, pix)
        return tpix

    def recoverpix(self, doc, item):
        x = item[0]  # xref of PDF image
        s = item[1]  # xref of its /SMask
        is_rgb = True if item[5] == 'DeviceRGB' else False

        # RGB
        if is_rgb:
            if s == 0:
                return doc.extractImage(x)
            # we need to reconstruct the alpha channel with the smask
            pix1 = fitz.Pixmap(doc, x)
            pix2 = fitz.Pixmap(doc, s)  # create pixmap of the /SMask entry

            # sanity check
            if not (pix1.irect == pix2.irect and pix1.alpha == pix2.alpha == 0 and pix2.n == 1):
                pix2 = None
                return self.getimage(pix1)

            pix = fitz.Pixmap(pix1)  # copy of pix1, alpha channel added
            pix.setAlpha(pix2.samples)  # treat pix2.samples as alpha value
            pix1 = pix2 = None  # free temp pixmaps
            return self.getimage(pix)

        # GRAY/CMYK
        pix1 = fitz.Pixmap(doc, x)
        pix = fitz.Pixmap(pix1)  # copy of pix1, alpha channel added

        if s != 0:
            pix2 = fitz.Pixmap(doc, s)  # create pixmap of the /SMask entry

            # sanity check
            if not (pix1.irect == pix2.irect and pix1.alpha == pix2.alpha == 0 and pix2.n == 1):
                pix2 = None
                return self.getimage(pix1)

            pix.setAlpha(pix2.samples)  # treat pix2.samples as alpha value

        pix1 = pix2 = None  # free temp pixmaps

        pix = fitz.Pixmap(fitz.csRGB, pix)  # GRAY/CMYK to RGB
        return self.getimage(pix)

    @staticmethod
    def get_img_data(pix):
        if type(pix) is dict:  # we got a raw image
            ext = pix["ext"]
            img_data = pix["image"]
        else:  # we got a pixmap
            ext = 'png'
            img_data = pix.getPNGData()
        return ext, img_data

    @staticmethod
    def split_il(il):
        img_il_list = []
        start = 0
        length = len(il)
        for i in range(length):
            if i == start:
                if i == length - 1:
                    img_il_list.append(il[start: length])
                continue
            elif i == length - 1:
                img_il_list.append(il[start: length])
                continue
            if il[i][2] != il[i - 1][2]:
                img_il_list.append(il[start: i])
                start = i
            elif il[i][3] != il[i - 1][3]:
                img_il_list.append(il[start: i + 1])
                start = i + 1
        return img_il_list

    def handle(self, *args, **kwargs):  # TODO 调用接口重试
        sleep_second = 5
        max_sleep_second = 60
        while self.switch:
            # 从队列获取文件信息
            doc_info, doc_class, doc_id, business_type = self.get_doc_info()
            # 从EDMS获取PDF文件
            doc_data_path, excel_path, pdf_path = self.pdf_download(doc_id, doc_info, business_type)
            # 队列为空时的处理
            if pdf_path is None:
                time.sleep(sleep_second)
                sleep_second = min(max_sleep_second, sleep_second+5)
                continue
            sleep_second = 5
            try:
                # PDF文件提取图片
                img_save_path = os.path.join(doc_data_path, 'img')
                os.makedirs(img_save_path, exist_ok=True)
                img_path_list = []
                with fitz.Document(pdf_path) as pdf:
                    self.cronjob_log.info('{0} [pdf_path={1}] [metadata={2}]'.format(
                        self.log_base, pdf_path, pdf.metadata))
                    # xref_list = []  # TODO 图片去重 特殊pdf:如电子发票
                    for pno in range(pdf.pageCount):
                        il = pdf.getPageImageList(pno)
                        il.sort(key=lambda x: x[0])
                        img_il_list = self.split_il(il)
                        del il

                        if len(img_il_list) > 3:  # 单页无规律小图过多时,使用页面转图片
                            page = pdf.loadPage(pno)
                            pm = page.getPixmap(matrix=self.trans, alpha=False)
                            save_path = os.path.join(img_save_path, 'page_{0}_img_0.png'.format(page.number))
                            pm.writePNG(save_path)
                            img_path_list.append(save_path)
                            self.cronjob_log.info('{0} [page to img success] [doc_id={1}] [pdf_path={2}] '
                                                  '[page={3}]'.format(self.log_base, doc_id, pdf_path, page.number))
                        else:  # 提取图片
                            for img_index, img_il in enumerate(img_il_list):
                                if len(img_il) == 1:  # 当只有一张图片时, 简化处理
                                    pix = self.recoverpix(pdf, img_il[0])
                                    ext, img_data = self.get_img_data(pix)
                                    save_path = os.path.join(img_save_path, 'page_{0}_img_{1}.{2}'.format(
                                        pno, img_index, ext))
                                    with open(save_path, "wb") as f:
                                        f.write(img_data)
                                    img_path_list.append(save_path)
                                    self.cronjob_log.info(
                                        '{0} [extract img success] [doc_id={1}] [pdf_path={2}] [page={3}] '
                                        '[img_index={4}]'.format(self.log_base, doc_id, pdf_path, pno, img_index))
                                else:  # 多张图片,竖向拼接
                                    height_sum = 0
                                    im_list = []
                                    width = img_il[0][2]
                                    for img in img_il:
                                        # xref = img[0]
                                        # if xref in xref_list:
                                        #     continue
                                        height = img[3]
                                        pix = self.recoverpix(pdf, img)
                                        ext, img_data = self.get_img_data(pix)

                                        # xref_list.append(xref)

                                        im = Image.open(BytesIO(img_data))
                                        im_list.append((height, im, ext))
                                        height_sum += height

                                    save_path = os.path.join(img_save_path, 'page_{0}_img_{1}.{2}'.format(
                                        pno, img_index, im_list[0][2]))
                                    res = Image.new(im_list[0][1].mode, (width, height_sum))
                                    h_now = 0
                                    for h, m, _ in im_list:
                                        res.paste(m, box=(0, h_now))
                                        h_now += h
                                    res.save(save_path)
                                    img_path_list.append(save_path)
                                    self.cronjob_log.info(
                                        '{0} [extract img success] [doc_id={1}] [pdf_path={2}] [page={3}] '
                                        '[img_index={4}]'.format(self.log_base, doc_id, pdf_path, pno, img_index))
                    self.cronjob_log.info('{0} [pdf to img success] [doc_id={1}]'.format(self.log_base, doc_id))

                write_zip_file(img_save_path, os.path.join(doc_data_path, '{0}_img.zip'.format(doc_id)))
                # 图片调用算法判断是否为银行流水, 图片调用算法OCR为excel文件
                wb = xlwt.Workbook()
                loop = asyncio.get_event_loop()
                tasks = [self.img_ocr_excel(wb, img_path) for img_path in img_path_list]
                loop.run_until_complete(asyncio.wait(tasks))
                # loop.close()
                wb.save(excel_path)  # TODO no sheet (res always [])
                # 整合excel文件
                # 上传至EDMS
            except Exception as e:
                doc_class.objects.filter(id=doc_id).update(status=DocStatus.PROCESS_FAILED.value)
                self.cronjob_log.error('{0} [process failed] [doc_id={1}] [err={2}]'.format(self.log_base, doc_id, e))
            else:
                doc_class.objects.filter(id=doc_id).update(status=DocStatus.COMPLETE.value)
                self.cronjob_log.info('{0} [doc process complete] [doc_id={1}]'.format(self.log_base, doc_id))

        self.cronjob_log.info('{0} [stop safely]'.format(self.log_base))