hed_edge_detection_solver.py 2.47 KB
import torch
import torch.nn as nn
import os
import logging
from model.hed import *
from data.edge_loader import get_loader
from torch.optim import Adam, lr_scheduler
import torch.nn.functional as F
import cv2

os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'

logging.basicConfig(level=logging.INFO,
                    filename='/home/lxl/work/ocr/documentSeg/log/logfile/hed_edge_detection.log',
                    filemode='a')


class Solver:
    def __init__(self):
        self.train_loader, self.test_loader = get_loader()
        self.len_train_loader = len(self.train_loader)
        self.len_test_loader = len(self.test_loader)

        self.model = HED_res34()
        self.model = nn.DataParallel(self.model)
        self.model = self.model.cuda()

        self.epoch = 10
        self.lr = 0.00005
        self.weight_decay = 0.00001
        self.save_dir = '/home/lxl/work/ocr/documentSeg/log/checkpoint/hed_edge_detection'
        if not os.path.exists(self.save_dir):
            os.makedirs(self.save_dir)

        self.optimizer = Adam(self.model.parameters(), lr=self.lr, weight_decay=self.weight_decay)


    def train(self):
        lr_decay = lr_scheduler.StepLR(self.optimizer, step_size=1, gamma=0.5)

        for t in range(self.epoch):
            logging.info("==============EPOCH {} START================".format(t + 1))
            self.model.train()

            for i, (x, y) in enumerate(self.train_loader):
                x = x.cuda()
                y = y.cuda()

                self.optimizer.zero_grad()

                pred = self.model(x)

                loss = cross_entropy_loss_RCF(pred, y.type(torch.int64))


                iteration = i + 1
                if iteration % 50 == 0 or iteration == 1:
                    logging.info(
                        'epoch: {}/{}, iteration: {}/{}, loss: {:.4f}'.format(
                            t + 1, self.epoch, iteration, self.len_train_loader, loss))

                loss.backward()
                self.optimizer.step()

            lr_decay.step()


            logging.info('EPOCH: {}/{}, '.format(t + 1, self.epoch, ))

            self.model.eval()
            torch.save(self.model.module.state_dict(), '%s/ckpt_epoch_%s.pt' % (self.save_dir, str(t + 1)))

            self.test()

    @torch.no_grad()
    def test(self):
        self.model.eval()


        for i, (x, y) in enumerate(self.test_loader):
            x = x.cuda()
            y = y.cuda()

            pred = self.model(x)



        logging.info('VAL '.format())