comparison.py 13.5 KB
import re
import time
import numpy as np
from datetime import datetime
from dateutil.relativedelta import relativedelta
from .rmb_lower import rmb_handler
# from .rmb_upper import to_rmb_upper
from pandas._libs import tslib
from pandas._libs.tslibs.nattype import NaTType
from pandas.core.indexes.datetimes import DatetimeIndex


class Comparison:

    def __init__(self):
        self.CSIBM = 'CSIBM'
        self.CSSME = 'CSSME'
        self.CSOTH = 'CSOTH'

        self.TYPE_MAPPING = (
            (r'个体工商户', self.CSIBM),
            (r'有限责任公司', self.CSSME),
            (r'个人独资企业', self.CSSME),
            (r'有限合伙企业', self.CSSME),
            (r'股份合作制', self.CSSME),
        )

        self.RESULT_Y = 'Y'
        self.RESULT_N = 'N'
        self.RESULT_NA = 'NA'

        self.TRANS_MAP = {
            ' ': '',
            '·': '',
        }
        self.TRANS = str.maketrans(self.TRANS_MAP)
        self.re_obj = r'[(\(].*?[\))]'

    def build_res(self, result):
        if result:
            return self.RESULT_Y
        else:
            return self.RESULT_N

    def common_compare(self, input_str, ocr_str, idx, **kwargs):
        if not isinstance(ocr_str, str) or not isinstance(input_str, str):
            return self.RESULT_NA, ocr_str
        if ocr_str == '' or ocr_str.strip() == '':
            return self.RESULT_NA, None
        return self.build_res(input_str == ocr_str), ocr_str

    def company_compare(self, input_str, ocr_str, idx, **kwargs):
        if not isinstance(ocr_str, str) or not isinstance(input_str, str):
            return self.RESULT_NA, ocr_str
        if ocr_str == '' or ocr_str.strip() == '':
            return self.RESULT_NA, None
        input_tmp = re.sub(self.re_obj, '', input_str).strip()
        ocr_tmp = re.sub(self.re_obj, '', ocr_str).strip()
        return self.build_res(input_tmp == ocr_tmp), ocr_str

    def name_compare(self, input_str, ocr_str, idx, **kwargs):
        if not isinstance(ocr_str, str) or not isinstance(input_str, str):
            return self.RESULT_NA, ocr_str
        if ocr_str == '' or ocr_str.strip() == '':
            return self.RESULT_NA, None
        if kwargs.get('is_passport'):
            input_tmp = input_str.upper().replace(' ', '')
            ocr_tmp = ocr_str.upper().replace(' ', '')
            if input_tmp.find(ocr_tmp) == -1:
                return self.RESULT_N, ocr_str
            else:
                return self.RESULT_Y, ocr_str
        else:
            # if re.search(r'[a-zA-Z]]', input_str):
            #     return self.RESULT_NA, ocr_str
            input_s = input_str.translate(self.TRANS)
            ocr_s = ocr_str.translate(self.TRANS)
            return self.build_res(input_s == ocr_s), ocr_str

    def date_compare(self, input_str, ocr_str, idx, **kwargs):
        if not isinstance(ocr_str, str) or not isinstance(input_str, str):
            return self.RESULT_NA, ocr_str
        if ocr_str == '' or ocr_str.strip() == '':
            return self.RESULT_NA, None
        if kwargs.get('long', False):
            if '长期' in ocr_str or '永久' in ocr_str:
                if input_str == '2099-12-31' or input_str == '2099-01-01':
                    return self.RESULT_Y, '2099-12-31'
                else:
                    return self.RESULT_N, '2099-12-31'
        if kwargs.get('ocr_split', False):
            if '至' in ocr_str:
                ocr_str = ocr_str.split('至')[-1]
            elif '-' in ocr_str:
                ocr_str = ocr_str.split('-')[-1]
        if kwargs.get('ocr_replace', False):
            ocr_str = ocr_str.replace('年', '-').replace('月', '-').replace('日', '')
        if kwargs.get('input_replace') is not None:
            input_str = input_str.replace('-', kwargs.get('input_replace'))
            try:
                ocr_output = datetime.strptime(ocr_str, '%Y{0}%m{0}%d'.format(
                    kwargs.get('input_replace'))).strftime('%Y-%m-%d')
            except Exception as e:
                ocr_output = None
        else:
            try:
                ocr_output = datetime.strptime(ocr_str, '%Y-%m-%d').strftime('%Y-%m-%d')
            except Exception as e:
                ocr_output = None
        return self.build_res(input_str == ocr_str), ocr_output

    def rmb_compare(self, input_str, ocr_str, idx, **kwargs):
        if not isinstance(ocr_str, str) or not isinstance(input_str, str):
            return self.RESULT_NA, None
        if ocr_str == '' or ocr_str.strip() == '':
            return self.RESULT_NA, None
        try:
            ocr_lower = rmb_handler.to_rmb_lower(ocr_str)
            res = self.build_res(float(input_str) == ocr_lower)
            # input_rmb_upper = to_rmb_upper(float(input_str))
            # res = self.build_res(input_rmb_upper == ocr_str)
        except Exception as e:
            return self.RESULT_N, None
        else:
            if res == self.RESULT_Y:
                return res, input_str
            else:
                return res, ocr_lower
                # return res, None

    def type_compare(self, input_str, ocr_str, idx, **kwargs):
        if not isinstance(ocr_str, str) or not isinstance(input_str, str):
            return self.RESULT_NA, ocr_str
        if ocr_str == '' or ocr_str.strip() == '':
            return self.RESULT_NA, None
        for map_tuple in self.TYPE_MAPPING:
            if re.search(map_tuple[0], ocr_str) is not None:
                compare_str = map_tuple[1]
                break
        else:
            compare_str = self.CSOTH

        return self.build_res(input_str == compare_str), compare_str

    def se_name_compare(self, input_str, ocr_str, **kwargs):
        if kwargs.get('is_passport'):
            input_tmp = input_str.upper().replace(' ', '')
            ocr_tmp = ocr_str.upper().replace(' ', '')
            if input_tmp.find(ocr_tmp) == -1:
                return self.RESULT_N
            else:
                if ocr_str.strip() == '':
                    return self.RESULT_N
                else:
                    return self.RESULT_Y
        else:
            # if re.search(r'[a-zA-Z]]', input_str):
            #     return self.RESULT_NA, ocr_str
            input_s = input_str.translate(self.TRANS)
            ocr_s = ocr_str.translate(self.TRANS)
            return self.build_res(input_s == ocr_s)

    def ca_name_compare(self, input_str, ocr_str, **kwargs):
        if kwargs.get('is_passport'):
            input_tmp = input_str.upper().replace(' ', '')
            ocr_tmp = ocr_str.upper().replace(' ', '')
            if input_tmp.find(ocr_tmp) == -1:
                return self.RESULT_N
            else:
                if ocr_str.strip() == '':
                    return self.RESULT_N
                else:
                    return self.RESULT_Y
        else:
            # if re.search(r'[a-zA-Z]]', input_str):
            #     return self.RESULT_NA, ocr_str
            input_s = input_str.translate(self.TRANS)
            ocr_s = ocr_str.translate(self.TRANS)
            return self.build_res(input_s == ocr_s)

    def se_common_compare(self, input_str, ocr_str, **kwargs):
        return self.build_res(input_str == ocr_str)

    def ca_common_compare(self, input_str, ocr_str, **kwargs):
        return self.build_res(input_str == ocr_str)

    @staticmethod
    def is_after_today(ocr_str):
        dt_array, _ = tslib.array_to_datetime(
            np.array([ocr_str, ], copy=False, dtype=np.object_),
            errors="coerce",
            utc=False,
            dayfirst=False,
            yearfirst=False,
            require_iso8601=True,
        )
        dti = DatetimeIndex(dt_array, tz=None, name=None)
        ts = dti[0]
        if isinstance(ts, NaTType) or ts.date() < datetime.today().date():
            return False
        else:
            return True

    def se_date_compare(self, input_str, ocr_str, **kwargs):
        if kwargs.get('long', False):
            if '长期' in ocr_str or '永久' in ocr_str or '***' in ocr_str or '至今' in ocr_str or '年—月—日' in ocr_str or '年 月 日' in ocr_str:
                if kwargs.get('today', False) or input_str in ['2099-12-31', '2099-01-01', '2999-12-31', '2999-01-01']:
                    return self.RESULT_Y
                else:
                    return self.RESULT_N
        if kwargs.get('ocr_split', False):
            if '至' in ocr_str:
                ocr_str = ocr_str.split('至')[-1]
            elif '-' in ocr_str:
                ocr_str = ocr_str.split('-')[-1]
        if kwargs.get('ocr_replace', False):
            ocr_str = ocr_str.replace('年', '-').replace('月', '-').replace('日', '')
        if kwargs.get('input_replace') is not None:
            input_str = input_str.replace('-', kwargs.get('input_replace'))
        if kwargs.get('today', False):
            return self.build_res(self.is_after_today(ocr_str))
        else:
            return self.build_res(input_str == ocr_str)

    def ca_date_compare(self, input_str, ocr_str, **kwargs):
        if kwargs.get('long', False):
            if '长期' in ocr_str or '永久' in ocr_str:
                if input_str in ['2099-12-31', '2099-01-01']:
                    return self.RESULT_Y
                else:
                    return self.RESULT_N
        if kwargs.get('ocr_split', False):
            if '至' in ocr_str:
                ocr_str = ocr_str.split('至')[-1]
            elif '-' in ocr_str:
                ocr_str = ocr_str.split('-')[-1]
        if kwargs.get('ocr_replace', False):
            ocr_str = ocr_str.replace('年', '-').replace('月', '-').replace('日', '')
        if kwargs.get('input_replace') is not None:
            input_str = input_str.replace('-', kwargs.get('input_replace'))
        return self.build_res(input_str == ocr_str)

    def se_contain_compare(self, input_str, ocr_str, **kwargs):
        if ocr_str.find(input_str) == -1:
            return self.RESULT_N
        else:
            if ocr_str.strip() == '':
                return self.RESULT_N
            else:
                return self.RESULT_Y

    def se_contain_compare_2(self, input_str, ocr_str, **kwargs):
        if input_str.find(ocr_str) == -1:
            return self.RESULT_N
        else:
            if ocr_str.strip() == '':
                return self.RESULT_N
            else:
                return self.RESULT_Y

    def se_both_contain_compare(self, input_str, ocr_str, **kwargs):
        if ocr_str.find(input_str) == -1 and input_str.find(ocr_str) == -1:
            return self.RESULT_N
        else:
            if ocr_str.strip() == '':
                return self.RESULT_N
            else:
                return self.RESULT_Y

    def se_amount_compare(self, input_str, ocr_str, **kwargs):
        if input_str == ocr_str:
            return self.RESULT_Y
        else:
            try:
                float_input = float(input_str)
                float_ocr = float(ocr_str)
            except Exception as e:
                return self.RESULT_N
            else:
                return self.build_res(float_ocr == float_input)

    def se_company_compare(self, input_str, ocr_str, **kwargs):
        input_tmp = re.sub(self.re_obj, '', input_str).strip()
        ocr_tmp = re.sub(self.re_obj, '', ocr_str).strip()
        return self.build_res(input_tmp == ocr_tmp)

    def ca_company_compare(self, input_str, ocr_str, **kwargs):
        input_tmp = re.sub(self.re_obj, '', input_str).strip()
        ocr_tmp = re.sub(self.re_obj, '', ocr_str).strip()
        return self.build_res(input_tmp == ocr_tmp)

    def se_rmb_compare(self, input_str, ocr_str, **kwargs):
        try:
            ocr_lower = rmb_handler.to_rmb_lower(ocr_str)
            res = self.build_res(float(input_str) == ocr_lower)
            # input_rmb_upper = to_rmb_upper(float(input_str))
            # res = self.build_res(input_rmb_upper == ocr_str)
        except Exception as e:
            return self.RESULT_N
        else:
            return res

    def ca_rmb_compare(self, input_str, ocr_str, **kwargs):
        try:
            ocr_lower = rmb_handler.to_rmb_lower(ocr_str)
            res = self.build_res(float(input_str) == ocr_lower)
            # input_rmb_upper = to_rmb_upper(float(input_str))
            # res = self.build_res(input_rmb_upper == ocr_str)
        except Exception as e:
            return self.RESULT_N
        else:
            return res

    def se_type_compare(self, input_str, ocr_str, **kwargs):
        for map_tuple in self.TYPE_MAPPING:
            if re.search(map_tuple[0], ocr_str) is not None:
                compare_str = map_tuple[1]
                break
        else:
            compare_str = self.CSOTH
        return self.build_res(input_str == compare_str)

    def ca_type_compare(self, input_str, ocr_str, **kwargs):
        for map_tuple in self.TYPE_MAPPING:
            if re.search(map_tuple[0], ocr_str) is not None:
                compare_str = map_tuple[1]
                break
        else:
            compare_str = self.CSOTH
        return self.build_res(input_str == compare_str)

    def se_date_compare_2(self, input_str, ocr_str, **kwargs):
        try:
            input_date = datetime.strptime(input_str, "%Y-%m-%d").date()
            if kwargs.get('three_month', False):
                three_month_date = (datetime.today() - relativedelta(months=3)).date()
                compare_date = max(input_date, three_month_date)
            else:
                compare_date = input_date
            ocr_date = datetime.strptime(ocr_str, "%Y-%m-%d").date()
        except Exception as e:
            return self.RESULT_N
        else:
            return self.build_res(compare_date <= ocr_date)


cp = Comparison()