wb.py 16.4 KB
import locale
import numpy as np
from pandas._libs import tslib
from pandas._libs.tslibs.nattype import NaTType
from pandas.core.indexes.datetimes import DatetimeIndex
from openpyxl import Workbook
from openpyxl.styles import Border, Side, PatternFill, numbers
from openpyxl.utils import get_column_letter
from apps.doc import consts


class BSWorkbook(Workbook):

    def __init__(self, interest_keyword, salary_keyword, loan_keyword, *args, **kwargs):
        super().__init__(*args, **kwargs)
        locale.setlocale(locale.LC_NUMERIC, 'en_US.UTF-8')
        self.meta_sheet_title = '关键信息提取和展示'
        self.blank_row = (None,)
        self.code_header = ('页数', '电子回单验证码')
        self.date_header = ('打印时间', '起始日期', '终止日期', '流水区间结果')
        self.keyword_header = ('关键词', '记账日期', '金额')
        self.interest_keyword = interest_keyword
        self.salary_keyword = salary_keyword
        self.loan_keyword = loan_keyword
        self.proof_res = ('对', '错')
        self.loan_fill = PatternFill("solid", fgColor="00FFCC00")
        self.amount_fill = PatternFill("solid", fgColor="00FFFF00")
        # self.bd = Side(style='thin', color="000000")
        # self.border = Border(left=self.bd, top=self.bd, right=self.bd, bottom=self.bd)
        self.MAX_MEAN = 31

    @staticmethod
    def sheet_prune(ws, classify):
        ws.insert_cols(1, amount=consts.FIXED_COL_AMOUNT)
        moved_col_set = set()
        header_col_set = set()
        # 根据第一行关键词排列
        for col in range(consts.FIXED_COL_AMOUNT + 1, ws.max_column + 1):
            header_value = ws.cell(1, col).value
            header_col = consts.HEADERS_MAPPING.get(header_value)
            if header_col is not None and header_col not in header_col_set:
                letter = get_column_letter(col)
                ws.move_range("{0}1:{0}{1}".format(letter, ws.max_row), cols=header_col - col)
                moved_col_set.add(col)
                header_col_set.add(header_col)
            elif header_value in consts.BORROW_HEADERS_SET:
                letter = get_column_letter(col)
                ws.move_range("{0}1:{0}{1}".format(letter, ws.max_row), cols=consts.BORROW_HEADER_COL - col)
                moved_col_set.add(col)
                header_col_set.add(consts.BORROW_HEADER_COL)
            elif header_value in consts.INCOME_HEADERS_SET:
                letter = get_column_letter(col)
                ws.move_range("{0}1:{0}{1}".format(letter, ws.max_row), cols=consts.INCOME_HEADER_COL - col)
                moved_col_set.add(col)
                header_col_set.add(consts.INCOME_HEADER_COL)
            elif header_value in consts.OUTLAY_HEADERS_SET:
                letter = get_column_letter(col)
                ws.move_range("{0}1:{0}{1}".format(letter, ws.max_row), cols=consts.OUTLAY_HEADER_COL - col)
                moved_col_set.add(col)
                header_col_set.add(consts.OUTLAY_HEADER_COL)

        # 缺失表头再次查找
        for header_col in range(1, consts.FIXED_COL_AMOUNT + 1):
            if header_col in header_col_set or header_col == consts.RESULT_HEADER_COL:
                continue
            fix_col = consts.CLASSIFY_LIST[classify][1][header_col - 1]
            if fix_col is None:
                continue
            fix_col = fix_col + consts.FIXED_COL_AMOUNT
            if fix_col in moved_col_set:
                break
            letter = get_column_letter(fix_col)
            ws.move_range("{0}1:{0}{1}".format(letter, ws.max_row), cols=header_col - fix_col)

        ws.delete_cols(consts.FIXED_COL_AMOUNT + 1, amount=ws.max_column)
        min_row = 1 if len(moved_col_set) == 0 else 2
        return min_row

    @staticmethod
    def month_split(dti, date_list, date_statistics):
        month_list = []
        idx_list = []
        month_pre = None
        for idx, month_str in enumerate(dti.strftime('%Y-%m')):
            if isinstance(month_str, float):
                continue
            if month_str != month_pre:
                month_list.append(month_str)
                if month_pre is None:
                    if date_statistics:
                        date_list.append(dti[idx].date())
                    idx = 0
                idx_list.append(idx)
                month_pre = month_str
        if date_statistics:
            for idx in range(len(dti) - 1, -1, -1):
                if isinstance(dti[idx], NaTType):
                    continue
                date_list.append(dti[idx].date())
                break
        return month_list, idx_list

    @staticmethod
    def get_reverse_trend(day_idx, idx_list):
        reverse_trend = 0
        pre_day = None
        for idx, day in enumerate(day_idx):
            if np.isnan(day):
                continue
            if idx in idx_list or pre_day is None:
                pre_day = day
                continue
            if day < pre_day:
                reverse_trend += 1
                pre_day = day
            elif day > pre_day:
                reverse_trend -= 1
                pre_day = day
        if reverse_trend > 0:
            reverse_trend = 1
        elif reverse_trend < 0:
            reverse_trend = -1
        return reverse_trend

    def sheet_split(self, ws, month_mapping, reverse_trend_list, min_row, date_list, date_statistics):
        for date_tuple_src in ws.iter_cols(min_col=1, max_col=1, min_row=min_row, values_only=True):
            date_tuple = [date[:10] if isinstance(date, str) else date for date in date_tuple_src]
            dt_array, tz_parsed = tslib.array_to_datetime(
                np.array(date_tuple, copy=False, dtype=np.object_),
                errors="coerce",
                utc=False,
                dayfirst=False,
                yearfirst=False,
                require_iso8601=True,
            )
            dti = DatetimeIndex(dt_array, tz=None, name=None)

            month_list, idx_list = self.month_split(dti, date_list, date_statistics)

            if len(month_list) == 0:
                # month_info process
                month_info = month_mapping.setdefault('xxxx-xx', [])
                month_info.append((ws.title, min_row, ws.max_row, 0))
            elif len(month_list) == 1:
                # reverse_trend_list process
                reverse_trend = self.get_reverse_trend(dti.day, idx_list)
                reverse_trend_list.append(reverse_trend)
                # month_info process
                month_info = month_mapping.setdefault(month_list[0], [])
                day_mean = np.mean(dti.day.dropna())
                if len(month_info) == 0:
                    month_info.append((ws.title, min_row, ws.max_row, day_mean))
                else:
                    for i, item in enumerate(month_info):
                        if day_mean <= item[-1]:
                            month_info.insert(i, (ws.title, min_row, ws.max_row, day_mean))
                            break
                    else:
                        month_info.append((ws.title, min_row, ws.max_row, day_mean))
            else:
                # reverse_trend_list process
                reverse_trend = self.get_reverse_trend(dti.day, idx_list)
                reverse_trend_list.append(reverse_trend)
                # month_info process
                for i, item in enumerate(month_list[:-1]):
                    month_mapping.setdefault(item, []).append(
                        (ws.title, idx_list[i] + min_row, idx_list[i + 1] + min_row - 1, self.MAX_MEAN))
                month_mapping.setdefault(month_list[-1], []).insert(
                    0, (ws.title, idx_list[-1] + min_row, ws.max_row, 0))

    def build_metadata_rows(self, confidence, code, print_time, start_date, end_date):
        if start_date is None or end_date is None:
            timedelta = None
        else:
            timedelta = (end_date - start_date).days
        metadata_rows = [
            ('流水识别置信度', confidence),
            self.blank_row,
            self.code_header,
        ]
        metadata_rows.extend(code)
        metadata_rows.extend(
            [self.blank_row,
             self.date_header,
             (print_time, start_date, end_date, timedelta),
             self.blank_row,
             self.keyword_header]
        )
        return metadata_rows

    def create_meta_sheet(self, card):
        if self.worksheets[0].title == 'Sheet':
            ms = self.worksheets[0]
            ms.title = '{0}({1})'.format(self.meta_sheet_title, card[-6:])
        else:
            ms = self.create_sheet('{0}({1})'.format(self.meta_sheet_title, card[-6:]))
        return ms

    def build_meta_sheet(self, card, confidence, code, print_time, start_date, end_date):
        metadata_rows = self.build_metadata_rows(confidence, code, print_time, start_date, end_date)
        ms = self.create_meta_sheet(card)
        for row in metadata_rows:
            ms.append(row)
        return ms

    @staticmethod
    def amount_format(amount_str):
        if not isinstance(amount_str, str) or amount_str == '':
            return amount_str
        # 替换
        res_str = amount_str.translate(consts.TRANS)
        # 删除多余的-
        res_str = res_str[0] + res_str[1:].replace('-', '')
        # TODO 逗号与句号处理
        return res_str

    def build_month_sheet(self, card, month_mapping, ms, is_reverse):
        tmp_ws = self.create_sheet('tmp_ws')
        for month in sorted(month_mapping.keys()):
            # 3.1.拷贝数据
            parts = month_mapping.get(month)
            new_ws = self.create_sheet('{0}({1})'.format(month, card[-6:]))
            new_ws.append(consts.FIXED_HEADERS)
            for part in parts:
                ws = self.get_sheet_by_name(part[0])
                for row_value in ws.iter_rows(min_row=part[1], max_row=part[2], values_only=True):
                    new_ws.append(row_value)
            # 3.2.提取信息、高亮
            amount_mapping = {}
            amount_fill_row = set()
            for rows in new_ws.iter_rows(min_row=2):
                summary_cell = rows[consts.SUMMARY_IDX]
                date_cell = rows[consts.DATE_IDX]
                amount_cell = rows[consts.AMOUNT_IDX]
                row = summary_cell.row
                # 关键词1提取
                if summary_cell.value in self.interest_keyword:
                    ms.append((summary_cell.value, date_cell.value, amount_cell.value))
                # 关键词2提取至临时表
                elif summary_cell.value in self.salary_keyword:
                    tmp_ws.append((summary_cell.value, date_cell.value, amount_cell.value))
                # 贷款关键词高亮
                elif summary_cell.value in self.loan_keyword:
                    summary_cell.fill = self.loan_fill

                # 3.3.余额转数值
                over_cell = rows[consts.OVER_IDX]
                try:
                    over_cell.value = locale.atof(self.amount_format(over_cell.value))
                except Exception as e:
                    continue
                else:
                    over_cell.number_format = numbers.FORMAT_NUMBER_COMMA_SEPARATED1

                # 3.4.金额转数值
                try:
                    try:
                        amount_cell.value = locale.atof(self.amount_format(amount_cell.value))
                    except Exception as e:
                        try:
                            amount_cell.value = locale.atof(self.amount_format(rows[consts.INCOME_IDX].value))
                            if amount_cell.value == 0:
                                raise
                            elif amount_cell.value < 0:
                                amount_cell.value = -amount_cell.value
                        except Exception as e:
                            amount_cell.value = locale.atof(self.amount_format(rows[consts.OUTLAY_IDX].value))
                            if amount_cell.value > 0:
                                amount_cell.value = -amount_cell.value
                except Exception as e:
                    continue
                else:
                    if rows[consts.BORROW_IDX].value in consts.BORROW_OUTLAY_SET:
                        amount_cell.value = -amount_cell.value
                    amount_cell.number_format = numbers.FORMAT_NUMBER_COMMA_SEPARATED1
                    same_amount_mapping = amount_mapping.get(date_cell.value, {})
                    fill_rows = same_amount_mapping.get(-amount_cell.value)
                    if fill_rows:
                        amount_fill_row.add(row)
                        amount_fill_row.update(fill_rows)
                    amount_mapping.setdefault(date_cell.value, {}).setdefault(
                        amount_cell.value, []).append(row)

                # 3.5.核对结果
                if row > 2:
                    if is_reverse:
                        rows[consts.RESULT_IDX].value = '=IF(D{0}=SUM(D{1},C{0}), "{2}", "{3}")'.format(
                            row - 1, row, *self.proof_res)
                    else:
                        rows[consts.RESULT_IDX].value = '=IF(D{0}=SUM(D{1},C{0}), "{2}", "{3}")'.format(
                            row, row - 1, *self.proof_res)

            # 删除金额辅助列
            new_ws.delete_cols(consts.BORROW_HEADER_COL, amount=new_ws.max_column)

            # 3.6.同一天相同进出账高亮
            del amount_mapping
            for row in amount_fill_row:
                new_ws[row][consts.AMOUNT_IDX].fill = self.amount_fill

        # 关键词2信息提取
        ms.append(self.blank_row)
        ms.append(self.keyword_header)
        for row in tmp_ws.iter_rows(values_only=True):
            ms.append(row)
        self.remove(tmp_ws)

    def bs_rebuild(self, bs_summary):
        # bs_summary = {
        #     '卡号': {
        #         'classify': 0,
        #         'confidence': 0.9,
        #         'role': '柳雪',
        #         'code': [('page', 'code')],
        #         'print_time': 'datetime',
        #         'start_date': 'datetime',
        #         'end_date': 'datetime',
        #         'sheet': ['sheet_name']
        #     }
        # }
        for card, summary in bs_summary.items():
            # 1.原表修剪、排列、按照月份分割
            start_date = summary.get('start_date')
            end_date = summary.get('end_date')
            date_statistics = False
            if start_date is None or end_date is None:
                date_statistics = True
            date_list = []
            month_mapping = {}
            reverse_trend_list = []
            for sheet in summary.get('sheet', []):
                ws = self.get_sheet_by_name(sheet)
                # 1.1.删除多余列、排列
                min_row = self.sheet_prune(ws, summary.get('classify', 0))
                # 1.2.按月份分割
                self.sheet_split(ws, month_mapping, reverse_trend_list, min_row, date_list, date_statistics)

            if date_statistics is True and len(date_list) > 1:
                start_date = min(date_list) if start_date is None else start_date
                end_date = max(date_list) if end_date is None else end_date

            # 2.元信息提取表
            ms = self.build_meta_sheet(card,
                                       summary.get('confidence', 1),
                                       summary.get('code'),
                                       summary.get('print_time'),
                                       start_date,
                                       end_date)

            # 3.创建月份表、提取/高亮关键行
            is_reverse = False
            if sum(reverse_trend_list) > 0:  # 倒序处理
                is_reverse = True
                for month_list in month_mapping.values():
                    month_list.sort(key=lambda x: x[-1], reverse=True)
            self.build_month_sheet(card, month_mapping, ms, is_reverse)

            # 4.删除原表
            for sheet in summary.get('sheet'):
                self.remove(self.get_sheet_by_name(sheet))

    def license_rebuild(self, license_summary):
        for classify, (_, name) in consts.LICENSE_ORDER:
            res = license_summary.get(classify)
            if res is None:
                continue
            ws = self.create_sheet(name)
            for bl in res:
                for bl_field in bl:
                    ws.append(bl_field)
                ws.append((None, ))

    def rebuild(self, bs_summary, license_summary):
        self.bs_rebuild(bs_summary)
        self.license_rebuild(license_summary)