pred.py
1.96 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import cv2
import numpy as np
import torch
from PIL import Image
import torchvision.transforms as transforms
from tqdm import tqdm
from core.model.unet import Unet
from core.model.unet_skip import UnetSkip
# def load_model(ckpt='/data1/lxl/code/ocr/removal/log/2022-11-10/baseline/ckpt_epoch_18.pt'):
# def load_model(ckpt='/data1/lxl/code/ocr/removal/log/2022-11-10/residual/ckpt_epoch_28.pt'):
def load_model(ckpt='./log/2022-11-11/skip/ckpt_epoch_30.pt'):
# model = Unet(
# encoder_name='resnet50',
# encoder_weights='imagenet',
# in_channels=3,
# classes=3,
# activation='tanh'
# )
model = UnetSkip(
encoder_name='resnet50',
encoder_weights='imagenet',
in_channels=3,
classes=3,
activation='tanh'
)
model.load_state_dict(torch.load(ckpt, map_location='cpu'))
model.eval()
return model
def infer(model, img_path, gen_path):
img = cv2.imread(img_path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, (512, 512))
img = img / 255.
img = img.transpose(2, 0, 1).astype(np.float32)
img = torch.from_numpy(img).unsqueeze(0)
out = model(img)
out = out['reconstruction']
out[out < 0] = 0
out[out > 1] = 1
out = out * 255
out = out.detach().cpu().numpy().squeeze(0).transpose(1, 2, 0)
out = cv2.cvtColor(out, cv2.COLOR_RGB2BGR)
cv2.imwrite(gen_path, out)
def infer_list():
img_root = '/data1/lxl/data/ocr/real/src/'
gen_root = '/data1/lxl/data/ocr/real/removed/'
img_list = sorted(os.listdir(img_root))
model = load_model()
for img in tqdm(img_list):
img_path = os.path.join(img_root, img)
gen_path = os.path.join(gen_root, img)
infer(model, img_path, gen_path)
def infer_img():
model = load_model()
img_path = '../DocEnTR/real/hetong_006_00.png'
gen_path = './out.jpg'
infer(model, img_path, gen_path)
infer_img()