base_class.py
601 Bytes
class BaseModel:
"""
All Model classes should extend BaseModel.
"""
def load_model(self, for_training=False, load_weights_path=None):
"""
Defining the network structure and return
"""
raise NotImplementedError(".load() must be overridden.")
def train(self, dataset_dir, epoch, batch_size, ckpt_path, history_save_path,
train_dir_name='train', validate_dir_name='test', thresholds=0.5, metrics_name='accuracy'):
"""
Model training process
"""
raise NotImplementedError(".train() must be overridden.")