37a9d47e by 周伟奇

classification train

1 parent cbeebc6d
......@@ -10,6 +10,7 @@ class Retriever:
self.key_text_set = self.get_key_text_set(target_fields)
def get_key_text_set(self, target_fields):
# 关键词集合
key_text_set = set()
for key_text_list in target_fields[self.keys_str].values():
for key_text, _, _ in key_text_list:
......@@ -18,11 +19,13 @@ class Retriever:
@staticmethod
def key_top1(coordinates_list, key_coordinates):
# 关键词查找方向:最上面
coordinates_list.sort(key=lambda x: x[1])
return coordinates_list[0]
@staticmethod
def key_right(coordinates_list, key_coordinates, top_padding, bottom_padding):
# 关键词查找方向:右侧
if len(coordinates_list) == 1:
return coordinates_list[0]
height = key_coordinates[-1] - key_coordinates[1]
......@@ -41,6 +44,7 @@ class Retriever:
@staticmethod
def value_right(go_res, key_coordinates, top_padding, bottom_padding):
# 字段值查找方向:右侧
height = key_coordinates[-1] - key_coordinates[1]
y_min = key_coordinates[1] - (top_padding * height)
y_max = key_coordinates[-1] + (bottom_padding * height)
......@@ -57,6 +61,7 @@ class Retriever:
@staticmethod
def value_under(go_res, key_coordinates, left_padding, right_padding):
# 字段值查找方向:下方
width = key_coordinates[2] - key_coordinates[0]
x_min = key_coordinates[0] - (width * left_padding)
x_max = key_coordinates[2] + (width * right_padding)
......
......@@ -9,7 +9,8 @@ class BaseModel:
"""
raise NotImplementedError(".load() must be overridden.")
def train(self, dataset_dir, epoch, batch_size, ckpt_path, train_dir_name='train', validate_dir_name='test'):
def train(self, dataset_dir, epoch, batch_size, ckpt_path, history_save_path,
train_dir_name='train', validate_dir_name='test'):
"""
Model training process
"""
......
......@@ -14,10 +14,12 @@ if __name__ == '__main__':
# m.test()
dataset_dir = '/home/zwq/data/data_224'
dataset_dir = '/home/zwq/data/data_224_f3'
ckpt_path = os.path.join(base_dir, 'ckpt_{0}.h5'.format(datetime.now().strftime('%Y-%m-%d_%H:%M:%S')))
history_save_path = os.path.join(base_dir, 'history_{0}.jpg'.format(datetime.now().strftime('%Y-%m-%d_%H:%M:%S')))
epoch = 100
batch_size = 128
m.train(dataset_dir, epoch, batch_size, ckpt_path, train_dir_name='train', validate_dir_name='test')
m.train(dataset_dir, epoch, batch_size, ckpt_path, history_save_path,
train_dir_name='train', validate_dir_name='test')
......
......@@ -9,37 +9,6 @@ import matplotlib.pyplot as plt
from base_class import BaseModel
@tf.function
def random_rgb_2_bgr(image, label):
if random.random() > 0.5:
return image, label
image = image[:, :, ::-1]
return image, label
@tf.function
def random_grayscale_expand(image, label):
if random.random() > 0.1:
return image, label
image = tf.image.rgb_to_grayscale(image)
image = tf.image.grayscale_to_rgb(image)
return image, label
@tf.function
def load_image(image_path, label):
image = tf.io.read_file(image_path)
image = tf.image.decode_image(image, channels=3)
return image, label
@tf.function
def preprocess_input(image, label):
image = tf.image.resize(image, [224, 224])
image = applications.mobilenet_v2.preprocess_input(image)
return image, label
class F3Classification(BaseModel):
def __init__(self, class_name_list, class_other_first, *args, **kwargs):
......@@ -48,6 +17,34 @@ class F3Classification(BaseModel):
self.class_label_map = self.get_class_label_map(class_name_list, class_other_first)
@staticmethod
def history_save(history, save_path):
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']
plt.figure(figsize=(8, 8))
plt.subplot(2, 1, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.ylabel('Accuracy')
plt.ylim([min(plt.ylim()), 1])
plt.title('Training and Validation Accuracy')
plt.subplot(2, 1, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.ylabel('Cross Entropy')
plt.ylim([0, 1.0])
plt.title('Training and Validation Loss')
plt.xlabel('epoch')
# plt.show()
plt.savefig(save_path)
@staticmethod
def get_class_label_map(class_name_list, class_other_first=False):
return {cn_name: idx - 1 if class_other_first else idx for idx, cn_name in enumerate(class_name_list)}
......@@ -68,21 +65,52 @@ class F3Classification(BaseModel):
label_list.append(tf.one_hot(label, depth=self.class_count))
return image_path_list, label_list
@staticmethod
# @tf.function
def random_rgb_2_bgr(image, label):
if random.random() > 0.2:
return image, label
image = image[:, :, ::-1]
return image, label
@staticmethod
# @tf.function
def random_grayscale_expand(image, label):
if random.random() > 0.1:
return image, label
image = tf.image.rgb_to_grayscale(image)
image = tf.image.grayscale_to_rgb(image)
return image, label
@staticmethod
# @tf.function
def load_image(image_path, label):
image = tf.io.read_file(image_path)
# image = tf.image.decode_image(image, channels=3) # TODO 为什么不行
image = tf.image.decode_png(image, channels=3)
return image, label
@staticmethod
# @tf.function
def preprocess_input(image, label):
image = tf.image.resize(image, [224, 224])
image = applications.mobilenet_v2.preprocess_input(image)
return image, label
def load_dataset(self, dataset_dir, name, batch_size=128, augmentation_methods=[]):
image_and_label_list = self.get_image_label_list(dataset_dir)
tensor_slice_dataset = tf.data.Dataset.from_tensor_slices(image_and_label_list, name=name)
tensor_slice_dataset.shuffle(len(image_and_label_list[0]), reshuffle_each_iteration=True)
tensor_slice_dataset.map(load_image,
num_parallel_calls=tf.data.AUTOTUNE,
deterministic=False)
dataset = tensor_slice_dataset.shuffle(len(image_and_label_list[0]), reshuffle_each_iteration=True)
dataset = dataset.map(
self.load_image, num_parallel_calls=tf.data.AUTOTUNE, deterministic=False)
for augmentation_method in augmentation_methods:
tensor_slice_dataset.map(getattr(self, augmentation_method),
dataset = dataset.map(
getattr(self, augmentation_method),
num_parallel_calls=tf.data.AUTOTUNE,
deterministic=False)
tensor_slice_dataset.map(preprocess_input,
num_parallel_calls=tf.data.AUTOTUNE,
deterministic=False)
parallel_batch_dataset = tensor_slice_dataset.batch(
dataset = dataset.map(
self.preprocess_input, num_parallel_calls=tf.data.AUTOTUNE, deterministic=False)
parallel_batch_dataset = dataset.batch(
batch_size=batch_size,
drop_remainder=True,
num_parallel_calls=tf.data.AUTOTUNE,
......@@ -113,28 +141,29 @@ class F3Classification(BaseModel):
freeze = False
return model
def train(self, dataset_dir, epoch, batch_size, ckpt_path, train_dir_name='train', validate_dir_name='test'):
# model = self.load_model()
# model.summary()
#
# model.compile(
# optimizer=optimizers.Adam(learning_rate=3e-4),
# loss=tfa.losses.SigmoidFocalCrossEntropy(),
# metrics=['accuracy', ],
#
# loss_weights=None,
# weighted_metrics=None,
# run_eagerly=None,
# steps_per_execution=None,
# jit_compile=None,
# )
def train(self, dataset_dir, epoch, batch_size, ckpt_path, history_save_path,
train_dir_name='train', validate_dir_name='test'):
model = self.load_model()
model.summary()
model.compile(
optimizer=optimizers.Adam(learning_rate=3e-4),
loss=tfa.losses.SigmoidFocalCrossEntropy(),
metrics=['accuracy', ],
loss_weights=None,
weighted_metrics=None,
run_eagerly=None,
steps_per_execution=None,
jit_compile=None,
)
train_dataset = self.load_dataset(
dataset_dir=os.path.join(dataset_dir, train_dir_name),
name=train_dir_name,
batch_size=batch_size,
augmentation_methods=[],
# augmentation_methods=['random_rgb_2_bgr', 'random_grayscale_expand'],
# augmentation_methods=[],
augmentation_methods=['random_rgb_2_bgr', 'random_grayscale_expand'],
)
validate_dataset = self.load_dataset(
dataset_dir=os.path.join(dataset_dir, validate_dir_name),
......@@ -143,46 +172,17 @@ class F3Classification(BaseModel):
augmentation_methods=[]
)
# ckpt_callback = callbacks.ModelCheckpoint(ckpt_path, save_best_only=True)
#
# history = model.fit(
# train_dataset,
# epochs=epoch,
# validation_data=validate_dataset,
# callbacks=[ckpt_callback, ],
# )
#
# acc = history.history['accuracy']
# val_acc = history.history['val_accuracy']
#
# loss = history.history['loss']
# val_loss = history.history['val_loss']
#
# plt.figure(figsize=(8, 8))
# plt.subplot(2, 1, 1)
# plt.plot(acc, label='Training Accuracy')
# plt.plot(val_acc, label='Validation Accuracy')
# plt.legend(loc='lower right')
# plt.ylabel('Accuracy')
# plt.ylim([min(plt.ylim()), 1])
# plt.title('Training and Validation Accuracy')
#
# plt.subplot(2, 1, 2)
# plt.plot(loss, label='Training Loss')
# plt.plot(val_loss, label='Validation Loss')
# plt.legend(loc='upper right')
# plt.ylabel('Cross Entropy')
# plt.ylim([0, 1.0])
# plt.title('Training and Validation Loss')
# plt.xlabel('epoch')
# plt.show()
ckpt_callback = callbacks.ModelCheckpoint(ckpt_path, save_best_only=True)
history = model.fit(
train_dataset,
epochs=epoch,
validation_data=validate_dataset,
callbacks=[ckpt_callback, ],
)
self.history_save(history, history_save_path)
def test(self):
print(self.class_label_map)
print(self.class_count)
# path = '/home/zwq/data/data_224/train/银行卡/bc_1.jpg'
# label = 5
# image, label = self.load_image(path, label)
# print(image.shape)
# image, label = self.preprocess_input(image, label)
# print(image.shape)
......
Styling with Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!