main.py
803 Bytes
import os
from datetime import datetime
from model import F3Classification
import const
if __name__ == '__main__':
base_dir = os.path.dirname(os.path.abspath(__file__))
m = F3Classification(
class_name_list=const.CLASS_CN_LIST,
class_other_first=const.CLASS_OTHER_FIRST
)
# m.test()
dataset_dir = '/home/zwq/data/data_224_f3'
ckpt_path = os.path.join(base_dir, 'ckpt_{0}.h5'.format(datetime.now().strftime('%Y-%m-%d_%H:%M:%S')))
history_save_path = os.path.join(base_dir, 'history_{0}.jpg'.format(datetime.now().strftime('%Y-%m-%d_%H:%M:%S')))
epoch = 100
batch_size = 128
m.train(dataset_dir, epoch, batch_size, ckpt_path, history_save_path,
train_dir_name='train', validate_dir_name='test', thresholds=const.OTHER_THRESHOLDS)