add async test
Showing
9 changed files
with
205 additions
and
29 deletions
async_test/classification/__init__.py
0 → 100644
async_test/classification/const.py
0 → 100644
async_test/classification/model.py
0 → 100644
1 | import tensorflow as tf | ||
2 | |||
3 | |||
4 | class F3Classification: | ||
5 | |||
6 | def __init__(self, class_name_list, class_other_first, *args, **kwargs): | ||
7 | self.class_name_list = class_name_list | ||
8 | self.class_count = len(class_name_list) if not class_other_first else len(class_name_list) - 1 | ||
9 | self.model_name = 'classification_model' | ||
10 | self.signature_name = 'serving_default' | ||
11 | self.server_name = 'server_1' | ||
12 | |||
13 | @staticmethod | ||
14 | def preprocess_input(image): | ||
15 | image = tf.image.resize(image, [224, 224]) | ||
16 | image = tf.keras.applications.mobilenet_v2.preprocess_input(image) | ||
17 | input_images = tf.expand_dims(image, axis=0) | ||
18 | return input_images | ||
19 | |||
20 | def reprocess_output(self, outputs, thresholds=0.5): | ||
21 | for output in outputs: | ||
22 | idx = tf.math.argmax(output) | ||
23 | confidence = output[idx] | ||
24 | if confidence < thresholds: | ||
25 | idx = -1 | ||
26 | label = self.class_name_list[idx + 1] | ||
27 | break | ||
28 | |||
29 | res = { | ||
30 | 'label': label, | ||
31 | 'confidence': confidence | ||
32 | } | ||
33 | return res | ||
34 |
async_test/locustfile.py
0 → 100644
1 | import time | ||
2 | from locust import HttpUser, task, between, constant, tag | ||
3 | |||
4 | |||
5 | class QuickstartUser(HttpUser): | ||
6 | # wait_time = between(1, 5) | ||
7 | |||
8 | @tag('sync') | ||
9 | @task | ||
10 | def sync_test(self): | ||
11 | self.client.get("/sync") | ||
12 | |||
13 | @tag('async') | ||
14 | @task | ||
15 | def async_test(self): | ||
16 | self.client.get("/async") | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
async_test/requirements.txt
0 → 100644
async_test/server1.py
0 → 100644
1 | import asyncio | ||
2 | import time | ||
3 | |||
4 | from sanic import Sanic | ||
5 | from sanic.response import json | ||
6 | |||
7 | app = Sanic("async_test") | ||
8 | |||
9 | |||
10 | @app.get("/sync") | ||
11 | async def sync_handler(request): | ||
12 | time.sleep(2) | ||
13 | return json({'code': 1}) | ||
14 | |||
15 | |||
16 | @app.get("/async") | ||
17 | async def async_handler(request): | ||
18 | await asyncio.sleep(2) | ||
19 | return json({'code': 1}) | ||
20 | |||
21 | |||
22 | if __name__ == '__main__': | ||
23 | app.run(host='0.0.0.0', port=1337, workers=5) |
async_test/server2.py
0 → 100644
1 | import grpc | ||
2 | import cv2 | ||
3 | import numpy as np | ||
4 | import tensorflow as tf | ||
5 | from tensorflow_serving.apis import prediction_service_pb2_grpc, predict_pb2 | ||
6 | |||
7 | from sanic import Sanic | ||
8 | from sanic.response import json | ||
9 | |||
10 | from classification import classifier | ||
11 | |||
12 | app = Sanic("async_test") | ||
13 | |||
14 | # TODO 从配置文件读取 | ||
15 | tf_serving_settings = { | ||
16 | 'servers': { | ||
17 | 'server_1': { | ||
18 | 'host': '192.168.10.191', | ||
19 | 'port': '8500', | ||
20 | 'options': [ | ||
21 | ('grpc.max_send_message_length', 1000 * 1024 * 1024), | ||
22 | ('grpc.max_receive_message_length', 1000 * 1024 * 1024), | ||
23 | ], | ||
24 | } | ||
25 | }, | ||
26 | } | ||
27 | app.config.update(tf_serving_settings) | ||
28 | |||
29 | |||
30 | @app.post("/sync_classification") | ||
31 | async def sync_handler(request): | ||
32 | image = request.files.get("image") | ||
33 | img_array = np.frombuffer(image.body, np.uint8) | ||
34 | image = cv2.imdecode(img_array, cv2.IMREAD_COLOR) | ||
35 | input_images = classifier.preprocess_input(image) | ||
36 | |||
37 | # print(type(image)) | ||
38 | |||
39 | # See prediction_service.proto for gRPC request/response details. | ||
40 | request = predict_pb2.PredictRequest() | ||
41 | request.model_spec.name = classifier.model_name | ||
42 | request.model_spec.signature_name = classifier.signature_name | ||
43 | stub = getattr(app.ctx, classifier.server_name) | ||
44 | |||
45 | request.inputs['input_1'].CopyFrom(tf.make_tensor_proto(input_images)) | ||
46 | result = stub.Predict(request, 100.0) # 100 secs timeout | ||
47 | outputs = tf.make_ndarray(result.outputs['output']) | ||
48 | |||
49 | res = classifier.reprocess_output(outputs) | ||
50 | return json(res) | ||
51 | |||
52 | |||
53 | # @app.get("/async") | ||
54 | # async def async_handler(request): | ||
55 | # await asyncio.sleep(2) | ||
56 | # return json({'code': 1}) | ||
57 | |||
58 | |||
59 | @app.listener("before_server_start") | ||
60 | async def set_grpc_channel(app, loop): | ||
61 | for server_name, server_settings in app.config['servers'].items(): | ||
62 | channel = grpc.insecure_channel( | ||
63 | '{0}:{1}'.format(server_settings['host'], server_settings['port']), | ||
64 | options=server_settings.get('options')) | ||
65 | stub = prediction_service_pb2_grpc.PredictionServiceStub(channel) | ||
66 | setattr(app.ctx, server_name, stub) | ||
67 | |||
68 | |||
69 | if __name__ == '__main__': | ||
70 | app.run(host='0.0.0.0', port=6699, workers=5) |
... | @@ -9,5 +9,5 @@ classifier = F3Classification( | ... | @@ -9,5 +9,5 @@ classifier = F3Classification( |
9 | ) | 9 | ) |
10 | 10 | ||
11 | classifier.load_model(load_weights_path=os.path.join( | 11 | classifier.load_model(load_weights_path=os.path.join( |
12 | os.path.dirname(os.path.abspath(__file__)), 'ckpt_prod.h5')) | 12 | os.path.dirname(os.path.abspath(__file__)), 'ckpt_best_0.5.h5')) |
13 | 13 | ... | ... |
1 | import os | 1 | import os |
2 | import random | 2 | import random |
3 | import cv2 | ||
3 | import tensorflow as tf | 4 | import tensorflow as tf |
4 | import tensorflow_addons as tfa | 5 | import tensorflow_addons as tfa |
5 | 6 | ||
... | @@ -23,10 +24,9 @@ class F3Classification(BaseModel): | ... | @@ -23,10 +24,9 @@ class F3Classification(BaseModel): |
23 | self.model = None | 24 | self.model = None |
24 | 25 | ||
25 | @staticmethod | 26 | @staticmethod |
26 | def gpu_config(): | 27 | def gpu_config(gpu_idx=0): |
27 | gpus = tf.config.experimental.list_physical_devices(device_type='GPU') | 28 | gpus = tf.config.experimental.list_physical_devices(device_type='GPU') |
28 | # print(gpus) | 29 | tf.config.set_visible_devices(devices=gpus[gpu_idx], device_type='GPU') |
29 | tf.config.set_visible_devices(devices=gpus[1], device_type='GPU') | ||
30 | 30 | ||
31 | @staticmethod | 31 | @staticmethod |
32 | def get_class_label_map(class_name_list, class_other_first=False): | 32 | def get_class_label_map(class_name_list, class_other_first=False): |
... | @@ -53,15 +53,20 @@ class F3Classification(BaseModel): | ... | @@ -53,15 +53,20 @@ class F3Classification(BaseModel): |
53 | @staticmethod | 53 | @staticmethod |
54 | # @tf.function | 54 | # @tf.function |
55 | def random_rgb_2_bgr(image, label): | 55 | def random_rgb_2_bgr(image, label): |
56 | # 1/5 | 56 | # 1/2 |
57 | if random.random() < 0.1: | 57 | if random.random() < 0.5: |
58 | image = image[:, :, ::-1] | ||
59 | return image, label | ||
60 | |||
61 | @staticmethod | ||
62 | # @tf.function | ||
63 | def rgb_2_bgr(image, label): | ||
58 | image = image[:, :, ::-1] | 64 | image = image[:, :, ::-1] |
59 | return image, label | 65 | return image, label |
60 | 66 | ||
61 | @staticmethod | 67 | @staticmethod |
62 | # @tf.function | 68 | # @tf.function |
63 | def random_grayscale_expand(image, label): | 69 | def random_grayscale_expand(image, label): |
64 | # 1/10 | ||
65 | if random.random() < 0.1: | 70 | if random.random() < 0.1: |
66 | image = tf.image.rgb_to_grayscale(image) | 71 | image = tf.image.rgb_to_grayscale(image) |
67 | image = tf.image.grayscale_to_rgb(image) | 72 | image = tf.image.grayscale_to_rgb(image) |
... | @@ -69,22 +74,19 @@ class F3Classification(BaseModel): | ... | @@ -69,22 +74,19 @@ class F3Classification(BaseModel): |
69 | 74 | ||
70 | @staticmethod | 75 | @staticmethod |
71 | def random_flip_left_right(image, label): | 76 | def random_flip_left_right(image, label): |
72 | # 1/10 | 77 | # if random.random() < 0.2: |
73 | if random.random() < 0.2: | ||
74 | image = tf.image.random_flip_left_right(image) | 78 | image = tf.image.random_flip_left_right(image) |
75 | return image, label | 79 | return image, label |
76 | 80 | ||
77 | @staticmethod | 81 | @staticmethod |
78 | def random_flip_up_down(image, label): | 82 | def random_flip_up_down(image, label): |
79 | # 1/10 | 83 | # if random.random() < 0.2: |
80 | if random.random() < 0.2: | ||
81 | image = tf.image.random_flip_up_down(image) | 84 | image = tf.image.random_flip_up_down(image) |
82 | return image, label | 85 | return image, label |
83 | 86 | ||
84 | @staticmethod | 87 | @staticmethod |
85 | def random_rot90(image, label): | 88 | def random_rot90(image, label): |
86 | # 1/10 | 89 | if random.random() < 0.3: |
87 | if random.random() < 0.1: | ||
88 | image = tf.image.rot90(image, k=random.randint(1, 3)) | 90 | image = tf.image.rot90(image, k=random.randint(1, 3)) |
89 | return image, label | 91 | return image, label |
90 | 92 | ||
... | @@ -93,13 +95,15 @@ class F3Classification(BaseModel): | ... | @@ -93,13 +95,15 @@ class F3Classification(BaseModel): |
93 | def load_image(image_path, label): | 95 | def load_image(image_path, label): |
94 | image = tf.io.read_file(image_path) | 96 | image = tf.io.read_file(image_path) |
95 | # image = tf.image.decode_image(image, channels=3) # TODO ? | 97 | # image = tf.image.decode_image(image, channels=3) # TODO ? |
96 | image = tf.image.decode_png(image, channels=3) | 98 | # image = tf.image.decode_png(image, channels=3) |
99 | image = tf.image.decode_jpeg(image, channels=3, dct_method='INTEGER_ACCURATE') | ||
100 | image = tf.image.resize(image, [224, 224]) | ||
97 | return image, label | 101 | return image, label |
98 | 102 | ||
99 | @staticmethod | 103 | @staticmethod |
100 | # @tf.function | 104 | # @tf.function |
101 | def preprocess_input(image, label): | 105 | def preprocess_input(image, label): |
102 | image = tf.image.resize(image, [224, 224]) | 106 | # image = tf.image.resize(image, [224, 224]) |
103 | image = applications.mobilenet_v2.preprocess_input(image) | 107 | image = applications.mobilenet_v2.preprocess_input(image) |
104 | return image, label | 108 | return image, label |
105 | 109 | ||
... | @@ -131,7 +135,9 @@ class F3Classification(BaseModel): | ... | @@ -131,7 +135,9 @@ class F3Classification(BaseModel): |
131 | 135 | ||
132 | base_model = MobileNetV2( | 136 | base_model = MobileNetV2( |
133 | input_shape=(224, 224, 3), | 137 | input_shape=(224, 224, 3), |
134 | alpha=0.35, | 138 | # alpha=0.35, |
139 | alpha=0.5, | ||
140 | # alpha=1, | ||
135 | include_top=False, | 141 | include_top=False, |
136 | weights='imagenet', | 142 | weights='imagenet', |
137 | pooling='avg', | 143 | pooling='avg', |
... | @@ -143,18 +149,15 @@ class F3Classification(BaseModel): | ... | @@ -143,18 +149,15 @@ class F3Classification(BaseModel): |
143 | x = layers.Dense(self.class_count, activation='sigmoid', name='output')(x) | 149 | x = layers.Dense(self.class_count, activation='sigmoid', name='output')(x) |
144 | self.model = models.Model(inputs=base_model.input, outputs=x) | 150 | self.model = models.Model(inputs=base_model.input, outputs=x) |
145 | 151 | ||
146 | if for_training: | 152 | if isinstance(load_weights_path, str) and os.path.isfile(load_weights_path): |
153 | self.model.load_weights(load_weights_path, by_name=True, skip_mismatch=True) | ||
154 | elif for_training: | ||
147 | freeze = True | 155 | freeze = True |
148 | for layer in self.model.layers: | 156 | for layer in self.model.layers: |
149 | layer.trainable = not freeze | 157 | layer.trainable = not freeze |
150 | if freeze and layer.name == 'block_16_project_BN': | 158 | if freeze and layer.name == 'block_16_project_BN': |
151 | freeze = False | 159 | freeze = False |
152 | 160 | ||
153 | if isinstance(load_weights_path, str): | ||
154 | if not os.path.isfile(load_weights_path): | ||
155 | raise Exception('load_weights_path can not find') | ||
156 | self.model.load_weights(load_weights_path, by_name=True, skip_mismatch=True) | ||
157 | |||
158 | def train(self, | 161 | def train(self, |
159 | dataset_dir, | 162 | dataset_dir, |
160 | epoch, | 163 | epoch, |
... | @@ -167,13 +170,14 @@ class F3Classification(BaseModel): | ... | @@ -167,13 +170,14 @@ class F3Classification(BaseModel): |
167 | thresholds=0.5, | 170 | thresholds=0.5, |
168 | metrics_name='accuracy'): | 171 | metrics_name='accuracy'): |
169 | 172 | ||
170 | self.gpu_config() | 173 | self.gpu_config(1) |
171 | 174 | ||
172 | self.load_model(for_training=True, load_weights_path=load_weights_path) | 175 | self.load_model(for_training=True, load_weights_path=load_weights_path) |
173 | self.model.summary() | 176 | self.model.summary() |
174 | 177 | ||
175 | self.model.compile( | 178 | self.model.compile( |
176 | optimizer=optimizers.Adam(learning_rate=3e-4), | 179 | # optimizer=optimizers.Adam(learning_rate=3e-4), |
180 | optimizer=optimizers.Adam(learning_rate=1e-4), | ||
177 | loss=tfa.losses.SigmoidFocalCrossEntropy(), # TODO ? | 181 | loss=tfa.losses.SigmoidFocalCrossEntropy(), # TODO ? |
178 | metrics=[CustomMetric(thresholds, name=metrics_name), ], | 182 | metrics=[CustomMetric(thresholds, name=metrics_name), ], |
179 | 183 | ||
... | @@ -193,7 +197,8 @@ class F3Classification(BaseModel): | ... | @@ -193,7 +197,8 @@ class F3Classification(BaseModel): |
193 | 'random_flip_left_right', | 197 | 'random_flip_left_right', |
194 | 'random_flip_up_down', | 198 | 'random_flip_up_down', |
195 | 'random_rot90', | 199 | 'random_rot90', |
196 | 'random_rgb_2_bgr', | 200 | # 'random_rgb_2_bgr', |
201 | # 'rgb_2_bgr', | ||
197 | 'random_grayscale_expand' | 202 | 'random_grayscale_expand' |
198 | ], | 203 | ], |
199 | ) | 204 | ) |
... | @@ -201,16 +206,21 @@ class F3Classification(BaseModel): | ... | @@ -201,16 +206,21 @@ class F3Classification(BaseModel): |
201 | dataset_dir=os.path.join(dataset_dir, validate_dir_name), | 206 | dataset_dir=os.path.join(dataset_dir, validate_dir_name), |
202 | name=validate_dir_name, | 207 | name=validate_dir_name, |
203 | batch_size=batch_size, | 208 | batch_size=batch_size, |
204 | augmentation_methods=[] | 209 | augmentation_methods=[ |
210 | 'rgb_2_bgr' | ||
211 | ] | ||
205 | ) | 212 | ) |
206 | 213 | ||
207 | ckpt_callback = callbacks.ModelCheckpoint(ckpt_path, save_best_only=True) | 214 | ckpt_callback = callbacks.ModelCheckpoint(ckpt_path, save_best_only=True) |
215 | es_callback = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True) | ||
208 | 216 | ||
209 | history = self.model.fit( | 217 | history = self.model.fit( |
210 | train_dataset, | 218 | train_dataset, |
211 | epochs=epoch, | 219 | epochs=epoch, |
212 | validation_data=validate_dataset, | 220 | validation_data=validate_dataset, |
213 | callbacks=[ckpt_callback, ], | 221 | callbacks=[ckpt_callback, |
222 | # es_callback | ||
223 | ], | ||
214 | ) | 224 | ) |
215 | 225 | ||
216 | history_save(history, history_save_path, metrics_name) | 226 | history_save(history, history_save_path, metrics_name) |
... | @@ -222,7 +232,7 @@ class F3Classification(BaseModel): | ... | @@ -222,7 +232,7 @@ class F3Classification(BaseModel): |
222 | batch_size, | 232 | batch_size, |
223 | validate_dir_name='test', | 233 | validate_dir_name='test', |
224 | thresholds=0.5): | 234 | thresholds=0.5): |
225 | self.gpu_config() | 235 | self.gpu_config(3) |
226 | 236 | ||
227 | self.load_model(load_weights_path=load_weights_path) | 237 | self.load_model(load_weights_path=load_weights_path) |
228 | self.model.summary() | 238 | self.model.summary() |
... | @@ -231,7 +241,9 @@ class F3Classification(BaseModel): | ... | @@ -231,7 +241,9 @@ class F3Classification(BaseModel): |
231 | dataset_dir=os.path.join(dataset_dir, validate_dir_name), | 241 | dataset_dir=os.path.join(dataset_dir, validate_dir_name), |
232 | name=validate_dir_name, | 242 | name=validate_dir_name, |
233 | batch_size=batch_size, | 243 | batch_size=batch_size, |
234 | augmentation_methods=[] | 244 | augmentation_methods=[ |
245 | 'rgb_2_bgr' | ||
246 | ] | ||
235 | ) | 247 | ) |
236 | 248 | ||
237 | label_true_list = [] | 249 | label_true_list = [] | ... | ... |
-
Please register or sign in to post a comment