sl_solver.py 14.3 KB
import copy
import os
import cv2
import json

import torch
from PIL import Image, ImageDraw, ImageFont

from data import build_dataloader
from loss import build_loss
from model import build_model
from optimizer import build_lr_scheduler, build_optimizer
from utils import SOLVER_REGISTRY, get_logger_and_log_dir
from utils import sequence_mask
from sklearn.metrics import confusion_matrix, accuracy_score, classification_report


@SOLVER_REGISTRY.register()
class SLSolver(object):

    def __init__(self, cfg):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"

        self.cfg = copy.deepcopy(cfg)

        self.train_loader, self.val_loader = build_dataloader(cfg)
        self.train_loader_size, self.val_loader_size = len(self.train_loader), len(self.val_loader)
        self.train_dataset_size, self.val_dataset_size = len(self.train_loader.dataset), len(self.val_loader.dataset)

        # BatchNorm ?
        self.model = build_model(cfg).to(self.device)

        self.loss_fn = build_loss(cfg)

        self.optimizer = build_optimizer(cfg)(self.model.parameters(), **cfg['solver']['optimizer']['args'])

        self.hyper_params = cfg['solver']['args']
        self.base_on = self.hyper_params['base_on']
        self.model_path = self.hyper_params['model_path']
        self.val_image_path = self.hyper_params['val_image_path']
        self.val_label_path = self.hyper_params['val_label_path']
        self.val_go_path = self.hyper_params['val_go_path']
        self.val_map_path = self.hyper_params['val_map_path']
        self.draw_font_path = self.hyper_params['draw_font_path']
        self.thresholds = self.hyper_params['thresholds']
        try:
            self.epoch = self.hyper_params['epoch']
        except Exception:
            raise 'should contain epoch in {solver.args}'

        self.logger, self.log_dir = get_logger_and_log_dir(**cfg['solver']['logger'])

    def accuracy(self, y_pred, y_true, valid_lens, eval=False):
        # [batch_size, seq_len, num_classes]
        y_pred_sigmoid = torch.nn.Sigmoid()(y_pred)
        # [batch_size, seq_len]
        y_pred_idx = torch.argmax(y_pred_sigmoid, dim=-1) + 1
        # [batch_size, seq_len]
        y_pred_is_other = (torch.amax(y_pred_sigmoid, dim=-1) > self.thresholds).int()
        y_pred_rebuild = torch.multiply(y_pred_idx, y_pred_is_other)

        y_true_idx = torch.argmax(y_true, dim=-1) + 1
        y_true_is_other = torch.sum(y_true, dim=-1).int()
        y_true_rebuild = torch.multiply(y_true_idx, y_true_is_other)

        if eval:
            return y_pred_rebuild, y_true_rebuild

        masked_y_true_rebuild = sequence_mask(y_true_rebuild, valid_lens, value=-1)

        return torch.sum((y_pred_rebuild == masked_y_true_rebuild).int()).item()

    def train_loop(self):
        self.model.train()

        seq_lens_sum = torch.zeros(1).to(self.device)
        train_loss = torch.zeros(1).to(self.device)
        correct = torch.zeros(1).to(self.device)
        for batch, (X, y, valid_lens) in enumerate(self.train_loader):
            X, y, valid_lens = X.to(self.device), y.to(self.device), valid_lens.to(self.device)
            
            pred = self.model(X, valid_lens)
            # [batch_size, seq_len, num_classes]

            loss = self.loss_fn(pred, y, valid_lens)
            train_loss += loss.sum()

            if batch % 100 == 0:
                loss_value, current = loss.sum().item(), batch
                self.logger.info(f'train iteration: {current}/{self.train_loader_size}, train loss: {loss_value :.4f}')
            
            self.optimizer.zero_grad()
            loss.sum().backward()
            self.optimizer.step()

            seq_lens_sum += valid_lens.sum()
            correct += self.accuracy(pred, y, valid_lens)

        # correct /= self.train_dataset_size
        correct /= seq_lens_sum 
        train_loss /= self.train_loader_size
        self.logger.info(f'train accuracy: {correct.item() :.4f}, train mean loss: {train_loss.item() :.4f}')

    @torch.no_grad()
    def val_loop(self, t):
        self.model.eval()

        seq_lens_sum = torch.zeros(1).to(self.device)
        val_loss = torch.zeros(1).to(self.device)
        correct = torch.zeros(1).to(self.device)
        for X, y, valid_lens in self.val_loader:
            X, y, valid_lens = X.to(self.device), y.to(self.device), valid_lens.to(self.device)

            # pred = torch.nn.Sigmoid()(self.model(X))
            pred = self.model(X, valid_lens)
            # [batch_size, seq_len, num_classes]

            loss = self.loss_fn(pred, y, valid_lens)
            val_loss += loss.sum()

            seq_lens_sum += valid_lens.sum()
            correct += self.accuracy(pred, y, valid_lens)

        # correct /= self.val_dataset_size
        correct /= seq_lens_sum
        val_loss /= self.val_loader_size
            
        self.logger.info(f"val accuracy: {correct.item() :.4f}, val mean loss: {val_loss.item() :.4f}")

    def save_checkpoint(self, epoch_id):
        self.model.eval()
        torch.save(self.model.state_dict(), os.path.join(self.log_dir, f'ckpt_epoch_{epoch_id}.pt'))

    def run(self):
        if isinstance(self.base_on, str) and os.path.exists(self.base_on):
            self.model.load_state_dict(torch.load(self.base_on))
            self.logger.info(f'==> Load Model from {self.base_on}')

        self.logger.info('==> Start Training')
        print(self.model)

        lr_scheduler = build_lr_scheduler(self.cfg)(self.optimizer, **self.cfg['solver']['lr_scheduler']['args'])

        for t in range(self.epoch):
            self.logger.info(f'==> epoch {t + 1}')

            self.train_loop()
            self.val_loop(t + 1)
            self.save_checkpoint(t + 1)

            lr_scheduler.step()

        self.logger.info('==> End Training')

    # def run(self):
    #     from torch.nn import functional

    #     y = functional.one_hot(torch.randint(0, 10, (8, 100)), 10)
    #     valid_lens = torch.randint(50, 100, (8, ))
    #     print(valid_lens)

    #     pred = functional.one_hot(torch.randint(0, 10, (8, 100)), 10)

    #     print(self.accuracy(pred, y, valid_lens))

    def evaluate(self):
        if isinstance(self.model_path, str) and os.path.exists(self.model_path):
            self.model.load_state_dict(torch.load(self.model_path))
            self.logger.info(f'==> Load Model from {self.model_path}')
        else:
            return

        self.model.eval()

        label_true_list = []
        label_pred_list = []
        for X, y, valid_lens in self.val_loader:
            X, y_true, valid_lens = X.to(self.device), y.to(self.device), valid_lens.to(self.device)

            # pred = torch.nn.Sigmoid()(self.model(X))
            y_pred = self.model(X, valid_lens)
            
            y_pred_rebuild, y_true_rebuild = self.accuracy(y_pred, y_true, valid_lens, eval=True)

            for idx, seq_result in enumerate(y_true_rebuild.cpu().numpy().tolist()):
                label_true_list.extend(seq_result[: valid_lens.cpu().numpy()[idx]])
            for idx, seq_result in enumerate(y_pred_rebuild.cpu().numpy().tolist()):
                label_pred_list.extend(seq_result[: valid_lens.cpu().numpy()[idx]])
        
        acc = accuracy_score(label_true_list, label_pred_list)
        cm = confusion_matrix(label_true_list, label_pred_list)
        report = classification_report(label_true_list, label_pred_list)
        print(acc)
        print(cm)
        print(report)
    
    def draw_val(self):
        if not os.path.isdir(self.val_image_path):
            print('Warn: val_image_path not exists: {0}'.format(self.val_image_path))    
            return

        if not os.path.isdir(self.val_label_path):
            print('Warn: val_label_path not exists: {0}'.format(self.val_label_path))    
            return

        if not os.path.isdir(self.val_go_path):
            print('Warn: val_go_path not exists: {0}'.format(self.val_go_path))    
            return

        if not os.path.isfile(self.val_map_path):
            print('Warn: val_map_path not exists: {0}'.format(self.val_map_path))    
            return

        if isinstance(self.model_path, str) and os.path.exists(self.model_path):
            self.model.load_state_dict(torch.load(self.model_path))
            self.logger.info(f'==> Load Model from {self.model_path}')
        else:
            return

        self.model.eval()

        map_key_input = 'x_y_valid_lens'
        map_key_text = 'find_top_text'
        map_key_value = 'find_value'
        test_group_id = [1, 2, 5, 9, 20, 15, 16, 22, 24, 28]
        group_cn_list = ['其他', '开票日期', '发票代码', '机打号码', '车辆类型', '电话', '发动机号码', '车架号', '帐号', '开户银行', '小写']
        skip_list_valid = [
            # 'CH-B102897920-2.jpg',
            # 'CH-B102551284-0.jpg',
            # 'CH-B102879376-2.jpg',
            # 'CH-B101509488-page-16.jpg',
            # 'CH-B102708352-2.jpg',
        ]

        dataset_base_dir = os.path.dirname(self.val_map_path)
        val_dataset_dir = os.path.join(dataset_base_dir, 'valid')
        save_dir = os.path.join(dataset_base_dir, 'draw_val')
        if not os.path.isdir(save_dir):
            os.makedirs(save_dir, exist_ok=True)

        with open(self.val_map_path, 'r') as fp:
            val_map = json.load(fp) 
        
        data_dict = {key_cn: [0, 0] for key_cn in group_cn_list[1:]}
        failed_dict = dict()
        for img_name in sorted(os.listdir(self.val_image_path)):
            if img_name in skip_list_valid:
                continue

            print('Info: start {0}'.format(img_name))
            image_path = os.path.join(self.val_image_path, img_name)

            img = cv2.imread(image_path)
            im_h, im_w, _ = img.shape
            img_pil = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
            draw = ImageDraw.Draw(img_pil)

            if im_h < im_w:
                size = int(im_h * 0.010)
            else:
                size = int(im_w * 0.010)
            if size < 10:
                size = 10
            font = ImageFont.truetype(self.draw_font_path, size, encoding='utf-8')

            green_color = (0, 255, 0)
            red_color = (255, 0, 0) 
            blue_color = (0, 0, 255) 

            base_image_name, _ = os.path.splitext(img_name)
            go_res_json_path = os.path.join(self.val_go_path, '{0}.json'.format(base_image_name))
            with open(go_res_json_path, 'r') as fp:
                go_res_list = json.load(fp)

            with open(os.path.join(val_dataset_dir, val_map[img_name][map_key_input]), 'r') as fp:
                input_list, label_list, valid_lens_scalar = json.load(fp)

            X = torch.tensor(input_list).unsqueeze(0).to(self.device)
            y_true = torch.tensor(label_list).unsqueeze(0).float().to(self.device)
            valid_lens = torch.tensor([valid_lens_scalar, ]).to(self.device)
            del input_list
            del label_list

            y_pred = self.model(X, valid_lens)

            y_pred_rebuild, y_true_rebuild = self.accuracy(y_pred, y_true, valid_lens, eval=True)
            pred = y_pred_rebuild.cpu().numpy().tolist()[0]
            label = y_true_rebuild.cpu().numpy().tolist()[0]

            correct = 0
            bbox_draw_dict = dict()
            bbox_text_dict = dict()
            for i in range(valid_lens_scalar):
                if pred[i] != 0:
                    bbox_text_dict.setdefault(pred[i]-1, list()).append(i) 

            #     if pred[i] == label[i]:
            #         correct += 1
            #         if pred[i] != 0:
            #             # 绿色
            #             bbox_draw_dict[i] = (group_cn_list[pred[i]], )
            #     else:
            #         # 红色:左上角label,右上角pred
            #         bbox_draw_dict[i] = (group_cn_list[label[i]], group_cn_list[pred[i]])

            # correct_rate = correct / valid_lens_scalar
            
            # 画图
            # for idx, text_tuple in bbox_draw_dict.items():
            #     (x0, y0, x1, y1, x2, y2, x3, y3), _ = go_res_list[idx]
            #     line_color = green_color if len(text_tuple) == 1 else red_color
            #     draw.polygon([(x0, y0), (x1, y1), (x2, y2), (x3, y3)], outline=line_color)
            #     draw.text((int(x0), int(y0)), text_tuple[0], green_color, font=font)
            #     if len(text_tuple) == 2:
            #         draw.text((int(x1), int(y1)), text_tuple[1], red_color, font=font)

            # draw.text((0, 0), str(correct_rate), blue_color, font=font)

            # last_y = size 
            # for k, v in val_map[img_name][map_key_value].items():
            #     draw.text((0, last_y), '{0}: {1}'.format(k, v), blue_color, font=font)
            #     last_y += size

            # img_pil.save(os.path.join(save_dir, img_name))

            # 统计准确率
            label_json_path = os.path.join(self.val_label_path, '{0}.json'.format(base_image_name))
            with open(label_json_path, 'r') as fp:
                label_res = json.load(fp)

            group_text_list = []
            for group_id in test_group_id:
                for item in label_res.get("shapes", []):
                    if item.get("group_id") == group_id:
                        group_text_list.append(item['label'])
                        break
                else:
                    group_text_list.append(None)

            for idx, text in enumerate(group_text_list):
                key_cn = group_cn_list[idx+1]

                pred_idx_list = bbox_text_dict.get(idx)
                if isinstance(pred_idx_list, list):
                    pred_text_list = [go_res_list[idx][-1] for idx in pred_idx_list]
                    pred_text = ' '.join(pred_text_list)
                else:
                    pred_text = None

                data_dict[key_cn][-1] += 1
                if pred_text == text:
                    data_dict[key_cn][0] += 1 
                else:
                    failed_dict.setdefault(key_cn, list()).append((text, pred_text))

            # break
        
        for key_cn, (correct_count, all_count) in data_dict.items():
            print('{0}: {1}'.format(key_cn, round(correct_count/all_count, 2)))
    
        print('===========================')

        for key_cn, failed_list in failed_dict.items():
            print(key_cn)
            for text, pred_text in failed_list:
                print('label: {0} pred: {1}'.format(text, pred_text))
            print('----------------------------------')