train.py 6.18 KB
# -*- coding: utf-8 -*-
# @Author        : lk
# @Email         : 9428.al@gmail.com
# @Create Date   : 2022-04-19 21:15:37
# @Last Modified : 2022-04-23 15:06:00
# @Description   : 

import os
# os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
import cv2
import time
import json
import math
import pyclipper
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

from shapely.geometry import Polygon
from shapely.geometry.polygon import LinearRing

from utils import MobileUNet, resize_with_padding, DBNetLoss


def gen_dbmap(image_path, label_path):
    prob_map = np.zeros((1024, 1024), dtype=np.float32)
    thre_map = np.zeros((1024, 1024), dtype=np.float32)
    binary_map = np.zeros((1024, 1024), dtype=np.float32)

    image_path = image_path.decode('utf-8')
    image = cv2.imread(image_path)
    if np.random.uniform(0, 1) > 0.5:
        image = image[...,::-1]
    image, ratio = resize_with_padding(image, 1024)                                 # 1024 * 1024  dtype=float32

    label_path = label_path.decode('utf-8')
    label_dict = json.load(open(label_path, 'r', encoding="utf-8"))
    for shape in label_dict['shapes']:
        points = np.array(shape['points']) * ratio                           
        points = points[::-1, :] if LinearRing(points).is_ccw else points           # 使得都是顺时针
        points = np.array(points, dtype=np.int32)

        polygon = Polygon(points)
        if polygon.is_valid:
            distance = int(polygon.area * (1 - np.power(0.4, 2)) / polygon.length)

            cv2.fillPoly(prob_map, [points], 1)
            cv2.polylines(thre_map, [points], isClosed=True, color=1, thickness=distance)

    binary_map = prob_map
    prob_map = np.clip(prob_map - thre_map, 0, 1)

    # plt.figure(figsize=(10, 20))
    # plt.subplot(2, 2, 1)
    # plt.imshow(image.astype(np.uint8))
    # plt.subplot(2, 2, 2)
    # plt.imshow(prob_map)
    # plt.subplot(2, 2, 3)
    # plt.imshow(thre_map)
    # plt.subplot(2, 2, 4)
    # plt.imshow(binary_map)
    # plt.show()

    label = np.stack([prob_map, thre_map, binary_map], axis=-1)
    return image, label

def data_augment(image, label):
    if tf.random.uniform(()) < 0.25:
        image = tf.image.flip_left_right(image)
        label = tf.image.flip_left_right(label)
    if tf.random.uniform(()) < 0.25:
        image = tf.image.flip_up_down(image)
        label = tf.image.flip_up_down(label)
    return image, label


if __name__ == '__main__':

    dataDir = './dataset/wild_200/'

    train_images = [os.path.join(dataDir, 'train/image', fn) for fn in os.listdir(dataDir+'train/image')]
    train_labels = [os.path.join(dataDir, 'train/json', fn.replace('.jpg', '.json')) for fn in os.listdir(dataDir+'train/image')]
    for fn in os.listdir('/home/lk/MyProject/文字检测相关/DBNet2022/dataset/inter_ocr/image'):
        train_images.append(os.path.join('/home/lk/MyProject/文字检测相关/DBNet2022/dataset/inter_ocr/image', fn))
        train_labels.append(os.path.join('/home/lk/MyProject/文字检测相关/DBNet2022/dataset/inter_ocr/json', fn.replace('.jpg', '.json')))
    valid_images = [os.path.join(dataDir, 'test/image', fn) for fn in os.listdir(dataDir+'test/image')]
    valid_labels = [os.path.join(dataDir, 'test/json', fn.replace('.jpg', '.json')) for fn in os.listdir(dataDir+'test/image')]

    # for i, j in zip(train_images, train_labels):
    #     image, label = gen_dbmap(image_path=i.encode('utf-8'),
    #             label_path=j.encode('utf-8'))
    #     print(label.shape)

    train_ds = tf.data.Dataset.from_tensor_slices((train_images, train_labels)).shuffle(10086)
    train_ds = train_ds.map(lambda item1, item2: tf.numpy_function(gen_dbmap, [item1, item2], [tf.float32, tf.float32]), num_parallel_calls=tf.data.experimental.AUTOTUNE)
    train_ds = train_ds.map(data_augment, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    train_ds = train_ds.batch(2).prefetch(buffer_size=32)

    valid_ds = tf.data.Dataset.from_tensor_slices((valid_images, valid_labels)).shuffle(10086)
    valid_ds = valid_ds.map(lambda item1, item2: tf.numpy_function(gen_dbmap, [item1, item2], [tf.float32, tf.float32]), num_parallel_calls=tf.data.experimental.AUTOTUNE)
    valid_ds = valid_ds.batch(2).prefetch(buffer_size=32)

    strategy = tf.distribute.MirroredStrategy()
    with strategy.scope():
        model = MobileUNet()
        model.summary()
        model.load_weights('./model/ckpt.h5', by_name=True, skip_mismatch=True)

        loss_fn = DBNetLoss()

        optimizer = tf.keras.optimizers.Adam(learning_rate=3e-4)

        # for batch_data in train_ds:
        #     batch_images, batch_label = batch_data

        #     pred = model(batch_images, training=True)

        #     prob_map_pred, thre_map_pred = tf.split(pred, num_or_size_splits=2, axis=3)

        #     binary_map_pred = 1 / (1 + tf.math.exp(-50 * (prob_map_pred - thre_map_pred)))

        #     plt.figure(figsize=(10, 20))
        #     plt.subplot(1, 4, 1)
        #     plt.imshow(batch_images[0]/255.)
        #     plt.subplot(1, 4, 2)
        #     plt.imshow(prob_map_pred[0])
        #     plt.subplot(1, 4, 3)
        #     plt.imshow(thre_map_pred[0])
        #     plt.subplot(1, 4, 4)
        #     plt.imshow(binary_map_pred[0])
        #     plt.show()

        #     loss = loss_fn(batch_label, pred)
            
        #     print(loss)

        model.compile(loss=loss_fn, optimizer=optimizer)

    cp_callback = tf.keras.callbacks.ModelCheckpoint('./model/ckpt.h5', monitor='val_loss', save_best_only=True)

    history = model.fit(train_ds,
                        validation_data=valid_ds,
                        epochs=200,
                        callbacks=[cp_callback],
                        )

    model = tf.keras.models.load_model('./model/ckpt.h5', compile=False)
    model.compile(loss=loss_fn, optimizer=optimizer)
    model.evaluate(valid_ds)

    fig, ax = plt.subplots(1, 3, figsize=(30, 3))
    ax = ax.ravel()

    for i, metric in enumerate(["loss"]):
        ax[i].plot(history.history[metric])
        ax[i].plot(history.history["val_" + metric])
        ax[i].set_title("Model {}".format(metric))
        ax[i].set_xlabel("epochs")
        ax[i].set_ylabel(metric)
        ax[i].legend(["train", "val"])

    plt.savefig('./loss.png')
    # plt.show()