classification train
Showing
4 changed files
with
109 additions
and
101 deletions
... | @@ -10,6 +10,7 @@ class Retriever: | ... | @@ -10,6 +10,7 @@ class Retriever: |
10 | self.key_text_set = self.get_key_text_set(target_fields) | 10 | self.key_text_set = self.get_key_text_set(target_fields) |
11 | 11 | ||
12 | def get_key_text_set(self, target_fields): | 12 | def get_key_text_set(self, target_fields): |
13 | # 关键词集合 | ||
13 | key_text_set = set() | 14 | key_text_set = set() |
14 | for key_text_list in target_fields[self.keys_str].values(): | 15 | for key_text_list in target_fields[self.keys_str].values(): |
15 | for key_text, _, _ in key_text_list: | 16 | for key_text, _, _ in key_text_list: |
... | @@ -18,11 +19,13 @@ class Retriever: | ... | @@ -18,11 +19,13 @@ class Retriever: |
18 | 19 | ||
19 | @staticmethod | 20 | @staticmethod |
20 | def key_top1(coordinates_list, key_coordinates): | 21 | def key_top1(coordinates_list, key_coordinates): |
22 | # 关键词查找方向:最上面 | ||
21 | coordinates_list.sort(key=lambda x: x[1]) | 23 | coordinates_list.sort(key=lambda x: x[1]) |
22 | return coordinates_list[0] | 24 | return coordinates_list[0] |
23 | 25 | ||
24 | @staticmethod | 26 | @staticmethod |
25 | def key_right(coordinates_list, key_coordinates, top_padding, bottom_padding): | 27 | def key_right(coordinates_list, key_coordinates, top_padding, bottom_padding): |
28 | # 关键词查找方向:右侧 | ||
26 | if len(coordinates_list) == 1: | 29 | if len(coordinates_list) == 1: |
27 | return coordinates_list[0] | 30 | return coordinates_list[0] |
28 | height = key_coordinates[-1] - key_coordinates[1] | 31 | height = key_coordinates[-1] - key_coordinates[1] |
... | @@ -41,6 +44,7 @@ class Retriever: | ... | @@ -41,6 +44,7 @@ class Retriever: |
41 | 44 | ||
42 | @staticmethod | 45 | @staticmethod |
43 | def value_right(go_res, key_coordinates, top_padding, bottom_padding): | 46 | def value_right(go_res, key_coordinates, top_padding, bottom_padding): |
47 | # 字段值查找方向:右侧 | ||
44 | height = key_coordinates[-1] - key_coordinates[1] | 48 | height = key_coordinates[-1] - key_coordinates[1] |
45 | y_min = key_coordinates[1] - (top_padding * height) | 49 | y_min = key_coordinates[1] - (top_padding * height) |
46 | y_max = key_coordinates[-1] + (bottom_padding * height) | 50 | y_max = key_coordinates[-1] + (bottom_padding * height) |
... | @@ -57,6 +61,7 @@ class Retriever: | ... | @@ -57,6 +61,7 @@ class Retriever: |
57 | 61 | ||
58 | @staticmethod | 62 | @staticmethod |
59 | def value_under(go_res, key_coordinates, left_padding, right_padding): | 63 | def value_under(go_res, key_coordinates, left_padding, right_padding): |
64 | # 字段值查找方向:下方 | ||
60 | width = key_coordinates[2] - key_coordinates[0] | 65 | width = key_coordinates[2] - key_coordinates[0] |
61 | x_min = key_coordinates[0] - (width * left_padding) | 66 | x_min = key_coordinates[0] - (width * left_padding) |
62 | x_max = key_coordinates[2] + (width * right_padding) | 67 | x_max = key_coordinates[2] + (width * right_padding) | ... | ... |
... | @@ -9,7 +9,8 @@ class BaseModel: | ... | @@ -9,7 +9,8 @@ class BaseModel: |
9 | """ | 9 | """ |
10 | raise NotImplementedError(".load() must be overridden.") | 10 | raise NotImplementedError(".load() must be overridden.") |
11 | 11 | ||
12 | def train(self, dataset_dir, epoch, batch_size, ckpt_path, train_dir_name='train', validate_dir_name='test'): | 12 | def train(self, dataset_dir, epoch, batch_size, ckpt_path, history_save_path, |
13 | train_dir_name='train', validate_dir_name='test'): | ||
13 | """ | 14 | """ |
14 | Model training process | 15 | Model training process |
15 | """ | 16 | """ | ... | ... |
... | @@ -14,10 +14,12 @@ if __name__ == '__main__': | ... | @@ -14,10 +14,12 @@ if __name__ == '__main__': |
14 | 14 | ||
15 | # m.test() | 15 | # m.test() |
16 | 16 | ||
17 | dataset_dir = '/home/zwq/data/data_224' | 17 | dataset_dir = '/home/zwq/data/data_224_f3' |
18 | ckpt_path = os.path.join(base_dir, 'ckpt_{0}.h5'.format(datetime.now().strftime('%Y-%m-%d_%H:%M:%S'))) | 18 | ckpt_path = os.path.join(base_dir, 'ckpt_{0}.h5'.format(datetime.now().strftime('%Y-%m-%d_%H:%M:%S'))) |
19 | history_save_path = os.path.join(base_dir, 'history_{0}.jpg'.format(datetime.now().strftime('%Y-%m-%d_%H:%M:%S'))) | ||
19 | epoch = 100 | 20 | epoch = 100 |
20 | batch_size = 128 | 21 | batch_size = 128 |
21 | 22 | ||
22 | m.train(dataset_dir, epoch, batch_size, ckpt_path, train_dir_name='train', validate_dir_name='test') | 23 | m.train(dataset_dir, epoch, batch_size, ckpt_path, history_save_path, |
24 | train_dir_name='train', validate_dir_name='test') | ||
23 | 25 | ... | ... |
... | @@ -9,37 +9,6 @@ import matplotlib.pyplot as plt | ... | @@ -9,37 +9,6 @@ import matplotlib.pyplot as plt |
9 | from base_class import BaseModel | 9 | from base_class import BaseModel |
10 | 10 | ||
11 | 11 | ||
12 | @tf.function | ||
13 | def random_rgb_2_bgr(image, label): | ||
14 | if random.random() > 0.5: | ||
15 | return image, label | ||
16 | image = image[:, :, ::-1] | ||
17 | return image, label | ||
18 | |||
19 | |||
20 | @tf.function | ||
21 | def random_grayscale_expand(image, label): | ||
22 | if random.random() > 0.1: | ||
23 | return image, label | ||
24 | image = tf.image.rgb_to_grayscale(image) | ||
25 | image = tf.image.grayscale_to_rgb(image) | ||
26 | return image, label | ||
27 | |||
28 | |||
29 | @tf.function | ||
30 | def load_image(image_path, label): | ||
31 | image = tf.io.read_file(image_path) | ||
32 | image = tf.image.decode_image(image, channels=3) | ||
33 | return image, label | ||
34 | |||
35 | |||
36 | @tf.function | ||
37 | def preprocess_input(image, label): | ||
38 | image = tf.image.resize(image, [224, 224]) | ||
39 | image = applications.mobilenet_v2.preprocess_input(image) | ||
40 | return image, label | ||
41 | |||
42 | |||
43 | class F3Classification(BaseModel): | 12 | class F3Classification(BaseModel): |
44 | 13 | ||
45 | def __init__(self, class_name_list, class_other_first, *args, **kwargs): | 14 | def __init__(self, class_name_list, class_other_first, *args, **kwargs): |
... | @@ -48,6 +17,34 @@ class F3Classification(BaseModel): | ... | @@ -48,6 +17,34 @@ class F3Classification(BaseModel): |
48 | self.class_label_map = self.get_class_label_map(class_name_list, class_other_first) | 17 | self.class_label_map = self.get_class_label_map(class_name_list, class_other_first) |
49 | 18 | ||
50 | @staticmethod | 19 | @staticmethod |
20 | def history_save(history, save_path): | ||
21 | acc = history.history['accuracy'] | ||
22 | val_acc = history.history['val_accuracy'] | ||
23 | |||
24 | loss = history.history['loss'] | ||
25 | val_loss = history.history['val_loss'] | ||
26 | |||
27 | plt.figure(figsize=(8, 8)) | ||
28 | plt.subplot(2, 1, 1) | ||
29 | plt.plot(acc, label='Training Accuracy') | ||
30 | plt.plot(val_acc, label='Validation Accuracy') | ||
31 | plt.legend(loc='lower right') | ||
32 | plt.ylabel('Accuracy') | ||
33 | plt.ylim([min(plt.ylim()), 1]) | ||
34 | plt.title('Training and Validation Accuracy') | ||
35 | |||
36 | plt.subplot(2, 1, 2) | ||
37 | plt.plot(loss, label='Training Loss') | ||
38 | plt.plot(val_loss, label='Validation Loss') | ||
39 | plt.legend(loc='upper right') | ||
40 | plt.ylabel('Cross Entropy') | ||
41 | plt.ylim([0, 1.0]) | ||
42 | plt.title('Training and Validation Loss') | ||
43 | plt.xlabel('epoch') | ||
44 | # plt.show() | ||
45 | plt.savefig(save_path) | ||
46 | |||
47 | @staticmethod | ||
51 | def get_class_label_map(class_name_list, class_other_first=False): | 48 | def get_class_label_map(class_name_list, class_other_first=False): |
52 | return {cn_name: idx - 1 if class_other_first else idx for idx, cn_name in enumerate(class_name_list)} | 49 | return {cn_name: idx - 1 if class_other_first else idx for idx, cn_name in enumerate(class_name_list)} |
53 | 50 | ||
... | @@ -68,21 +65,52 @@ class F3Classification(BaseModel): | ... | @@ -68,21 +65,52 @@ class F3Classification(BaseModel): |
68 | label_list.append(tf.one_hot(label, depth=self.class_count)) | 65 | label_list.append(tf.one_hot(label, depth=self.class_count)) |
69 | return image_path_list, label_list | 66 | return image_path_list, label_list |
70 | 67 | ||
68 | @staticmethod | ||
69 | # @tf.function | ||
70 | def random_rgb_2_bgr(image, label): | ||
71 | if random.random() > 0.2: | ||
72 | return image, label | ||
73 | image = image[:, :, ::-1] | ||
74 | return image, label | ||
75 | |||
76 | @staticmethod | ||
77 | # @tf.function | ||
78 | def random_grayscale_expand(image, label): | ||
79 | if random.random() > 0.1: | ||
80 | return image, label | ||
81 | image = tf.image.rgb_to_grayscale(image) | ||
82 | image = tf.image.grayscale_to_rgb(image) | ||
83 | return image, label | ||
84 | |||
85 | @staticmethod | ||
86 | # @tf.function | ||
87 | def load_image(image_path, label): | ||
88 | image = tf.io.read_file(image_path) | ||
89 | # image = tf.image.decode_image(image, channels=3) # TODO 为什么不行 | ||
90 | image = tf.image.decode_png(image, channels=3) | ||
91 | return image, label | ||
92 | |||
93 | @staticmethod | ||
94 | # @tf.function | ||
95 | def preprocess_input(image, label): | ||
96 | image = tf.image.resize(image, [224, 224]) | ||
97 | image = applications.mobilenet_v2.preprocess_input(image) | ||
98 | return image, label | ||
99 | |||
71 | def load_dataset(self, dataset_dir, name, batch_size=128, augmentation_methods=[]): | 100 | def load_dataset(self, dataset_dir, name, batch_size=128, augmentation_methods=[]): |
72 | image_and_label_list = self.get_image_label_list(dataset_dir) | 101 | image_and_label_list = self.get_image_label_list(dataset_dir) |
73 | tensor_slice_dataset = tf.data.Dataset.from_tensor_slices(image_and_label_list, name=name) | 102 | tensor_slice_dataset = tf.data.Dataset.from_tensor_slices(image_and_label_list, name=name) |
74 | tensor_slice_dataset.shuffle(len(image_and_label_list[0]), reshuffle_each_iteration=True) | 103 | dataset = tensor_slice_dataset.shuffle(len(image_and_label_list[0]), reshuffle_each_iteration=True) |
75 | tensor_slice_dataset.map(load_image, | 104 | dataset = dataset.map( |
76 | num_parallel_calls=tf.data.AUTOTUNE, | 105 | self.load_image, num_parallel_calls=tf.data.AUTOTUNE, deterministic=False) |
77 | deterministic=False) | ||
78 | for augmentation_method in augmentation_methods: | 106 | for augmentation_method in augmentation_methods: |
79 | tensor_slice_dataset.map(getattr(self, augmentation_method), | 107 | dataset = dataset.map( |
80 | num_parallel_calls=tf.data.AUTOTUNE, | 108 | getattr(self, augmentation_method), |
81 | deterministic=False) | 109 | num_parallel_calls=tf.data.AUTOTUNE, |
82 | tensor_slice_dataset.map(preprocess_input, | 110 | deterministic=False) |
83 | num_parallel_calls=tf.data.AUTOTUNE, | 111 | dataset = dataset.map( |
84 | deterministic=False) | 112 | self.preprocess_input, num_parallel_calls=tf.data.AUTOTUNE, deterministic=False) |
85 | parallel_batch_dataset = tensor_slice_dataset.batch( | 113 | parallel_batch_dataset = dataset.batch( |
86 | batch_size=batch_size, | 114 | batch_size=batch_size, |
87 | drop_remainder=True, | 115 | drop_remainder=True, |
88 | num_parallel_calls=tf.data.AUTOTUNE, | 116 | num_parallel_calls=tf.data.AUTOTUNE, |
... | @@ -113,28 +141,29 @@ class F3Classification(BaseModel): | ... | @@ -113,28 +141,29 @@ class F3Classification(BaseModel): |
113 | freeze = False | 141 | freeze = False |
114 | return model | 142 | return model |
115 | 143 | ||
116 | def train(self, dataset_dir, epoch, batch_size, ckpt_path, train_dir_name='train', validate_dir_name='test'): | 144 | def train(self, dataset_dir, epoch, batch_size, ckpt_path, history_save_path, |
117 | # model = self.load_model() | 145 | train_dir_name='train', validate_dir_name='test'): |
118 | # model.summary() | 146 | model = self.load_model() |
119 | # | 147 | model.summary() |
120 | # model.compile( | 148 | |
121 | # optimizer=optimizers.Adam(learning_rate=3e-4), | 149 | model.compile( |
122 | # loss=tfa.losses.SigmoidFocalCrossEntropy(), | 150 | optimizer=optimizers.Adam(learning_rate=3e-4), |
123 | # metrics=['accuracy', ], | 151 | loss=tfa.losses.SigmoidFocalCrossEntropy(), |
124 | # | 152 | metrics=['accuracy', ], |
125 | # loss_weights=None, | 153 | |
126 | # weighted_metrics=None, | 154 | loss_weights=None, |
127 | # run_eagerly=None, | 155 | weighted_metrics=None, |
128 | # steps_per_execution=None, | 156 | run_eagerly=None, |
129 | # jit_compile=None, | 157 | steps_per_execution=None, |
130 | # ) | 158 | jit_compile=None, |
159 | ) | ||
131 | 160 | ||
132 | train_dataset = self.load_dataset( | 161 | train_dataset = self.load_dataset( |
133 | dataset_dir=os.path.join(dataset_dir, train_dir_name), | 162 | dataset_dir=os.path.join(dataset_dir, train_dir_name), |
134 | name=train_dir_name, | 163 | name=train_dir_name, |
135 | batch_size=batch_size, | 164 | batch_size=batch_size, |
136 | augmentation_methods=[], | 165 | # augmentation_methods=[], |
137 | # augmentation_methods=['random_rgb_2_bgr', 'random_grayscale_expand'], | 166 | augmentation_methods=['random_rgb_2_bgr', 'random_grayscale_expand'], |
138 | ) | 167 | ) |
139 | validate_dataset = self.load_dataset( | 168 | validate_dataset = self.load_dataset( |
140 | dataset_dir=os.path.join(dataset_dir, validate_dir_name), | 169 | dataset_dir=os.path.join(dataset_dir, validate_dir_name), |
... | @@ -143,46 +172,17 @@ class F3Classification(BaseModel): | ... | @@ -143,46 +172,17 @@ class F3Classification(BaseModel): |
143 | augmentation_methods=[] | 172 | augmentation_methods=[] |
144 | ) | 173 | ) |
145 | 174 | ||
146 | # ckpt_callback = callbacks.ModelCheckpoint(ckpt_path, save_best_only=True) | 175 | ckpt_callback = callbacks.ModelCheckpoint(ckpt_path, save_best_only=True) |
147 | # | 176 | |
148 | # history = model.fit( | 177 | history = model.fit( |
149 | # train_dataset, | 178 | train_dataset, |
150 | # epochs=epoch, | 179 | epochs=epoch, |
151 | # validation_data=validate_dataset, | 180 | validation_data=validate_dataset, |
152 | # callbacks=[ckpt_callback, ], | 181 | callbacks=[ckpt_callback, ], |
153 | # ) | 182 | ) |
154 | # | 183 | |
155 | # acc = history.history['accuracy'] | 184 | self.history_save(history, history_save_path) |
156 | # val_acc = history.history['val_accuracy'] | ||
157 | # | ||
158 | # loss = history.history['loss'] | ||
159 | # val_loss = history.history['val_loss'] | ||
160 | # | ||
161 | # plt.figure(figsize=(8, 8)) | ||
162 | # plt.subplot(2, 1, 1) | ||
163 | # plt.plot(acc, label='Training Accuracy') | ||
164 | # plt.plot(val_acc, label='Validation Accuracy') | ||
165 | # plt.legend(loc='lower right') | ||
166 | # plt.ylabel('Accuracy') | ||
167 | # plt.ylim([min(plt.ylim()), 1]) | ||
168 | # plt.title('Training and Validation Accuracy') | ||
169 | # | ||
170 | # plt.subplot(2, 1, 2) | ||
171 | # plt.plot(loss, label='Training Loss') | ||
172 | # plt.plot(val_loss, label='Validation Loss') | ||
173 | # plt.legend(loc='upper right') | ||
174 | # plt.ylabel('Cross Entropy') | ||
175 | # plt.ylim([0, 1.0]) | ||
176 | # plt.title('Training and Validation Loss') | ||
177 | # plt.xlabel('epoch') | ||
178 | # plt.show() | ||
179 | 185 | ||
180 | def test(self): | 186 | def test(self): |
181 | print(self.class_label_map) | 187 | print(self.class_label_map) |
182 | print(self.class_count) | 188 | print(self.class_count) |
183 | # path = '/home/zwq/data/data_224/train/银行卡/bc_1.jpg' | ||
184 | # label = 5 | ||
185 | # image, label = self.load_image(path, label) | ||
186 | # print(image.shape) | ||
187 | # image, label = self.preprocess_input(image, label) | ||
188 | # print(image.shape) | ... | ... |
-
Please register or sign in to post a comment