37a9d47e by 周伟奇

classification train

1 parent cbeebc6d
...@@ -10,6 +10,7 @@ class Retriever: ...@@ -10,6 +10,7 @@ class Retriever:
10 self.key_text_set = self.get_key_text_set(target_fields) 10 self.key_text_set = self.get_key_text_set(target_fields)
11 11
12 def get_key_text_set(self, target_fields): 12 def get_key_text_set(self, target_fields):
13 # 关键词集合
13 key_text_set = set() 14 key_text_set = set()
14 for key_text_list in target_fields[self.keys_str].values(): 15 for key_text_list in target_fields[self.keys_str].values():
15 for key_text, _, _ in key_text_list: 16 for key_text, _, _ in key_text_list:
...@@ -18,11 +19,13 @@ class Retriever: ...@@ -18,11 +19,13 @@ class Retriever:
18 19
19 @staticmethod 20 @staticmethod
20 def key_top1(coordinates_list, key_coordinates): 21 def key_top1(coordinates_list, key_coordinates):
22 # 关键词查找方向:最上面
21 coordinates_list.sort(key=lambda x: x[1]) 23 coordinates_list.sort(key=lambda x: x[1])
22 return coordinates_list[0] 24 return coordinates_list[0]
23 25
24 @staticmethod 26 @staticmethod
25 def key_right(coordinates_list, key_coordinates, top_padding, bottom_padding): 27 def key_right(coordinates_list, key_coordinates, top_padding, bottom_padding):
28 # 关键词查找方向:右侧
26 if len(coordinates_list) == 1: 29 if len(coordinates_list) == 1:
27 return coordinates_list[0] 30 return coordinates_list[0]
28 height = key_coordinates[-1] - key_coordinates[1] 31 height = key_coordinates[-1] - key_coordinates[1]
...@@ -41,6 +44,7 @@ class Retriever: ...@@ -41,6 +44,7 @@ class Retriever:
41 44
42 @staticmethod 45 @staticmethod
43 def value_right(go_res, key_coordinates, top_padding, bottom_padding): 46 def value_right(go_res, key_coordinates, top_padding, bottom_padding):
47 # 字段值查找方向:右侧
44 height = key_coordinates[-1] - key_coordinates[1] 48 height = key_coordinates[-1] - key_coordinates[1]
45 y_min = key_coordinates[1] - (top_padding * height) 49 y_min = key_coordinates[1] - (top_padding * height)
46 y_max = key_coordinates[-1] + (bottom_padding * height) 50 y_max = key_coordinates[-1] + (bottom_padding * height)
...@@ -57,6 +61,7 @@ class Retriever: ...@@ -57,6 +61,7 @@ class Retriever:
57 61
58 @staticmethod 62 @staticmethod
59 def value_under(go_res, key_coordinates, left_padding, right_padding): 63 def value_under(go_res, key_coordinates, left_padding, right_padding):
64 # 字段值查找方向:下方
60 width = key_coordinates[2] - key_coordinates[0] 65 width = key_coordinates[2] - key_coordinates[0]
61 x_min = key_coordinates[0] - (width * left_padding) 66 x_min = key_coordinates[0] - (width * left_padding)
62 x_max = key_coordinates[2] + (width * right_padding) 67 x_max = key_coordinates[2] + (width * right_padding)
......
...@@ -9,7 +9,8 @@ class BaseModel: ...@@ -9,7 +9,8 @@ class BaseModel:
9 """ 9 """
10 raise NotImplementedError(".load() must be overridden.") 10 raise NotImplementedError(".load() must be overridden.")
11 11
12 def train(self, dataset_dir, epoch, batch_size, ckpt_path, train_dir_name='train', validate_dir_name='test'): 12 def train(self, dataset_dir, epoch, batch_size, ckpt_path, history_save_path,
13 train_dir_name='train', validate_dir_name='test'):
13 """ 14 """
14 Model training process 15 Model training process
15 """ 16 """
......
...@@ -14,10 +14,12 @@ if __name__ == '__main__': ...@@ -14,10 +14,12 @@ if __name__ == '__main__':
14 14
15 # m.test() 15 # m.test()
16 16
17 dataset_dir = '/home/zwq/data/data_224' 17 dataset_dir = '/home/zwq/data/data_224_f3'
18 ckpt_path = os.path.join(base_dir, 'ckpt_{0}.h5'.format(datetime.now().strftime('%Y-%m-%d_%H:%M:%S'))) 18 ckpt_path = os.path.join(base_dir, 'ckpt_{0}.h5'.format(datetime.now().strftime('%Y-%m-%d_%H:%M:%S')))
19 history_save_path = os.path.join(base_dir, 'history_{0}.jpg'.format(datetime.now().strftime('%Y-%m-%d_%H:%M:%S')))
19 epoch = 100 20 epoch = 100
20 batch_size = 128 21 batch_size = 128
21 22
22 m.train(dataset_dir, epoch, batch_size, ckpt_path, train_dir_name='train', validate_dir_name='test') 23 m.train(dataset_dir, epoch, batch_size, ckpt_path, history_save_path,
24 train_dir_name='train', validate_dir_name='test')
23 25
......
...@@ -9,37 +9,6 @@ import matplotlib.pyplot as plt ...@@ -9,37 +9,6 @@ import matplotlib.pyplot as plt
9 from base_class import BaseModel 9 from base_class import BaseModel
10 10
11 11
12 @tf.function
13 def random_rgb_2_bgr(image, label):
14 if random.random() > 0.5:
15 return image, label
16 image = image[:, :, ::-1]
17 return image, label
18
19
20 @tf.function
21 def random_grayscale_expand(image, label):
22 if random.random() > 0.1:
23 return image, label
24 image = tf.image.rgb_to_grayscale(image)
25 image = tf.image.grayscale_to_rgb(image)
26 return image, label
27
28
29 @tf.function
30 def load_image(image_path, label):
31 image = tf.io.read_file(image_path)
32 image = tf.image.decode_image(image, channels=3)
33 return image, label
34
35
36 @tf.function
37 def preprocess_input(image, label):
38 image = tf.image.resize(image, [224, 224])
39 image = applications.mobilenet_v2.preprocess_input(image)
40 return image, label
41
42
43 class F3Classification(BaseModel): 12 class F3Classification(BaseModel):
44 13
45 def __init__(self, class_name_list, class_other_first, *args, **kwargs): 14 def __init__(self, class_name_list, class_other_first, *args, **kwargs):
...@@ -48,6 +17,34 @@ class F3Classification(BaseModel): ...@@ -48,6 +17,34 @@ class F3Classification(BaseModel):
48 self.class_label_map = self.get_class_label_map(class_name_list, class_other_first) 17 self.class_label_map = self.get_class_label_map(class_name_list, class_other_first)
49 18
50 @staticmethod 19 @staticmethod
20 def history_save(history, save_path):
21 acc = history.history['accuracy']
22 val_acc = history.history['val_accuracy']
23
24 loss = history.history['loss']
25 val_loss = history.history['val_loss']
26
27 plt.figure(figsize=(8, 8))
28 plt.subplot(2, 1, 1)
29 plt.plot(acc, label='Training Accuracy')
30 plt.plot(val_acc, label='Validation Accuracy')
31 plt.legend(loc='lower right')
32 plt.ylabel('Accuracy')
33 plt.ylim([min(plt.ylim()), 1])
34 plt.title('Training and Validation Accuracy')
35
36 plt.subplot(2, 1, 2)
37 plt.plot(loss, label='Training Loss')
38 plt.plot(val_loss, label='Validation Loss')
39 plt.legend(loc='upper right')
40 plt.ylabel('Cross Entropy')
41 plt.ylim([0, 1.0])
42 plt.title('Training and Validation Loss')
43 plt.xlabel('epoch')
44 # plt.show()
45 plt.savefig(save_path)
46
47 @staticmethod
51 def get_class_label_map(class_name_list, class_other_first=False): 48 def get_class_label_map(class_name_list, class_other_first=False):
52 return {cn_name: idx - 1 if class_other_first else idx for idx, cn_name in enumerate(class_name_list)} 49 return {cn_name: idx - 1 if class_other_first else idx for idx, cn_name in enumerate(class_name_list)}
53 50
...@@ -68,21 +65,52 @@ class F3Classification(BaseModel): ...@@ -68,21 +65,52 @@ class F3Classification(BaseModel):
68 label_list.append(tf.one_hot(label, depth=self.class_count)) 65 label_list.append(tf.one_hot(label, depth=self.class_count))
69 return image_path_list, label_list 66 return image_path_list, label_list
70 67
68 @staticmethod
69 # @tf.function
70 def random_rgb_2_bgr(image, label):
71 if random.random() > 0.2:
72 return image, label
73 image = image[:, :, ::-1]
74 return image, label
75
76 @staticmethod
77 # @tf.function
78 def random_grayscale_expand(image, label):
79 if random.random() > 0.1:
80 return image, label
81 image = tf.image.rgb_to_grayscale(image)
82 image = tf.image.grayscale_to_rgb(image)
83 return image, label
84
85 @staticmethod
86 # @tf.function
87 def load_image(image_path, label):
88 image = tf.io.read_file(image_path)
89 # image = tf.image.decode_image(image, channels=3) # TODO 为什么不行
90 image = tf.image.decode_png(image, channels=3)
91 return image, label
92
93 @staticmethod
94 # @tf.function
95 def preprocess_input(image, label):
96 image = tf.image.resize(image, [224, 224])
97 image = applications.mobilenet_v2.preprocess_input(image)
98 return image, label
99
71 def load_dataset(self, dataset_dir, name, batch_size=128, augmentation_methods=[]): 100 def load_dataset(self, dataset_dir, name, batch_size=128, augmentation_methods=[]):
72 image_and_label_list = self.get_image_label_list(dataset_dir) 101 image_and_label_list = self.get_image_label_list(dataset_dir)
73 tensor_slice_dataset = tf.data.Dataset.from_tensor_slices(image_and_label_list, name=name) 102 tensor_slice_dataset = tf.data.Dataset.from_tensor_slices(image_and_label_list, name=name)
74 tensor_slice_dataset.shuffle(len(image_and_label_list[0]), reshuffle_each_iteration=True) 103 dataset = tensor_slice_dataset.shuffle(len(image_and_label_list[0]), reshuffle_each_iteration=True)
75 tensor_slice_dataset.map(load_image, 104 dataset = dataset.map(
76 num_parallel_calls=tf.data.AUTOTUNE, 105 self.load_image, num_parallel_calls=tf.data.AUTOTUNE, deterministic=False)
77 deterministic=False)
78 for augmentation_method in augmentation_methods: 106 for augmentation_method in augmentation_methods:
79 tensor_slice_dataset.map(getattr(self, augmentation_method), 107 dataset = dataset.map(
80 num_parallel_calls=tf.data.AUTOTUNE, 108 getattr(self, augmentation_method),
81 deterministic=False) 109 num_parallel_calls=tf.data.AUTOTUNE,
82 tensor_slice_dataset.map(preprocess_input, 110 deterministic=False)
83 num_parallel_calls=tf.data.AUTOTUNE, 111 dataset = dataset.map(
84 deterministic=False) 112 self.preprocess_input, num_parallel_calls=tf.data.AUTOTUNE, deterministic=False)
85 parallel_batch_dataset = tensor_slice_dataset.batch( 113 parallel_batch_dataset = dataset.batch(
86 batch_size=batch_size, 114 batch_size=batch_size,
87 drop_remainder=True, 115 drop_remainder=True,
88 num_parallel_calls=tf.data.AUTOTUNE, 116 num_parallel_calls=tf.data.AUTOTUNE,
...@@ -113,28 +141,29 @@ class F3Classification(BaseModel): ...@@ -113,28 +141,29 @@ class F3Classification(BaseModel):
113 freeze = False 141 freeze = False
114 return model 142 return model
115 143
116 def train(self, dataset_dir, epoch, batch_size, ckpt_path, train_dir_name='train', validate_dir_name='test'): 144 def train(self, dataset_dir, epoch, batch_size, ckpt_path, history_save_path,
117 # model = self.load_model() 145 train_dir_name='train', validate_dir_name='test'):
118 # model.summary() 146 model = self.load_model()
119 # 147 model.summary()
120 # model.compile( 148
121 # optimizer=optimizers.Adam(learning_rate=3e-4), 149 model.compile(
122 # loss=tfa.losses.SigmoidFocalCrossEntropy(), 150 optimizer=optimizers.Adam(learning_rate=3e-4),
123 # metrics=['accuracy', ], 151 loss=tfa.losses.SigmoidFocalCrossEntropy(),
124 # 152 metrics=['accuracy', ],
125 # loss_weights=None, 153
126 # weighted_metrics=None, 154 loss_weights=None,
127 # run_eagerly=None, 155 weighted_metrics=None,
128 # steps_per_execution=None, 156 run_eagerly=None,
129 # jit_compile=None, 157 steps_per_execution=None,
130 # ) 158 jit_compile=None,
159 )
131 160
132 train_dataset = self.load_dataset( 161 train_dataset = self.load_dataset(
133 dataset_dir=os.path.join(dataset_dir, train_dir_name), 162 dataset_dir=os.path.join(dataset_dir, train_dir_name),
134 name=train_dir_name, 163 name=train_dir_name,
135 batch_size=batch_size, 164 batch_size=batch_size,
136 augmentation_methods=[], 165 # augmentation_methods=[],
137 # augmentation_methods=['random_rgb_2_bgr', 'random_grayscale_expand'], 166 augmentation_methods=['random_rgb_2_bgr', 'random_grayscale_expand'],
138 ) 167 )
139 validate_dataset = self.load_dataset( 168 validate_dataset = self.load_dataset(
140 dataset_dir=os.path.join(dataset_dir, validate_dir_name), 169 dataset_dir=os.path.join(dataset_dir, validate_dir_name),
...@@ -143,46 +172,17 @@ class F3Classification(BaseModel): ...@@ -143,46 +172,17 @@ class F3Classification(BaseModel):
143 augmentation_methods=[] 172 augmentation_methods=[]
144 ) 173 )
145 174
146 # ckpt_callback = callbacks.ModelCheckpoint(ckpt_path, save_best_only=True) 175 ckpt_callback = callbacks.ModelCheckpoint(ckpt_path, save_best_only=True)
147 # 176
148 # history = model.fit( 177 history = model.fit(
149 # train_dataset, 178 train_dataset,
150 # epochs=epoch, 179 epochs=epoch,
151 # validation_data=validate_dataset, 180 validation_data=validate_dataset,
152 # callbacks=[ckpt_callback, ], 181 callbacks=[ckpt_callback, ],
153 # ) 182 )
154 # 183
155 # acc = history.history['accuracy'] 184 self.history_save(history, history_save_path)
156 # val_acc = history.history['val_accuracy']
157 #
158 # loss = history.history['loss']
159 # val_loss = history.history['val_loss']
160 #
161 # plt.figure(figsize=(8, 8))
162 # plt.subplot(2, 1, 1)
163 # plt.plot(acc, label='Training Accuracy')
164 # plt.plot(val_acc, label='Validation Accuracy')
165 # plt.legend(loc='lower right')
166 # plt.ylabel('Accuracy')
167 # plt.ylim([min(plt.ylim()), 1])
168 # plt.title('Training and Validation Accuracy')
169 #
170 # plt.subplot(2, 1, 2)
171 # plt.plot(loss, label='Training Loss')
172 # plt.plot(val_loss, label='Validation Loss')
173 # plt.legend(loc='upper right')
174 # plt.ylabel('Cross Entropy')
175 # plt.ylim([0, 1.0])
176 # plt.title('Training and Validation Loss')
177 # plt.xlabel('epoch')
178 # plt.show()
179 185
180 def test(self): 186 def test(self):
181 print(self.class_label_map) 187 print(self.class_label_map)
182 print(self.class_count) 188 print(self.class_count)
183 # path = '/home/zwq/data/data_224/train/银行卡/bc_1.jpg'
184 # label = 5
185 # image, label = self.load_image(path, label)
186 # print(image.shape)
187 # image, label = self.preprocess_input(image, label)
188 # print(image.shape)
......
Styling with Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!