solver.py 3.9 KB
import torch
import torch.nn as nn
from segmentation_models_pytorch.losses import SoftCrossEntropyLoss
import segmentation_models_pytorch as smp
import os
import logging
from data.finetune_loader import *
from torch.optim import Adam, lr_scheduler
import torch.nn.functional as F
import cv2
import imutils

os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'

logging.basicConfig(level=logging.INFO,
                    filename='/home/lxl/work/ocr/documentSeg/log/logfile/finetune_random_crop.log',
                    filemode='a')

def calc_iou(inputs, targets):
    inputs = inputs.detach().cpu()
    targets = targets.detach().cpu()
    inputs = inputs.view(-1)
    targets = targets.view(-1)

    intersection = (inputs * targets).sum()
    total = (inputs + targets).sum()
    union = total - intersection 
    
    IoU = intersection / union

    del inputs, targets

    return IoU


class BCEDiceLoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super().__init__()

    def forward(self, input, target):
        pred = input.view(-1)
        truth = target.view(-1)
        # BCE loss
        bce_loss = nn.BCELoss()(pred, truth).double()
        # Dice Loss
        dice_coef = (2.0 * (pred * truth).double().sum() + 1) / (
            pred.double().sum() + truth.double().sum() + 1
        )

        return bce_loss + (1 - dice_coef)


class Solver:
    def __init__(self):
        self.train_loader, self.test_loader = get_loader()
        self.len_train_loader = len(self.train_loader)
        self.len_test_loader = len(self.test_loader)

        self.model = smp.DeepLabV3Plus(
            encoder_name='resnet50',
            encoder_weights='imagenet',
            in_channels=3,
            classes=1,
            activation='sigmoid'
        )

        self.save_dir = '/home/lxl/work/ocr/documentSeg/log/checkpoint/finetune_random_crop/'
        self.model = nn.DataParallel(self.model)
        self.model = self.model.cuda()

        self.epoch = 20
        self.lr = 0.0008
        self.weight_decay = 0.00001
        if not os.path.exists(self.save_dir):
            os.makedirs(self.save_dir)

        self.optimizer = Adam(self.model.parameters(), lr=self.lr, weight_decay=self.weight_decay)

        self.criterion = BCEDiceLoss().cuda()

    def train(self):
        lr_decay = lr_scheduler.StepLR(self.optimizer, step_size=2, gamma=0.5)

        for t in range(self.epoch):
            logging.info("==============EPOCH {} START================".format(t + 1))
            self.model.train()

            all_iou = 0.0

            for i, (x, y) in enumerate(self.train_loader):
                x = x.cuda()
                y = y.cuda()

                self.optimizer.zero_grad()

                pred = self.model(x)

                loss = self.criterion(pred, y)
                iou = calc_iou(pred, y)
                all_iou += iou

                iteration = i + 1
                if iteration % 50 == 0 or iteration == 1:
                    logging.info(
                        'epoch: {}/{}, iteration: {}/{}, loss: {:.4f}'.format(
                            t + 1, self.epoch, iteration, self.len_train_loader, loss.item()))

                loss.backward()
                self.optimizer.step()

            lr_decay.step()

            miou = all_iou / self.len_train_loader

            logging.info('EPOCH: {}/{}, MIOU: {:.4f}'.format(t + 1, self.epoch, miou))

            self.model.eval()
            torch.save(self.model.module.state_dict(), '%s/ckpt_epoch_%s.pt' % (self.save_dir, str(t + 1)))

            self.val()

    @torch.no_grad()
    def val(self):
        self.model.eval()

        all_iou = 0.0

        for i, (x, y) in enumerate(self.test_loader):
            x = x.cuda()
            y = y.cuda()

            pred = self.model(x)

            iou = calc_iou(pred, y)
            all_iou += iou

        miou = all_iou / self.len_test_loader

        logging.info('VAL MIOU: {:.4f}'.format(miou))