sl_solver.py 11.6 KB
import copy
import os
import cv2
import json

import torch
from PIL import Image, ImageDraw, ImageFont

from data import build_dataloader
from loss import build_loss
from model import build_model
from optimizer import build_lr_scheduler, build_optimizer
from utils import SOLVER_REGISTRY, get_logger_and_log_dir
from utils import sequence_mask
from sklearn.metrics import confusion_matrix, accuracy_score, classification_report


@SOLVER_REGISTRY.register()
class SLSolver(object):

    def __init__(self, cfg):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"

        self.cfg = copy.deepcopy(cfg)

        self.train_loader, self.val_loader = build_dataloader(cfg)
        self.train_loader_size, self.val_loader_size = len(self.train_loader), len(self.val_loader)
        self.train_dataset_size, self.val_dataset_size = len(self.train_loader.dataset), len(self.val_loader.dataset)

        # BatchNorm ?
        self.model = build_model(cfg).to(self.device)

        self.loss_fn = build_loss(cfg)

        self.optimizer = build_optimizer(cfg)(self.model.parameters(), **cfg['solver']['optimizer']['args'])

        self.hyper_params = cfg['solver']['args']
        self.base_on = self.hyper_params['base_on']
        self.model_path = self.hyper_params['model_path']
        self.val_image_path = self.hyper_params['val_image_path']
        self.val_go_path = self.hyper_params['val_go_path']
        self.val_map_path = self.hyper_params['val_map_path']
        self.draw_font_path = self.hyper_params['draw_font_path']
        self.thresholds = self.hyper_params['thresholds']
        try:
            self.epoch = self.hyper_params['epoch']
        except Exception:
            raise 'should contain epoch in {solver.args}'

        self.logger, self.log_dir = get_logger_and_log_dir(**cfg['solver']['logger'])

    def accuracy(self, y_pred, y_true, valid_lens, eval=False):
        # [batch_size, seq_len, num_classes]
        y_pred_sigmoid = torch.nn.Sigmoid()(y_pred)
        # [batch_size, seq_len]
        y_pred_idx = torch.argmax(y_pred_sigmoid, dim=-1) + 1
        # [batch_size, seq_len]
        y_pred_is_other = (torch.amax(y_pred_sigmoid, dim=-1) > self.thresholds).int()
        y_pred_rebuild = torch.multiply(y_pred_idx, y_pred_is_other)

        y_true_idx = torch.argmax(y_true, dim=-1) + 1
        y_true_is_other = torch.sum(y_true, dim=-1).int()
        y_true_rebuild = torch.multiply(y_true_idx, y_true_is_other)

        if eval:
            return y_pred_rebuild, y_true_rebuild

        masked_y_true_rebuild = sequence_mask(y_true_rebuild, valid_lens, value=-1)

        return torch.sum((y_pred_rebuild == masked_y_true_rebuild).int()).item()

    def train_loop(self):
        self.model.train()

        seq_lens_sum = torch.zeros(1).to(self.device)
        train_loss = torch.zeros(1).to(self.device)
        correct = torch.zeros(1).to(self.device)
        for batch, (X, y, valid_lens) in enumerate(self.train_loader):
            X, y, valid_lens = X.to(self.device), y.to(self.device), valid_lens.to(self.device)
            
            pred = self.model(X, valid_lens)
            # [batch_size, seq_len, num_classes]

            loss = self.loss_fn(pred, y, valid_lens)
            train_loss += loss.sum()

            if batch % 100 == 0:
                loss_value, current = loss.sum().item(), batch
                self.logger.info(f'train iteration: {current}/{self.train_loader_size}, train loss: {loss_value :.4f}')
            
            self.optimizer.zero_grad()
            loss.sum().backward()
            self.optimizer.step()

            seq_lens_sum += valid_lens.sum()
            correct += self.accuracy(pred, y, valid_lens)

        # correct /= self.train_dataset_size
        correct /= seq_lens_sum 
        train_loss /= self.train_loader_size
        self.logger.info(f'train accuracy: {correct.item() :.4f}, train mean loss: {train_loss.item() :.4f}')

    @torch.no_grad()
    def val_loop(self, t):
        self.model.eval()

        seq_lens_sum = torch.zeros(1).to(self.device)
        val_loss = torch.zeros(1).to(self.device)
        correct = torch.zeros(1).to(self.device)
        for X, y, valid_lens in self.val_loader:
            X, y, valid_lens = X.to(self.device), y.to(self.device), valid_lens.to(self.device)

            # pred = torch.nn.Sigmoid()(self.model(X))
            pred = self.model(X, valid_lens)
            # [batch_size, seq_len, num_classes]

            loss = self.loss_fn(pred, y, valid_lens)
            val_loss += loss.sum()

            seq_lens_sum += valid_lens.sum()
            correct += self.accuracy(pred, y, valid_lens)

        # correct /= self.val_dataset_size
        correct /= seq_lens_sum
        val_loss /= self.val_loader_size
            
        self.logger.info(f"val accuracy: {correct.item() :.4f}, val mean loss: {val_loss.item() :.4f}")

    def save_checkpoint(self, epoch_id):
        self.model.eval()
        torch.save(self.model.state_dict(), os.path.join(self.log_dir, f'ckpt_epoch_{epoch_id}.pt'))

    def run(self):
        if isinstance(self.base_on, str) and os.path.exists(self.base_on):
            self.model.load_state_dict(torch.load(self.base_on))
            self.logger.info(f'==> Load Model from {self.base_on}')

        self.logger.info('==> Start Training')
        print(self.model)

        lr_scheduler = build_lr_scheduler(self.cfg)(self.optimizer, **self.cfg['solver']['lr_scheduler']['args'])

        for t in range(self.epoch):
            self.logger.info(f'==> epoch {t + 1}')

            self.train_loop()
            self.val_loop(t + 1)
            self.save_checkpoint(t + 1)

            lr_scheduler.step()

        self.logger.info('==> End Training')

    # def run(self):
    #     from torch.nn import functional

    #     y = functional.one_hot(torch.randint(0, 10, (8, 100)), 10)
    #     valid_lens = torch.randint(50, 100, (8, ))
    #     print(valid_lens)

    #     pred = functional.one_hot(torch.randint(0, 10, (8, 100)), 10)

    #     print(self.accuracy(pred, y, valid_lens))

    def evaluate(self):
        if isinstance(self.model_path, str) and os.path.exists(self.model_path):
            self.model.load_state_dict(torch.load(self.model_path))
            self.logger.info(f'==> Load Model from {self.model_path}')
        else:
            return

        self.model.eval()

        label_true_list = []
        label_pred_list = []
        for X, y, valid_lens in self.val_loader:
            X, y_true, valid_lens = X.to(self.device), y.to(self.device), valid_lens.to(self.device)

            # pred = torch.nn.Sigmoid()(self.model(X))
            y_pred = self.model(X, valid_lens)
            
            y_pred_rebuild, y_true_rebuild = self.accuracy(y_pred, y_true, valid_lens, eval=True)

            for idx, seq_result in enumerate(y_true_rebuild.cpu().numpy().tolist()):
                label_true_list.extend(seq_result[: valid_lens.cpu().numpy()[idx]])
            for idx, seq_result in enumerate(y_pred_rebuild.cpu().numpy().tolist()):
                label_pred_list.extend(seq_result[: valid_lens.cpu().numpy()[idx]])
        
        acc = accuracy_score(label_true_list, label_pred_list)
        cm = confusion_matrix(label_true_list, label_pred_list)
        report = classification_report(label_true_list, label_pred_list)
        print(acc)
        print(cm)
        print(report)
    
    def draw_val(self):
        if not os.path.isdir(self.val_image_path):
            print('Warn: val_image_path not exists: {0}'.format(self.val_image_path))    
            return

        if not os.path.isdir(self.val_go_path):
            print('Warn: val_go_path not exists: {0}'.format(self.val_go_path))    
            return

        if not os.path.isfile(self.val_map_path):
            print('Warn: val_map_path not exists: {0}'.format(self.val_map_path))    
            return

        map_key_input = 'x_y_valid_lens'
        map_key_text = 'find_top_text'
        map_key_value = 'find_value'
        group_cn_list = ['其他', '开票日期', '发票代码', '机打号码', '车辆类型', '电话', '发动机号码', '车架号', '帐号', '开户银行', '小写']

        dataset_base_dir = os.path.dirname(self.val_map_path)
        val_dataset_dir = os.path.join(dataset_base_dir, 'valid')
        save_dir = os.path.join(dataset_base_dir, 'draw_val')
        if not os.path.isdir(save_dir):
            os.makedirs(save_dir, exist_ok=True)

        self.model.eval()

        with open(self.val_map_path, 'r') as fp:
            val_map = json.load(fp) 
        
        for img_name in sorted(os.listdir(self.val_image_path)):
            print('Info: start {0}'.format(img_name))
            image_path = os.path.join(self.val_image_path, img_name)

            img = cv2.imread(image_path)
            im_h, im_w, _ = img.shape
            img_pil = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
            draw = ImageDraw.Draw(img_pil)

            if im_h < im_w:
                size = int(im_h * 0.015)
            else:
                size = int(im_w * 0.015)
            if size < 14:
                size = 14
            font = ImageFont.truetype(self.draw_font_path, size, encoding='utf-8')

            green_color = (0, 255, 0)
            red_color = (255, 0, 0) 
            blue_color = (0, 0, 255) 

            base_image_name, _ = os.path.splitext(img_name)
            go_res_json_path = os.path.join(self.val_go_path, '{0}.json'.format(base_image_name))
            with open(go_res_json_path, 'r') as fp:
                go_res_list = json.load(fp)

            with open(os.path.join(val_dataset_dir, val_map[img_name][map_key_input]), 'r') as fp:
                input_list, label_list, valid_lens_scalar = json.load(fp)

            X = torch.tensor(input_list).unsqueeze(0).to(self.device)
            y_true = torch.tensor(label_list).unsqueeze(0).float().to(self.device)
            valid_lens = torch.tenor([valid_lens_scalar, ]).to(self.device)
            del input_list
            del label_list

            y_pred = self.model(X, valid_lens)

            y_pred_rebuild, y_true_rebuild = self.accuracy(y_pred, y_true, valid_lens, eval=True)
            pred = y_pred_rebuild.cpu().numpy().tolist()[0]
            label = y_true_rebuild.cpu().numpy().tolist()[0]

            correct = 0
            bbox_draw_dict = dict()
            for i in range(valid_lens_scalar):
                if pred[i] == label[i]:
                    correct += 1
                    if pred[i] != 0:
                        # 绿色
                        bbox_draw_dict[i] = (group_cn_list[pred[i]], )
                else:
                    # 红色:左上角label,右上角pred
                    bbox_draw_dict[i] = (group_cn_list[label[i]], group_cn_list[pred[i]])

            correct_rate = correct / valid_lens_scalar
            
            # 画图
            for idx, text_tuple in bbox_draw_dict.items():
                (x0, y0, x1, y1, x2, y2, x3, y3), _ = go_res_list[idx]
                line_color = green_color if len(text_tuple) == 1 else red_color
                draw.polygon([(x0, y0), (x1, y1), (x2, y2), (x3, y3)], outline=line_color)
                draw.text((int(x0), int(y0)), text_tuple[0], green_color, font=font)
                if len(text_tuple) == 2:
                    draw.text((int(x1), int(y1)), text_tuple[1], red_color, font=font)

            draw.text((0, 0), str(correct_rate), blue_color, font=font)

            last_y = size 
            for k, v in val_map[img_name][map_key_value].items():
                draw.text((0, last_y), '{0}: {1}'.format(k, v), blue_color, font=font)
                last_y += size

            img_pil.save(os.path.join(save_dir, img_name))

            # break