main.py 615 Bytes
import os
from datetime import datetime
from model import F3Classification
import const


if __name__ == '__main__':
    base_dir = os.path.dirname(os.path.abspath(__file__))

    m = F3Classification(
        class_name_list=const.CLASS_CN_LIST,
        class_other_first=const.CLASS_OTHER_FIRST
    )

    # m.test()

    dataset_dir = '/home/zwq/data/data_224'
    ckpt_path = os.path.join(base_dir, 'ckpt_{0}.h5'.format(datetime.now().strftime('%Y-%m-%d_%H:%M:%S')))
    epoch = 100
    batch_size = 128

    m.train(dataset_dir, epoch, batch_size, ckpt_path, train_dir_name='train', validate_dir_name='test')