pred.py 1.96 KB
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import cv2
import numpy as np
import torch
from PIL import Image
import torchvision.transforms as transforms
from tqdm import tqdm


from core.model.unet import Unet
from core.model.unet_skip import UnetSkip

# def load_model(ckpt='/data1/lxl/code/ocr/removal/log/2022-11-10/baseline/ckpt_epoch_18.pt'):
# def load_model(ckpt='/data1/lxl/code/ocr/removal/log/2022-11-10/residual/ckpt_epoch_28.pt'):
def load_model(ckpt='./log/2022-11-11/skip/ckpt_epoch_30.pt'):
    # model = Unet(
    #     encoder_name='resnet50',
    #     encoder_weights='imagenet',
    #     in_channels=3,
    #     classes=3,
    #     activation='tanh'
    # )

    model = UnetSkip(
        encoder_name='resnet50',
        encoder_weights='imagenet',
        in_channels=3,
        classes=3,
        activation='tanh'
    )

    model.load_state_dict(torch.load(ckpt, map_location='cpu'))
    model.eval()

    return model


def infer(model, img_path, gen_path):
    img = cv2.imread(img_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = cv2.resize(img, (512, 512))
    img = img / 255.
    img = img.transpose(2, 0, 1).astype(np.float32)
    img = torch.from_numpy(img).unsqueeze(0)
    out = model(img)
    out = out['reconstruction']
    out[out < 0] = 0
    out[out > 1] = 1
    out = out * 255
    out = out.detach().cpu().numpy().squeeze(0).transpose(1, 2, 0)
    out = cv2.cvtColor(out, cv2.COLOR_RGB2BGR)
    cv2.imwrite(gen_path, out)

def infer_list():
    img_root = '/data1/lxl/data/ocr/real/src/'
    gen_root = '/data1/lxl/data/ocr/real/removed/'
    img_list = sorted(os.listdir(img_root))
    model = load_model()
    for img in tqdm(img_list):
        img_path = os.path.join(img_root, img)
        gen_path = os.path.join(gen_root, img)
        infer(model, img_path, gen_path)

def infer_img():
    model = load_model()
    img_path = '../DocEnTR/real/hetong_006_00.png'
    gen_path = './out.jpg'
    infer(model, img_path, gen_path)

infer_img()