classification train
Showing
4 changed files
with
107 additions
and
99 deletions
| ... | @@ -10,6 +10,7 @@ class Retriever: | ... | @@ -10,6 +10,7 @@ class Retriever: |
| 10 | self.key_text_set = self.get_key_text_set(target_fields) | 10 | self.key_text_set = self.get_key_text_set(target_fields) |
| 11 | 11 | ||
| 12 | def get_key_text_set(self, target_fields): | 12 | def get_key_text_set(self, target_fields): |
| 13 | # 关键词集合 | ||
| 13 | key_text_set = set() | 14 | key_text_set = set() |
| 14 | for key_text_list in target_fields[self.keys_str].values(): | 15 | for key_text_list in target_fields[self.keys_str].values(): |
| 15 | for key_text, _, _ in key_text_list: | 16 | for key_text, _, _ in key_text_list: |
| ... | @@ -18,11 +19,13 @@ class Retriever: | ... | @@ -18,11 +19,13 @@ class Retriever: |
| 18 | 19 | ||
| 19 | @staticmethod | 20 | @staticmethod |
| 20 | def key_top1(coordinates_list, key_coordinates): | 21 | def key_top1(coordinates_list, key_coordinates): |
| 22 | # 关键词查找方向:最上面 | ||
| 21 | coordinates_list.sort(key=lambda x: x[1]) | 23 | coordinates_list.sort(key=lambda x: x[1]) |
| 22 | return coordinates_list[0] | 24 | return coordinates_list[0] |
| 23 | 25 | ||
| 24 | @staticmethod | 26 | @staticmethod |
| 25 | def key_right(coordinates_list, key_coordinates, top_padding, bottom_padding): | 27 | def key_right(coordinates_list, key_coordinates, top_padding, bottom_padding): |
| 28 | # 关键词查找方向:右侧 | ||
| 26 | if len(coordinates_list) == 1: | 29 | if len(coordinates_list) == 1: |
| 27 | return coordinates_list[0] | 30 | return coordinates_list[0] |
| 28 | height = key_coordinates[-1] - key_coordinates[1] | 31 | height = key_coordinates[-1] - key_coordinates[1] |
| ... | @@ -41,6 +44,7 @@ class Retriever: | ... | @@ -41,6 +44,7 @@ class Retriever: |
| 41 | 44 | ||
| 42 | @staticmethod | 45 | @staticmethod |
| 43 | def value_right(go_res, key_coordinates, top_padding, bottom_padding): | 46 | def value_right(go_res, key_coordinates, top_padding, bottom_padding): |
| 47 | # 字段值查找方向:右侧 | ||
| 44 | height = key_coordinates[-1] - key_coordinates[1] | 48 | height = key_coordinates[-1] - key_coordinates[1] |
| 45 | y_min = key_coordinates[1] - (top_padding * height) | 49 | y_min = key_coordinates[1] - (top_padding * height) |
| 46 | y_max = key_coordinates[-1] + (bottom_padding * height) | 50 | y_max = key_coordinates[-1] + (bottom_padding * height) |
| ... | @@ -57,6 +61,7 @@ class Retriever: | ... | @@ -57,6 +61,7 @@ class Retriever: |
| 57 | 61 | ||
| 58 | @staticmethod | 62 | @staticmethod |
| 59 | def value_under(go_res, key_coordinates, left_padding, right_padding): | 63 | def value_under(go_res, key_coordinates, left_padding, right_padding): |
| 64 | # 字段值查找方向:下方 | ||
| 60 | width = key_coordinates[2] - key_coordinates[0] | 65 | width = key_coordinates[2] - key_coordinates[0] |
| 61 | x_min = key_coordinates[0] - (width * left_padding) | 66 | x_min = key_coordinates[0] - (width * left_padding) |
| 62 | x_max = key_coordinates[2] + (width * right_padding) | 67 | x_max = key_coordinates[2] + (width * right_padding) | ... | ... |
| ... | @@ -9,7 +9,8 @@ class BaseModel: | ... | @@ -9,7 +9,8 @@ class BaseModel: |
| 9 | """ | 9 | """ |
| 10 | raise NotImplementedError(".load() must be overridden.") | 10 | raise NotImplementedError(".load() must be overridden.") |
| 11 | 11 | ||
| 12 | def train(self, dataset_dir, epoch, batch_size, ckpt_path, train_dir_name='train', validate_dir_name='test'): | 12 | def train(self, dataset_dir, epoch, batch_size, ckpt_path, history_save_path, |
| 13 | train_dir_name='train', validate_dir_name='test'): | ||
| 13 | """ | 14 | """ |
| 14 | Model training process | 15 | Model training process |
| 15 | """ | 16 | """ | ... | ... |
| ... | @@ -14,10 +14,12 @@ if __name__ == '__main__': | ... | @@ -14,10 +14,12 @@ if __name__ == '__main__': |
| 14 | 14 | ||
| 15 | # m.test() | 15 | # m.test() |
| 16 | 16 | ||
| 17 | dataset_dir = '/home/zwq/data/data_224' | 17 | dataset_dir = '/home/zwq/data/data_224_f3' |
| 18 | ckpt_path = os.path.join(base_dir, 'ckpt_{0}.h5'.format(datetime.now().strftime('%Y-%m-%d_%H:%M:%S'))) | 18 | ckpt_path = os.path.join(base_dir, 'ckpt_{0}.h5'.format(datetime.now().strftime('%Y-%m-%d_%H:%M:%S'))) |
| 19 | history_save_path = os.path.join(base_dir, 'history_{0}.jpg'.format(datetime.now().strftime('%Y-%m-%d_%H:%M:%S'))) | ||
| 19 | epoch = 100 | 20 | epoch = 100 |
| 20 | batch_size = 128 | 21 | batch_size = 128 |
| 21 | 22 | ||
| 22 | m.train(dataset_dir, epoch, batch_size, ckpt_path, train_dir_name='train', validate_dir_name='test') | 23 | m.train(dataset_dir, epoch, batch_size, ckpt_path, history_save_path, |
| 24 | train_dir_name='train', validate_dir_name='test') | ||
| 23 | 25 | ... | ... |
| ... | @@ -9,37 +9,6 @@ import matplotlib.pyplot as plt | ... | @@ -9,37 +9,6 @@ import matplotlib.pyplot as plt |
| 9 | from base_class import BaseModel | 9 | from base_class import BaseModel |
| 10 | 10 | ||
| 11 | 11 | ||
| 12 | @tf.function | ||
| 13 | def random_rgb_2_bgr(image, label): | ||
| 14 | if random.random() > 0.5: | ||
| 15 | return image, label | ||
| 16 | image = image[:, :, ::-1] | ||
| 17 | return image, label | ||
| 18 | |||
| 19 | |||
| 20 | @tf.function | ||
| 21 | def random_grayscale_expand(image, label): | ||
| 22 | if random.random() > 0.1: | ||
| 23 | return image, label | ||
| 24 | image = tf.image.rgb_to_grayscale(image) | ||
| 25 | image = tf.image.grayscale_to_rgb(image) | ||
| 26 | return image, label | ||
| 27 | |||
| 28 | |||
| 29 | @tf.function | ||
| 30 | def load_image(image_path, label): | ||
| 31 | image = tf.io.read_file(image_path) | ||
| 32 | image = tf.image.decode_image(image, channels=3) | ||
| 33 | return image, label | ||
| 34 | |||
| 35 | |||
| 36 | @tf.function | ||
| 37 | def preprocess_input(image, label): | ||
| 38 | image = tf.image.resize(image, [224, 224]) | ||
| 39 | image = applications.mobilenet_v2.preprocess_input(image) | ||
| 40 | return image, label | ||
| 41 | |||
| 42 | |||
| 43 | class F3Classification(BaseModel): | 12 | class F3Classification(BaseModel): |
| 44 | 13 | ||
| 45 | def __init__(self, class_name_list, class_other_first, *args, **kwargs): | 14 | def __init__(self, class_name_list, class_other_first, *args, **kwargs): |
| ... | @@ -48,6 +17,34 @@ class F3Classification(BaseModel): | ... | @@ -48,6 +17,34 @@ class F3Classification(BaseModel): |
| 48 | self.class_label_map = self.get_class_label_map(class_name_list, class_other_first) | 17 | self.class_label_map = self.get_class_label_map(class_name_list, class_other_first) |
| 49 | 18 | ||
| 50 | @staticmethod | 19 | @staticmethod |
| 20 | def history_save(history, save_path): | ||
| 21 | acc = history.history['accuracy'] | ||
| 22 | val_acc = history.history['val_accuracy'] | ||
| 23 | |||
| 24 | loss = history.history['loss'] | ||
| 25 | val_loss = history.history['val_loss'] | ||
| 26 | |||
| 27 | plt.figure(figsize=(8, 8)) | ||
| 28 | plt.subplot(2, 1, 1) | ||
| 29 | plt.plot(acc, label='Training Accuracy') | ||
| 30 | plt.plot(val_acc, label='Validation Accuracy') | ||
| 31 | plt.legend(loc='lower right') | ||
| 32 | plt.ylabel('Accuracy') | ||
| 33 | plt.ylim([min(plt.ylim()), 1]) | ||
| 34 | plt.title('Training and Validation Accuracy') | ||
| 35 | |||
| 36 | plt.subplot(2, 1, 2) | ||
| 37 | plt.plot(loss, label='Training Loss') | ||
| 38 | plt.plot(val_loss, label='Validation Loss') | ||
| 39 | plt.legend(loc='upper right') | ||
| 40 | plt.ylabel('Cross Entropy') | ||
| 41 | plt.ylim([0, 1.0]) | ||
| 42 | plt.title('Training and Validation Loss') | ||
| 43 | plt.xlabel('epoch') | ||
| 44 | # plt.show() | ||
| 45 | plt.savefig(save_path) | ||
| 46 | |||
| 47 | @staticmethod | ||
| 51 | def get_class_label_map(class_name_list, class_other_first=False): | 48 | def get_class_label_map(class_name_list, class_other_first=False): |
| 52 | return {cn_name: idx - 1 if class_other_first else idx for idx, cn_name in enumerate(class_name_list)} | 49 | return {cn_name: idx - 1 if class_other_first else idx for idx, cn_name in enumerate(class_name_list)} |
| 53 | 50 | ||
| ... | @@ -68,21 +65,52 @@ class F3Classification(BaseModel): | ... | @@ -68,21 +65,52 @@ class F3Classification(BaseModel): |
| 68 | label_list.append(tf.one_hot(label, depth=self.class_count)) | 65 | label_list.append(tf.one_hot(label, depth=self.class_count)) |
| 69 | return image_path_list, label_list | 66 | return image_path_list, label_list |
| 70 | 67 | ||
| 68 | @staticmethod | ||
| 69 | # @tf.function | ||
| 70 | def random_rgb_2_bgr(image, label): | ||
| 71 | if random.random() > 0.2: | ||
| 72 | return image, label | ||
| 73 | image = image[:, :, ::-1] | ||
| 74 | return image, label | ||
| 75 | |||
| 76 | @staticmethod | ||
| 77 | # @tf.function | ||
| 78 | def random_grayscale_expand(image, label): | ||
| 79 | if random.random() > 0.1: | ||
| 80 | return image, label | ||
| 81 | image = tf.image.rgb_to_grayscale(image) | ||
| 82 | image = tf.image.grayscale_to_rgb(image) | ||
| 83 | return image, label | ||
| 84 | |||
| 85 | @staticmethod | ||
| 86 | # @tf.function | ||
| 87 | def load_image(image_path, label): | ||
| 88 | image = tf.io.read_file(image_path) | ||
| 89 | # image = tf.image.decode_image(image, channels=3) # TODO 为什么不行 | ||
| 90 | image = tf.image.decode_png(image, channels=3) | ||
| 91 | return image, label | ||
| 92 | |||
| 93 | @staticmethod | ||
| 94 | # @tf.function | ||
| 95 | def preprocess_input(image, label): | ||
| 96 | image = tf.image.resize(image, [224, 224]) | ||
| 97 | image = applications.mobilenet_v2.preprocess_input(image) | ||
| 98 | return image, label | ||
| 99 | |||
| 71 | def load_dataset(self, dataset_dir, name, batch_size=128, augmentation_methods=[]): | 100 | def load_dataset(self, dataset_dir, name, batch_size=128, augmentation_methods=[]): |
| 72 | image_and_label_list = self.get_image_label_list(dataset_dir) | 101 | image_and_label_list = self.get_image_label_list(dataset_dir) |
| 73 | tensor_slice_dataset = tf.data.Dataset.from_tensor_slices(image_and_label_list, name=name) | 102 | tensor_slice_dataset = tf.data.Dataset.from_tensor_slices(image_and_label_list, name=name) |
| 74 | tensor_slice_dataset.shuffle(len(image_and_label_list[0]), reshuffle_each_iteration=True) | 103 | dataset = tensor_slice_dataset.shuffle(len(image_and_label_list[0]), reshuffle_each_iteration=True) |
| 75 | tensor_slice_dataset.map(load_image, | 104 | dataset = dataset.map( |
| 76 | num_parallel_calls=tf.data.AUTOTUNE, | 105 | self.load_image, num_parallel_calls=tf.data.AUTOTUNE, deterministic=False) |
| 77 | deterministic=False) | ||
| 78 | for augmentation_method in augmentation_methods: | 106 | for augmentation_method in augmentation_methods: |
| 79 | tensor_slice_dataset.map(getattr(self, augmentation_method), | 107 | dataset = dataset.map( |
| 108 | getattr(self, augmentation_method), | ||
| 80 | num_parallel_calls=tf.data.AUTOTUNE, | 109 | num_parallel_calls=tf.data.AUTOTUNE, |
| 81 | deterministic=False) | 110 | deterministic=False) |
| 82 | tensor_slice_dataset.map(preprocess_input, | 111 | dataset = dataset.map( |
| 83 | num_parallel_calls=tf.data.AUTOTUNE, | 112 | self.preprocess_input, num_parallel_calls=tf.data.AUTOTUNE, deterministic=False) |
| 84 | deterministic=False) | 113 | parallel_batch_dataset = dataset.batch( |
| 85 | parallel_batch_dataset = tensor_slice_dataset.batch( | ||
| 86 | batch_size=batch_size, | 114 | batch_size=batch_size, |
| 87 | drop_remainder=True, | 115 | drop_remainder=True, |
| 88 | num_parallel_calls=tf.data.AUTOTUNE, | 116 | num_parallel_calls=tf.data.AUTOTUNE, |
| ... | @@ -113,28 +141,29 @@ class F3Classification(BaseModel): | ... | @@ -113,28 +141,29 @@ class F3Classification(BaseModel): |
| 113 | freeze = False | 141 | freeze = False |
| 114 | return model | 142 | return model |
| 115 | 143 | ||
| 116 | def train(self, dataset_dir, epoch, batch_size, ckpt_path, train_dir_name='train', validate_dir_name='test'): | 144 | def train(self, dataset_dir, epoch, batch_size, ckpt_path, history_save_path, |
| 117 | # model = self.load_model() | 145 | train_dir_name='train', validate_dir_name='test'): |
| 118 | # model.summary() | 146 | model = self.load_model() |
| 119 | # | 147 | model.summary() |
| 120 | # model.compile( | 148 | |
| 121 | # optimizer=optimizers.Adam(learning_rate=3e-4), | 149 | model.compile( |
| 122 | # loss=tfa.losses.SigmoidFocalCrossEntropy(), | 150 | optimizer=optimizers.Adam(learning_rate=3e-4), |
| 123 | # metrics=['accuracy', ], | 151 | loss=tfa.losses.SigmoidFocalCrossEntropy(), |
| 124 | # | 152 | metrics=['accuracy', ], |
| 125 | # loss_weights=None, | 153 | |
| 126 | # weighted_metrics=None, | 154 | loss_weights=None, |
| 127 | # run_eagerly=None, | 155 | weighted_metrics=None, |
| 128 | # steps_per_execution=None, | 156 | run_eagerly=None, |
| 129 | # jit_compile=None, | 157 | steps_per_execution=None, |
| 130 | # ) | 158 | jit_compile=None, |
| 159 | ) | ||
| 131 | 160 | ||
| 132 | train_dataset = self.load_dataset( | 161 | train_dataset = self.load_dataset( |
| 133 | dataset_dir=os.path.join(dataset_dir, train_dir_name), | 162 | dataset_dir=os.path.join(dataset_dir, train_dir_name), |
| 134 | name=train_dir_name, | 163 | name=train_dir_name, |
| 135 | batch_size=batch_size, | 164 | batch_size=batch_size, |
| 136 | augmentation_methods=[], | 165 | # augmentation_methods=[], |
| 137 | # augmentation_methods=['random_rgb_2_bgr', 'random_grayscale_expand'], | 166 | augmentation_methods=['random_rgb_2_bgr', 'random_grayscale_expand'], |
| 138 | ) | 167 | ) |
| 139 | validate_dataset = self.load_dataset( | 168 | validate_dataset = self.load_dataset( |
| 140 | dataset_dir=os.path.join(dataset_dir, validate_dir_name), | 169 | dataset_dir=os.path.join(dataset_dir, validate_dir_name), |
| ... | @@ -143,46 +172,17 @@ class F3Classification(BaseModel): | ... | @@ -143,46 +172,17 @@ class F3Classification(BaseModel): |
| 143 | augmentation_methods=[] | 172 | augmentation_methods=[] |
| 144 | ) | 173 | ) |
| 145 | 174 | ||
| 146 | # ckpt_callback = callbacks.ModelCheckpoint(ckpt_path, save_best_only=True) | 175 | ckpt_callback = callbacks.ModelCheckpoint(ckpt_path, save_best_only=True) |
| 147 | # | 176 | |
| 148 | # history = model.fit( | 177 | history = model.fit( |
| 149 | # train_dataset, | 178 | train_dataset, |
| 150 | # epochs=epoch, | 179 | epochs=epoch, |
| 151 | # validation_data=validate_dataset, | 180 | validation_data=validate_dataset, |
| 152 | # callbacks=[ckpt_callback, ], | 181 | callbacks=[ckpt_callback, ], |
| 153 | # ) | 182 | ) |
| 154 | # | 183 | |
| 155 | # acc = history.history['accuracy'] | 184 | self.history_save(history, history_save_path) |
| 156 | # val_acc = history.history['val_accuracy'] | ||
| 157 | # | ||
| 158 | # loss = history.history['loss'] | ||
| 159 | # val_loss = history.history['val_loss'] | ||
| 160 | # | ||
| 161 | # plt.figure(figsize=(8, 8)) | ||
| 162 | # plt.subplot(2, 1, 1) | ||
| 163 | # plt.plot(acc, label='Training Accuracy') | ||
| 164 | # plt.plot(val_acc, label='Validation Accuracy') | ||
| 165 | # plt.legend(loc='lower right') | ||
| 166 | # plt.ylabel('Accuracy') | ||
| 167 | # plt.ylim([min(plt.ylim()), 1]) | ||
| 168 | # plt.title('Training and Validation Accuracy') | ||
| 169 | # | ||
| 170 | # plt.subplot(2, 1, 2) | ||
| 171 | # plt.plot(loss, label='Training Loss') | ||
| 172 | # plt.plot(val_loss, label='Validation Loss') | ||
| 173 | # plt.legend(loc='upper right') | ||
| 174 | # plt.ylabel('Cross Entropy') | ||
| 175 | # plt.ylim([0, 1.0]) | ||
| 176 | # plt.title('Training and Validation Loss') | ||
| 177 | # plt.xlabel('epoch') | ||
| 178 | # plt.show() | ||
| 179 | 185 | ||
| 180 | def test(self): | 186 | def test(self): |
| 181 | print(self.class_label_map) | 187 | print(self.class_label_map) |
| 182 | print(self.class_count) | 188 | print(self.class_count) |
| 183 | # path = '/home/zwq/data/data_224/train/银行卡/bc_1.jpg' | ||
| 184 | # label = 5 | ||
| 185 | # image, label = self.load_image(path, label) | ||
| 186 | # print(image.shape) | ||
| 187 | # image, label = self.preprocess_input(image, label) | ||
| 188 | # print(image.shape) | ... | ... |
-
Please register or sign in to post a comment