pred.py
1.72 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import torch
import torch.nn as nn
import albumentations as A
import segmentation_models_pytorch as smp
import os
import cv2
import numpy as np
import pandas as pd
from tqdm import tqdm
root = '/home/mly/data/datasets/text_recognition/finetune/src/'
img_root = os.path.join(root, 'img')
df = pd.read_csv(os.path.join(root, 'test.csv'))
visual_root = os.path.join(root, 'visual')
def load_model(ckpt_path):
model = smp.DeepLabV3Plus(
encoder_name='resnet50',
encoder_weights='imagenet',
in_channels=3,
classes=1,
activation='sigmoid'
)
model.load_state_dict(torch.load(ckpt_path))
model.eval()
return model
def infer(model, img):
h, w = img.shape[:2]
transform = A.Compose([
A.Resize(512, 512),
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
augmented = transform(image=img)
img = augmented['image']
img = np.transpose(img, (2, 0, 1))
img = torch.from_numpy(img).unsqueeze(0)
img = img.to(torch.float32)
out = model(img)
out = out.squeeze(0).detach().cpu().numpy().transpose((1, 2, 0))
threshold = 0.1
out[out >= threshold] = 1.
out[out < threshold] = 0.
out = np.uint8(out * 255)
out = cv2.resize(out, (w, h))
return out
def main():
model = load_model('./log/checkpoint/finetune_random_crop/ckpt_epoch_20.pt')
name_list = df.path.to_list()
for name in tqdm(name_list):
img = cv2.imread(os.path.join(img_root, name))
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
mask = infer(model, img)
cv2.imwrite(os.path.join(visual_root, name), mask)
if __name__ == '__main__':
main()