import locale
import numpy as np
from pandas._libs import tslib
from pandas._libs.tslibs.nattype import NaTType
from pandas.core.indexes.datetimes import DatetimeIndex
from openpyxl import Workbook
from openpyxl.styles import Border, Side, PatternFill, numbers
from openpyxl.utils import get_column_letter
from apps.doc import consts


class BSWorkbook(Workbook):

    def __init__(self, interest_keyword, salary_keyword, loan_keyword, *args, **kwargs):
        super().__init__(*args, **kwargs)
        locale.setlocale(locale.LC_NUMERIC, 'en_US.UTF-8')
        self.meta_sheet_title = '关键信息提取和展示'
        self.blank_row = (None,)
        self.code_header = ('页数', '电子回单验证码')
        self.date_header = ('打印时间', '起始日期', '终止日期', '流水区间结果')
        self.keyword_header = ('关键词', '记账日期', '金额')
        self.interest_keyword = interest_keyword
        self.salary_keyword = salary_keyword
        self.loan_keyword = loan_keyword
        self.proof_res = ('对', '错')
        self.loan_fill = PatternFill("solid", fgColor="00FFCC00")
        self.amount_fill = PatternFill("solid", fgColor="00FFFF00")
        # self.bd = Side(style='thin', color="000000")
        # self.border = Border(left=self.bd, top=self.bd, right=self.bd, bottom=self.bd)
        self.MAX_MEAN = 31

    @staticmethod
    def header_collect(ws, sheet_header_info, header_info, max_column_list, classify):
        # sheet_header_info = {
        #     'sheet_name': {
        #         'summary_col': 1,
        #         'date_col': 1,
        #         'amount_col': 1,
        #         'over_col': 1,
        #         'income_col': 1,
        #         'outlay_col': 1,
        #         'borrow_col': 1,
        #         'min_row': 2,
        #         'find_count': 3,
        #         'find_col': {1},
        #         'header': ('日期', '金额')
        #     }
        # }

        # header_info = {
        #     'summary_col': {
        #         5: 2,
        #         3: 1,
        #     },
        #     'date_col': {},
        #     'amount_col': {},
        #     'over_col': {},
        #     'income_col': {},
        #     'outlay_col': {},
        #     'borrow_col': {},
        # }

        # 第一行关键词
        find_count = 0
        for first_row in ws.iter_rows(max_row=1, min_row=1, values_only=True):
            sheet_header_info.setdefault(ws.title, {}).setdefault(consts.HEADER_KEY, first_row)
            for idx, header_value in enumerate(first_row):
                if classify == consts.WECHART_CLASSIFY:
                    header_col = consts.WECHART_HEADERS_MAPPING.get(header_value)
                else:
                    header_col = consts.HEADERS_MAPPING.get(header_value)
                if header_col is not None:
                    find_count += 1
                    sheet_header_info.setdefault(ws.title, {}).setdefault(header_col, idx)
                    find_col_set = sheet_header_info.setdefault(ws.title, {}).setdefault(consts.FIND_COL_KEY, set())
                    find_col_set.add(idx)
                    col_count = header_info.setdefault(header_col, {}).get(idx)
                    header_info.setdefault(header_col, {})[idx] = 1 if col_count is None else col_count+1

        sheet_header_info.setdefault(ws.title, {}).setdefault(consts.FIND_COUNT_KEY, find_count)
        min_row = 1 if find_count == 0 else 2
        sheet_header_info.setdefault(ws.title, {}).setdefault(consts.MIN_ROW_KEY, min_row)
        max_column_list.append(ws.max_column)

    @staticmethod
    def header_statistics(sheet_header_info, header_info, classify):
        # statistics_header_info = {
        #     SUMMARY_KEY: 2,
        #     DATE_KEY: 3,
        #     AMOUNT_KEY: 4,
        #     OVER_KEY: 5,
        #     IMCOME_KEY: 6,
        #     OUTLAY_KEY: 7,
        #     BORROW_KEY: 8,
        #     'header': ('日期', '金额')
        # }
        statistics_header_info = {}
        sheet_order_list = sorted(sheet_header_info, reverse=True,
                                  key=lambda x: sheet_header_info[x][consts.FIND_COUNT_KEY])
        best_sheet_info = sheet_header_info.get(sheet_order_list[0])
        if best_sheet_info.get(consts.FIND_COUNT_KEY, 0) == 0:
            for key, value in consts.CLASSIFY_MAP.items():
                col = consts.CLASSIFY_LIST[classify][1][value]
                statistics_header_info[key] = col - 1 if isinstance(col, int) else None
            statistics_header_info[consts.HEADER_KEY] = consts.CLASSIFY_HEADER_LIST[classify]
        else:
            find_col_set = best_sheet_info.get(consts.FIND_COL_KEY, set())
            # SUMMARY_KEY DATE_KEY OVER_KEY BORROW_KEY
            for key in consts.KEY_LIST:
                col = best_sheet_info.get(key)
                if col is None:
                    col_dict = header_info.get(key, {})
                    for idx in sorted(col_dict, key=lambda x: col_dict[x], reverse=True):
                        if idx in find_col_set:
                            continue
                        col = idx
                        find_col_set.add(col)
                        break
                    else:
                        fixed_col = consts.CLASSIFY_LIST[classify][1][consts.CLASSIFY_MAP[key]]
                        if fixed_col not in find_col_set and isinstance(fixed_col, int):
                            col = fixed_col - 1
                            find_col_set.add(col)
                statistics_header_info[key] = col
            statistics_header_info[consts.HEADER_KEY] = best_sheet_info.get(consts.HEADER_KEY)
        return statistics_header_info

    @staticmethod
    def get_data_col_min_row(sheet, sheet_header_info, header_info, classify):
        date_col = sheet_header_info.get(sheet, {}).get(consts.DATE_KEY)
        if date_col is None:
            date_col_dict = header_info.get(consts.DATE_KEY, {})
            find_col_set = sheet_header_info.get(sheet, {}).get(consts.FIND_COL_KEY, set())
            for idx in sorted(date_col_dict, key=lambda x: date_col_dict[x], reverse=True):
                if idx in find_col_set:
                    continue
                date_col = idx
                break
            else:
                fixed_col = consts.CLASSIFY_LIST[classify][1][consts.CLASSIFY_MAP[consts.DATE_KEY]]
                if fixed_col not in find_col_set and isinstance(fixed_col, int):
                    date_col = fixed_col - 1
        min_row = sheet_header_info.get(sheet, {}).get(consts.MIN_ROW_KEY, 2)
        return date_col, min_row

    @staticmethod
    def month_split(dti, date_list, date_statistics):
        month_list = []
        idx_list = []
        month_pre = None
        for idx, month_str in enumerate(dti.strftime('%Y-%m')):
            if isinstance(month_str, float):
                continue
            if month_str != month_pre:
                month_list.append(month_str)
                if month_pre is None:
                    if date_statistics:
                        date_list.append(dti[idx].date())
                    idx = 0
                idx_list.append(idx)
                month_pre = month_str
        if date_statistics:
            for idx in range(len(dti) - 1, -1, -1):
                if isinstance(dti[idx], NaTType):
                    continue
                date_list.append(dti[idx].date())
                break
        return month_list, idx_list

    @staticmethod
    def get_reverse_trend(day_idx, idx_list):
        reverse_trend = 0
        pre_day = None
        for idx, day in enumerate(day_idx):
            if np.isnan(day):
                continue
            if idx in idx_list or pre_day is None:
                pre_day = day
                continue
            if day < pre_day:
                reverse_trend += 1
                pre_day = day
            elif day > pre_day:
                reverse_trend -= 1
                pre_day = day
        if reverse_trend > 0:
            reverse_trend = 1
        elif reverse_trend < 0:
            reverse_trend = -1
        return reverse_trend

    def sheet_split(self, ws, date_col, min_row, month_mapping, reverse_trend_list, date_list, date_statistics):
        if date_col is None:
            # month_info process
            month_info = month_mapping.setdefault('xxxx-xx', [])
            month_info.append((ws.title, min_row, ws.max_row, 0))
            return
        date_col = date_col + 1
        for date_tuple_src in ws.iter_cols(min_col=date_col, max_col=date_col, min_row=min_row, values_only=True):
            date_tuple = [date[:10] if isinstance(date, str) else date for date in date_tuple_src]
            dt_array, tz_parsed = tslib.array_to_datetime(
                np.array(date_tuple, copy=False, dtype=np.object_),
                errors="coerce",
                utc=False,
                dayfirst=False,
                yearfirst=False,
                require_iso8601=True,
            )
            dti = DatetimeIndex(dt_array, tz=None, name=None)

            month_list, idx_list = self.month_split(dti, date_list, date_statistics)

            if len(month_list) == 0:
                # month_info process
                month_info = month_mapping.setdefault('xxxx-xx', [])
                month_info.append((ws.title, min_row, ws.max_row, 0))
            else:
                # reverse_trend_list process
                reverse_trend = self.get_reverse_trend(dti.day, idx_list)
                reverse_trend_list.append(reverse_trend)
                # month_info process
                day_idx = dti.day
                idx_list_max_idx = len(idx_list) - 1
                for i, item in enumerate(month_list):
                    if i == idx_list_max_idx:
                        day_mean = np.mean(day_idx[idx_list[i]:].dropna())
                        month_mapping.setdefault(item, []).append(
                            (ws.title, idx_list[i] + min_row, ws.max_row, day_mean))
                    else:
                        day_mean = np.mean(day_idx[idx_list[i]: idx_list[i + 1]].dropna())
                        month_mapping.setdefault(item, []).append(
                            (ws.title, idx_list[i] + min_row, idx_list[i + 1] + min_row - 1, day_mean))

    def build_metadata_rows(self, confidence, code, print_time, start_date, end_date):
        if start_date is None or end_date is None:
            timedelta = None
        else:
            timedelta = (end_date - start_date).days
        metadata_rows = [
            ('流水识别置信度', confidence),
            self.blank_row,
            self.code_header,
        ]
        metadata_rows.extend(code)
        metadata_rows.extend(
            [self.blank_row,
             self.date_header,
             (print_time, start_date, end_date, timedelta),
             self.blank_row,
             self.keyword_header]
        )
        return metadata_rows

    def create_meta_sheet(self, card):
        if self.worksheets[0].title == 'Sheet':
            ms = self.worksheets[0]
            ms.title = '{0}({1})'.format(self.meta_sheet_title, card[-6:])
        else:
            ms = self.create_sheet('{0}({1})'.format(self.meta_sheet_title, card[-6:]))
        return ms

    def build_meta_sheet(self, card, confidence, code, print_time, start_date, end_date):
        metadata_rows = self.build_metadata_rows(confidence, code, print_time, start_date, end_date)
        ms = self.create_meta_sheet(card)
        for row in metadata_rows:
            ms.append(row)
        return ms

    @staticmethod
    def amount_format(amount_str):
        if not isinstance(amount_str, str) or amount_str == '':
            return amount_str
        # 1.替换
        res_str = amount_str.translate(consts.TRANS)
        # 2.首字符处理
        first_char = res_str[0]
        if first_char in consts.ERROR_CHARS:
            first_char = '-'
        # 3.删除多余的-
        res_str = first_char + res_str[1:].replace('-', '')
        # 4.逗号与句号处理
        if len(res_str) >= 4:
            period_idx = len(res_str) - 3
            if res_str[period_idx] == '.' and res_str[period_idx - 1] == ',':
                res_str = '{0}{1}'.format(res_str[:period_idx - 1], res_str[period_idx:])
            elif res_str[period_idx] == ',':
                res_str = '{0}.{1}'.format(res_str[:period_idx], res_str[period_idx + 1:])
        return res_str

    def build_month_sheet(self, ms, card, month_mapping, is_reverse, statistics_header_info, max_column):
        summary_cell_idx = statistics_header_info.get(consts.SUMMARY_KEY)
        date_cell_idx = statistics_header_info.get(consts.DATE_KEY)
        amount_cell_idx = statistics_header_info.get(consts.AMOUNT_KEY)  # None or src or append
        over_cell_idx = statistics_header_info.get(consts.OVER_KEY)
        income_cell_idx = statistics_header_info.get(consts.IMCOME_KEY)
        outlay_cell_idx = statistics_header_info.get(consts.OUTLAY_KEY)
        borrow_cell_idx = statistics_header_info.get(consts.BORROW_KEY)
        header = list(statistics_header_info.get(consts.HEADER_KEY))
        src_header_len = len(header)
        if max_column > src_header_len:
            for i in range(max_column - src_header_len):
                header.append(None)

        add_col = ['核对结果']
        if amount_cell_idx is None:
            if income_cell_idx is not None or outlay_cell_idx is not None:
                add_col = ['金额', '核对结果']
                amount_cell_idx = len(header)
        header.extend(add_col)
        result_idx = len(header) - 1

        tmp_ws = self.create_sheet('tmp_ws')
        for month in sorted(month_mapping.keys()):
            # 3.1.拷贝数据
            parts = month_mapping.get(month)
            new_ws = self.create_sheet('{0}({1})'.format(month, card[-6:]))
            new_ws.append(header)
            for part in parts:
                ws = self.get_sheet_by_name(part[0])
                for row_value in ws.iter_rows(min_row=part[1], max_row=part[2], values_only=True):
                    if any(row_value):
                        new_ws.append(row_value)
            # 3.2.提取信息、高亮
            amount_mapping = {}
            amount_fill_row = set()

            for rows in new_ws.iter_rows(min_row=2):
                # TODO 删除空行
                summary_cell = None if summary_cell_idx is None else rows[summary_cell_idx]
                date_cell = None if date_cell_idx is None else rows[date_cell_idx]
                amount_cell = None if amount_cell_idx is None else rows[amount_cell_idx]
                over_cell = None if over_cell_idx is None else rows[over_cell_idx]
                income_cell = None if income_cell_idx is None else rows[income_cell_idx]
                outlay_cell = None if outlay_cell_idx is None else rows[outlay_cell_idx]
                borrow_cell = None if borrow_cell_idx is None else rows[borrow_cell_idx]

                summary_cell_value = None if summary_cell is None else summary_cell.value
                date_cell_value = None if date_cell is None else date_cell.value
                amount_cell_value = None if amount_cell is None else amount_cell.value
                over_cell_value = None if over_cell is None else over_cell.value
                income_cell_value = None if income_cell is None else income_cell.value
                outlay_cell_value = None if outlay_cell is None else outlay_cell.value
                borrow_cell_value = None if borrow_cell is None else borrow_cell.value

                # row = summary_cell.row
                if summary_cell is not None:
                    # 关键词1提取
                    if summary_cell_value in self.interest_keyword:
                        ms.append((summary_cell_value, date_cell_value, amount_cell_value))
                    # 关键词2提取至临时表
                    elif summary_cell_value in self.salary_keyword:
                        tmp_ws.append((summary_cell_value, date_cell_value, amount_cell_value))
                    # 贷款关键词高亮
                    elif summary_cell_value in self.loan_keyword:
                        summary_cell.fill = self.amount_fill
                        if amount_cell is not None:
                            amount_cell.fill = self.amount_fill

                # 3.3.余额转数值
                over_success = False
                if over_cell is not None:
                    try:
                        over_cell.value = locale.atof(self.amount_format(over_cell_value))
                    except Exception as e:
                        pass
                    else:
                        over_success = True
                        over_cell.number_format = numbers.FORMAT_NUMBER_00

                # 3.4.金额转数值
                amount_success = False
                if amount_cell is not None:
                    try:
                        try:
                            amount_cell.value = locale.atof(self.amount_format(amount_cell_value))
                        except Exception as e:
                            try:
                                amount_cell.value = locale.atof(self.amount_format(income_cell_value))
                                if amount_cell.value == 0:
                                    raise
                                elif amount_cell.value < 0:
                                    amount_cell.value = -amount_cell.value
                            except Exception as e:
                                amount_cell.value = locale.atof(self.amount_format(outlay_cell_value))
                                if amount_cell.value > 0:
                                    amount_cell.value = -amount_cell.value
                    except Exception as e:
                        pass
                    else:
                        amount_success = True
                        if borrow_cell_value in consts.BORROW_OUTLAY_SET:
                            amount_cell.value = -amount_cell.value
                        amount_cell.number_format = numbers.FORMAT_NUMBER_00
                        if date_cell is not None:
                            same_amount_mapping = amount_mapping.get(date_cell.value, {})
                            fill_rows = same_amount_mapping.get(-amount_cell.value)
                            if fill_rows:
                                amount_fill_row.add(amount_cell.row)
                                amount_fill_row.update(fill_rows)
                            amount_mapping.setdefault(date_cell.value, {}).setdefault(
                                amount_cell.value, []).append(amount_cell.row)

                # 3.5.核对结果
                if amount_success and over_success and amount_cell.row > 2:
                    amount_col_letter = get_column_letter(amount_cell_idx + 1)
                    over_col_letter = get_column_letter(over_cell_idx + 1)
                    if is_reverse:
                        rows[result_idx].value = '=IF({2}{0}=ROUND(SUM({2}{1},{3}{0}),4), "{4}", "{5}")'.format(
                            amount_cell.row - 1, amount_cell.row, over_col_letter, amount_col_letter, *self.proof_res)
                    else:
                        rows[result_idx].value = '=IF({2}{0}=ROUND(SUM({2}{1},{3}{0}),4), "{4}", "{5}")'.format(
                            amount_cell.row, amount_cell.row - 1, over_col_letter, amount_col_letter, *self.proof_res)

            # 3.6.同一天相同进出账高亮
            del amount_mapping
            for row in amount_fill_row:
                new_ws[row][amount_cell_idx].fill = self.amount_fill
                if summary_cell_idx is not None:
                    new_ws[row][summary_cell_idx].fill = self.amount_fill

        # 关键词2信息提取
        ms.append(self.blank_row)
        ms.append(self.keyword_header)
        for row in tmp_ws.iter_rows(values_only=True):
            ms.append(row)
        self.remove(tmp_ws)

    def bs_rebuild(self, bs_summary):
        # bs_summary = {
        #     '卡号': {
        #         'classify': 0,
        #         'confidence': 0.9,
        #         'role': '柳雪',
        #         'code': [('page', 'code')],
        #         'print_time': 'datetime',
        #         'start_date': 'datetime',
        #         'end_date': 'datetime',
        #         'sheet': ['sheet_name']
        #     }
        # }
        for card, summary in bs_summary.items():
            # 1.原表表头收集、按照月份分割
            # 1.1 总结首行信息
            classify = summary.get('classify', 0)
            sheet_header_info = {}
            header_info = {}
            max_column_list = []
            for sheet in summary.get('sheet', []):
                ws = self.get_sheet_by_name(sheet)
                self.header_collect(ws, sheet_header_info, header_info, max_column_list, classify)
            statistics_header_info = self.header_statistics(sheet_header_info, header_info, classify)
            max_column = max(max_column_list)

            # 1.2.按月份分割 min_row 正文第一行 date_col 日期行
            start_date = summary.get('start_date')
            end_date = summary.get('end_date')
            date_statistics = True if start_date is None or end_date is None else False  # 用于判断是否需要收集各表中日期
            date_list = []  # 用于收集各表中日期
            month_mapping = {}  # 用于创建月份表
            reverse_trend_list = []  # 用于判断倒序与正序
            for sheet in summary.get('sheet', []):
                ws = self.get_sheet_by_name(sheet)
                date_col, min_row = self.get_data_col_min_row(sheet, sheet_header_info, header_info, classify)
                self.sheet_split(ws, date_col, min_row, month_mapping, reverse_trend_list, date_list, date_statistics)

            if date_statistics is True and len(date_list) > 1:
                start_date = min(date_list) if start_date is None else start_date
                end_date = max(date_list) if end_date is None else end_date

            # 2.元信息提取表
            ms = self.build_meta_sheet(card,
                                       summary.get('confidence', 1),
                                       summary.get('code'),
                                       summary.get('print_time'),
                                       start_date,
                                       end_date)

            # 3.创建月份表、提取/高亮关键行
            # 倒序处理
            is_reverse = True if sum(reverse_trend_list) > 0 else False
            for month_list in month_mapping.values():
                month_list.sort(key=lambda x: x[-1], reverse=is_reverse)

            self.build_month_sheet(ms, card, month_mapping, is_reverse, statistics_header_info, max_column)

            # 4.删除原表
            for sheet in summary.get('sheet'):
                self.remove(self.get_sheet_by_name(sheet))

    def license_rebuild(self, license_summary, document_scheme):
        for classify, (_, name, field_order, side_diff, scheme_diff) in consts.LICENSE_ORDER:
            license_list = license_summary.get(classify)
            if not license_list:
                continue
            ws = self.create_sheet(name)
            if scheme_diff and document_scheme == consts.DOC_SCHEME_LIST[1]:
                classify = consts.MVC_CLASSIFY_SE
            for license_dict in license_list:
                if classify == consts.IC_CLASSIFY and license_dict.get('类别') == '1':
                    license_summary.setdefault(consts.RP_CLASSIFY, []).append(license_dict)
                    continue
                if side_diff:
                    key, field_order_yes, field_order_no = consts.FIELD_ORDER_MAP.get(classify)
                    field_order = field_order_yes if key in license_dict else field_order_no
                for search_field, write_field in field_order:
                    ws.append((write_field, license_dict.get(search_field, '')))
                ws.append((None, ))

    def skip_img_sheet(self, skip_img):
        if skip_img:
            ws = self.create_sheet(consts.SKIP_IMG_SHEET_NAME)
            ws.append(consts.SKIP_IMG_SHEET_HEADER)
            for img_tuple in skip_img:
                ws.append(img_tuple)

    def rebuild(self, bs_summary, license_summary, skip_img, document_scheme):
        self.bs_rebuild(bs_summary)
        self.license_rebuild(license_summary, document_scheme)
        self.skip_img_sheet(skip_img)