pred.py 1.72 KB
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import torch
import torch.nn as nn
import albumentations as A
import segmentation_models_pytorch as smp
import os
import cv2
import numpy as np
import pandas as pd
from tqdm import tqdm

root = '/home/mly/data/datasets/text_recognition/finetune/src/'
img_root = os.path.join(root, 'img')
df = pd.read_csv(os.path.join(root, 'test.csv'))
visual_root = os.path.join(root, 'visual')

def load_model(ckpt_path):
    model = smp.DeepLabV3Plus(
        encoder_name='resnet50',
        encoder_weights='imagenet',
        in_channels=3,
        classes=1,
        activation='sigmoid'
    )
    model.load_state_dict(torch.load(ckpt_path))
    model.eval()

    return model


def infer(model, img):
    h, w = img.shape[:2]

    transform = A.Compose([
        A.Resize(512, 512),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    augmented = transform(image=img)
    img = augmented['image']

    img = np.transpose(img, (2, 0, 1))
    img = torch.from_numpy(img).unsqueeze(0)
    img = img.to(torch.float32)
    
    out = model(img)

    out = out.squeeze(0).detach().cpu().numpy().transpose((1, 2, 0))
    threshold = 0.1
    out[out >= threshold] = 1.
    out[out < threshold] = 0.
    out = np.uint8(out * 255)
    out = cv2.resize(out, (w, h))

    return out


def main():
    model = load_model('./log/checkpoint/finetune_random_crop/ckpt_epoch_20.pt')
    name_list = df.path.to_list()
    for name in tqdm(name_list):
        img = cv2.imread(os.path.join(img_root, name))
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        mask = infer(model, img)

        cv2.imwrite(os.path.join(visual_root, name), mask)
        
if __name__ == '__main__':
    main()