add async test
Showing
9 changed files
with
207 additions
and
31 deletions
async_test/classification/__init__.py
0 → 100644
async_test/classification/const.py
0 → 100644
async_test/classification/model.py
0 → 100644
| 1 | import tensorflow as tf | ||
| 2 | |||
| 3 | |||
| 4 | class F3Classification: | ||
| 5 | |||
| 6 | def __init__(self, class_name_list, class_other_first, *args, **kwargs): | ||
| 7 | self.class_name_list = class_name_list | ||
| 8 | self.class_count = len(class_name_list) if not class_other_first else len(class_name_list) - 1 | ||
| 9 | self.model_name = 'classification_model' | ||
| 10 | self.signature_name = 'serving_default' | ||
| 11 | self.server_name = 'server_1' | ||
| 12 | |||
| 13 | @staticmethod | ||
| 14 | def preprocess_input(image): | ||
| 15 | image = tf.image.resize(image, [224, 224]) | ||
| 16 | image = tf.keras.applications.mobilenet_v2.preprocess_input(image) | ||
| 17 | input_images = tf.expand_dims(image, axis=0) | ||
| 18 | return input_images | ||
| 19 | |||
| 20 | def reprocess_output(self, outputs, thresholds=0.5): | ||
| 21 | for output in outputs: | ||
| 22 | idx = tf.math.argmax(output) | ||
| 23 | confidence = output[idx] | ||
| 24 | if confidence < thresholds: | ||
| 25 | idx = -1 | ||
| 26 | label = self.class_name_list[idx + 1] | ||
| 27 | break | ||
| 28 | |||
| 29 | res = { | ||
| 30 | 'label': label, | ||
| 31 | 'confidence': confidence | ||
| 32 | } | ||
| 33 | return res | ||
| 34 |
async_test/locustfile.py
0 → 100644
| 1 | import time | ||
| 2 | from locust import HttpUser, task, between, constant, tag | ||
| 3 | |||
| 4 | |||
| 5 | class QuickstartUser(HttpUser): | ||
| 6 | # wait_time = between(1, 5) | ||
| 7 | |||
| 8 | @tag('sync') | ||
| 9 | @task | ||
| 10 | def sync_test(self): | ||
| 11 | self.client.get("/sync") | ||
| 12 | |||
| 13 | @tag('async') | ||
| 14 | @task | ||
| 15 | def async_test(self): | ||
| 16 | self.client.get("/async") | ||
| ... | \ No newline at end of file | ... | \ No newline at end of file |
async_test/requirements.txt
0 → 100644
async_test/server1.py
0 → 100644
| 1 | import asyncio | ||
| 2 | import time | ||
| 3 | |||
| 4 | from sanic import Sanic | ||
| 5 | from sanic.response import json | ||
| 6 | |||
| 7 | app = Sanic("async_test") | ||
| 8 | |||
| 9 | |||
| 10 | @app.get("/sync") | ||
| 11 | async def sync_handler(request): | ||
| 12 | time.sleep(2) | ||
| 13 | return json({'code': 1}) | ||
| 14 | |||
| 15 | |||
| 16 | @app.get("/async") | ||
| 17 | async def async_handler(request): | ||
| 18 | await asyncio.sleep(2) | ||
| 19 | return json({'code': 1}) | ||
| 20 | |||
| 21 | |||
| 22 | if __name__ == '__main__': | ||
| 23 | app.run(host='0.0.0.0', port=1337, workers=5) |
async_test/server2.py
0 → 100644
| 1 | import grpc | ||
| 2 | import cv2 | ||
| 3 | import numpy as np | ||
| 4 | import tensorflow as tf | ||
| 5 | from tensorflow_serving.apis import prediction_service_pb2_grpc, predict_pb2 | ||
| 6 | |||
| 7 | from sanic import Sanic | ||
| 8 | from sanic.response import json | ||
| 9 | |||
| 10 | from classification import classifier | ||
| 11 | |||
| 12 | app = Sanic("async_test") | ||
| 13 | |||
| 14 | # TODO 从配置文件读取 | ||
| 15 | tf_serving_settings = { | ||
| 16 | 'servers': { | ||
| 17 | 'server_1': { | ||
| 18 | 'host': '192.168.10.191', | ||
| 19 | 'port': '8500', | ||
| 20 | 'options': [ | ||
| 21 | ('grpc.max_send_message_length', 1000 * 1024 * 1024), | ||
| 22 | ('grpc.max_receive_message_length', 1000 * 1024 * 1024), | ||
| 23 | ], | ||
| 24 | } | ||
| 25 | }, | ||
| 26 | } | ||
| 27 | app.config.update(tf_serving_settings) | ||
| 28 | |||
| 29 | |||
| 30 | @app.post("/sync_classification") | ||
| 31 | async def sync_handler(request): | ||
| 32 | image = request.files.get("image") | ||
| 33 | img_array = np.frombuffer(image.body, np.uint8) | ||
| 34 | image = cv2.imdecode(img_array, cv2.IMREAD_COLOR) | ||
| 35 | input_images = classifier.preprocess_input(image) | ||
| 36 | |||
| 37 | # print(type(image)) | ||
| 38 | |||
| 39 | # See prediction_service.proto for gRPC request/response details. | ||
| 40 | request = predict_pb2.PredictRequest() | ||
| 41 | request.model_spec.name = classifier.model_name | ||
| 42 | request.model_spec.signature_name = classifier.signature_name | ||
| 43 | stub = getattr(app.ctx, classifier.server_name) | ||
| 44 | |||
| 45 | request.inputs['input_1'].CopyFrom(tf.make_tensor_proto(input_images)) | ||
| 46 | result = stub.Predict(request, 100.0) # 100 secs timeout | ||
| 47 | outputs = tf.make_ndarray(result.outputs['output']) | ||
| 48 | |||
| 49 | res = classifier.reprocess_output(outputs) | ||
| 50 | return json(res) | ||
| 51 | |||
| 52 | |||
| 53 | # @app.get("/async") | ||
| 54 | # async def async_handler(request): | ||
| 55 | # await asyncio.sleep(2) | ||
| 56 | # return json({'code': 1}) | ||
| 57 | |||
| 58 | |||
| 59 | @app.listener("before_server_start") | ||
| 60 | async def set_grpc_channel(app, loop): | ||
| 61 | for server_name, server_settings in app.config['servers'].items(): | ||
| 62 | channel = grpc.insecure_channel( | ||
| 63 | '{0}:{1}'.format(server_settings['host'], server_settings['port']), | ||
| 64 | options=server_settings.get('options')) | ||
| 65 | stub = prediction_service_pb2_grpc.PredictionServiceStub(channel) | ||
| 66 | setattr(app.ctx, server_name, stub) | ||
| 67 | |||
| 68 | |||
| 69 | if __name__ == '__main__': | ||
| 70 | app.run(host='0.0.0.0', port=6699, workers=5) |
| ... | @@ -9,5 +9,5 @@ classifier = F3Classification( | ... | @@ -9,5 +9,5 @@ classifier = F3Classification( |
| 9 | ) | 9 | ) |
| 10 | 10 | ||
| 11 | classifier.load_model(load_weights_path=os.path.join( | 11 | classifier.load_model(load_weights_path=os.path.join( |
| 12 | os.path.dirname(os.path.abspath(__file__)), 'ckpt_prod.h5')) | 12 | os.path.dirname(os.path.abspath(__file__)), 'ckpt_best_0.5.h5')) |
| 13 | 13 | ... | ... |
| 1 | import os | 1 | import os |
| 2 | import random | 2 | import random |
| 3 | import cv2 | ||
| 3 | import tensorflow as tf | 4 | import tensorflow as tf |
| 4 | import tensorflow_addons as tfa | 5 | import tensorflow_addons as tfa |
| 5 | 6 | ||
| ... | @@ -23,10 +24,9 @@ class F3Classification(BaseModel): | ... | @@ -23,10 +24,9 @@ class F3Classification(BaseModel): |
| 23 | self.model = None | 24 | self.model = None |
| 24 | 25 | ||
| 25 | @staticmethod | 26 | @staticmethod |
| 26 | def gpu_config(): | 27 | def gpu_config(gpu_idx=0): |
| 27 | gpus = tf.config.experimental.list_physical_devices(device_type='GPU') | 28 | gpus = tf.config.experimental.list_physical_devices(device_type='GPU') |
| 28 | # print(gpus) | 29 | tf.config.set_visible_devices(devices=gpus[gpu_idx], device_type='GPU') |
| 29 | tf.config.set_visible_devices(devices=gpus[1], device_type='GPU') | ||
| 30 | 30 | ||
| 31 | @staticmethod | 31 | @staticmethod |
| 32 | def get_class_label_map(class_name_list, class_other_first=False): | 32 | def get_class_label_map(class_name_list, class_other_first=False): |
| ... | @@ -53,15 +53,20 @@ class F3Classification(BaseModel): | ... | @@ -53,15 +53,20 @@ class F3Classification(BaseModel): |
| 53 | @staticmethod | 53 | @staticmethod |
| 54 | # @tf.function | 54 | # @tf.function |
| 55 | def random_rgb_2_bgr(image, label): | 55 | def random_rgb_2_bgr(image, label): |
| 56 | # 1/5 | 56 | # 1/2 |
| 57 | if random.random() < 0.1: | 57 | if random.random() < 0.5: |
| 58 | image = image[:, :, ::-1] | 58 | image = image[:, :, ::-1] |
| 59 | return image, label | 59 | return image, label |
| 60 | 60 | ||
| 61 | @staticmethod | 61 | @staticmethod |
| 62 | # @tf.function | 62 | # @tf.function |
| 63 | def rgb_2_bgr(image, label): | ||
| 64 | image = image[:, :, ::-1] | ||
| 65 | return image, label | ||
| 66 | |||
| 67 | @staticmethod | ||
| 68 | # @tf.function | ||
| 63 | def random_grayscale_expand(image, label): | 69 | def random_grayscale_expand(image, label): |
| 64 | # 1/10 | ||
| 65 | if random.random() < 0.1: | 70 | if random.random() < 0.1: |
| 66 | image = tf.image.rgb_to_grayscale(image) | 71 | image = tf.image.rgb_to_grayscale(image) |
| 67 | image = tf.image.grayscale_to_rgb(image) | 72 | image = tf.image.grayscale_to_rgb(image) |
| ... | @@ -69,22 +74,19 @@ class F3Classification(BaseModel): | ... | @@ -69,22 +74,19 @@ class F3Classification(BaseModel): |
| 69 | 74 | ||
| 70 | @staticmethod | 75 | @staticmethod |
| 71 | def random_flip_left_right(image, label): | 76 | def random_flip_left_right(image, label): |
| 72 | # 1/10 | 77 | # if random.random() < 0.2: |
| 73 | if random.random() < 0.2: | 78 | image = tf.image.random_flip_left_right(image) |
| 74 | image = tf.image.random_flip_left_right(image) | ||
| 75 | return image, label | 79 | return image, label |
| 76 | 80 | ||
| 77 | @staticmethod | 81 | @staticmethod |
| 78 | def random_flip_up_down(image, label): | 82 | def random_flip_up_down(image, label): |
| 79 | # 1/10 | 83 | # if random.random() < 0.2: |
| 80 | if random.random() < 0.2: | 84 | image = tf.image.random_flip_up_down(image) |
| 81 | image = tf.image.random_flip_up_down(image) | ||
| 82 | return image, label | 85 | return image, label |
| 83 | 86 | ||
| 84 | @staticmethod | 87 | @staticmethod |
| 85 | def random_rot90(image, label): | 88 | def random_rot90(image, label): |
| 86 | # 1/10 | 89 | if random.random() < 0.3: |
| 87 | if random.random() < 0.1: | ||
| 88 | image = tf.image.rot90(image, k=random.randint(1, 3)) | 90 | image = tf.image.rot90(image, k=random.randint(1, 3)) |
| 89 | return image, label | 91 | return image, label |
| 90 | 92 | ||
| ... | @@ -93,13 +95,15 @@ class F3Classification(BaseModel): | ... | @@ -93,13 +95,15 @@ class F3Classification(BaseModel): |
| 93 | def load_image(image_path, label): | 95 | def load_image(image_path, label): |
| 94 | image = tf.io.read_file(image_path) | 96 | image = tf.io.read_file(image_path) |
| 95 | # image = tf.image.decode_image(image, channels=3) # TODO ? | 97 | # image = tf.image.decode_image(image, channels=3) # TODO ? |
| 96 | image = tf.image.decode_png(image, channels=3) | 98 | # image = tf.image.decode_png(image, channels=3) |
| 99 | image = tf.image.decode_jpeg(image, channels=3, dct_method='INTEGER_ACCURATE') | ||
| 100 | image = tf.image.resize(image, [224, 224]) | ||
| 97 | return image, label | 101 | return image, label |
| 98 | 102 | ||
| 99 | @staticmethod | 103 | @staticmethod |
| 100 | # @tf.function | 104 | # @tf.function |
| 101 | def preprocess_input(image, label): | 105 | def preprocess_input(image, label): |
| 102 | image = tf.image.resize(image, [224, 224]) | 106 | # image = tf.image.resize(image, [224, 224]) |
| 103 | image = applications.mobilenet_v2.preprocess_input(image) | 107 | image = applications.mobilenet_v2.preprocess_input(image) |
| 104 | return image, label | 108 | return image, label |
| 105 | 109 | ||
| ... | @@ -131,7 +135,9 @@ class F3Classification(BaseModel): | ... | @@ -131,7 +135,9 @@ class F3Classification(BaseModel): |
| 131 | 135 | ||
| 132 | base_model = MobileNetV2( | 136 | base_model = MobileNetV2( |
| 133 | input_shape=(224, 224, 3), | 137 | input_shape=(224, 224, 3), |
| 134 | alpha=0.35, | 138 | # alpha=0.35, |
| 139 | alpha=0.5, | ||
| 140 | # alpha=1, | ||
| 135 | include_top=False, | 141 | include_top=False, |
| 136 | weights='imagenet', | 142 | weights='imagenet', |
| 137 | pooling='avg', | 143 | pooling='avg', |
| ... | @@ -143,18 +149,15 @@ class F3Classification(BaseModel): | ... | @@ -143,18 +149,15 @@ class F3Classification(BaseModel): |
| 143 | x = layers.Dense(self.class_count, activation='sigmoid', name='output')(x) | 149 | x = layers.Dense(self.class_count, activation='sigmoid', name='output')(x) |
| 144 | self.model = models.Model(inputs=base_model.input, outputs=x) | 150 | self.model = models.Model(inputs=base_model.input, outputs=x) |
| 145 | 151 | ||
| 146 | if for_training: | 152 | if isinstance(load_weights_path, str) and os.path.isfile(load_weights_path): |
| 153 | self.model.load_weights(load_weights_path, by_name=True, skip_mismatch=True) | ||
| 154 | elif for_training: | ||
| 147 | freeze = True | 155 | freeze = True |
| 148 | for layer in self.model.layers: | 156 | for layer in self.model.layers: |
| 149 | layer.trainable = not freeze | 157 | layer.trainable = not freeze |
| 150 | if freeze and layer.name == 'block_16_project_BN': | 158 | if freeze and layer.name == 'block_16_project_BN': |
| 151 | freeze = False | 159 | freeze = False |
| 152 | 160 | ||
| 153 | if isinstance(load_weights_path, str): | ||
| 154 | if not os.path.isfile(load_weights_path): | ||
| 155 | raise Exception('load_weights_path can not find') | ||
| 156 | self.model.load_weights(load_weights_path, by_name=True, skip_mismatch=True) | ||
| 157 | |||
| 158 | def train(self, | 161 | def train(self, |
| 159 | dataset_dir, | 162 | dataset_dir, |
| 160 | epoch, | 163 | epoch, |
| ... | @@ -167,13 +170,14 @@ class F3Classification(BaseModel): | ... | @@ -167,13 +170,14 @@ class F3Classification(BaseModel): |
| 167 | thresholds=0.5, | 170 | thresholds=0.5, |
| 168 | metrics_name='accuracy'): | 171 | metrics_name='accuracy'): |
| 169 | 172 | ||
| 170 | self.gpu_config() | 173 | self.gpu_config(1) |
| 171 | 174 | ||
| 172 | self.load_model(for_training=True, load_weights_path=load_weights_path) | 175 | self.load_model(for_training=True, load_weights_path=load_weights_path) |
| 173 | self.model.summary() | 176 | self.model.summary() |
| 174 | 177 | ||
| 175 | self.model.compile( | 178 | self.model.compile( |
| 176 | optimizer=optimizers.Adam(learning_rate=3e-4), | 179 | # optimizer=optimizers.Adam(learning_rate=3e-4), |
| 180 | optimizer=optimizers.Adam(learning_rate=1e-4), | ||
| 177 | loss=tfa.losses.SigmoidFocalCrossEntropy(), # TODO ? | 181 | loss=tfa.losses.SigmoidFocalCrossEntropy(), # TODO ? |
| 178 | metrics=[CustomMetric(thresholds, name=metrics_name), ], | 182 | metrics=[CustomMetric(thresholds, name=metrics_name), ], |
| 179 | 183 | ||
| ... | @@ -193,7 +197,8 @@ class F3Classification(BaseModel): | ... | @@ -193,7 +197,8 @@ class F3Classification(BaseModel): |
| 193 | 'random_flip_left_right', | 197 | 'random_flip_left_right', |
| 194 | 'random_flip_up_down', | 198 | 'random_flip_up_down', |
| 195 | 'random_rot90', | 199 | 'random_rot90', |
| 196 | 'random_rgb_2_bgr', | 200 | # 'random_rgb_2_bgr', |
| 201 | # 'rgb_2_bgr', | ||
| 197 | 'random_grayscale_expand' | 202 | 'random_grayscale_expand' |
| 198 | ], | 203 | ], |
| 199 | ) | 204 | ) |
| ... | @@ -201,16 +206,21 @@ class F3Classification(BaseModel): | ... | @@ -201,16 +206,21 @@ class F3Classification(BaseModel): |
| 201 | dataset_dir=os.path.join(dataset_dir, validate_dir_name), | 206 | dataset_dir=os.path.join(dataset_dir, validate_dir_name), |
| 202 | name=validate_dir_name, | 207 | name=validate_dir_name, |
| 203 | batch_size=batch_size, | 208 | batch_size=batch_size, |
| 204 | augmentation_methods=[] | 209 | augmentation_methods=[ |
| 210 | 'rgb_2_bgr' | ||
| 211 | ] | ||
| 205 | ) | 212 | ) |
| 206 | 213 | ||
| 207 | ckpt_callback = callbacks.ModelCheckpoint(ckpt_path, save_best_only=True) | 214 | ckpt_callback = callbacks.ModelCheckpoint(ckpt_path, save_best_only=True) |
| 215 | es_callback = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True) | ||
| 208 | 216 | ||
| 209 | history = self.model.fit( | 217 | history = self.model.fit( |
| 210 | train_dataset, | 218 | train_dataset, |
| 211 | epochs=epoch, | 219 | epochs=epoch, |
| 212 | validation_data=validate_dataset, | 220 | validation_data=validate_dataset, |
| 213 | callbacks=[ckpt_callback, ], | 221 | callbacks=[ckpt_callback, |
| 222 | # es_callback | ||
| 223 | ], | ||
| 214 | ) | 224 | ) |
| 215 | 225 | ||
| 216 | history_save(history, history_save_path, metrics_name) | 226 | history_save(history, history_save_path, metrics_name) |
| ... | @@ -222,7 +232,7 @@ class F3Classification(BaseModel): | ... | @@ -222,7 +232,7 @@ class F3Classification(BaseModel): |
| 222 | batch_size, | 232 | batch_size, |
| 223 | validate_dir_name='test', | 233 | validate_dir_name='test', |
| 224 | thresholds=0.5): | 234 | thresholds=0.5): |
| 225 | self.gpu_config() | 235 | self.gpu_config(3) |
| 226 | 236 | ||
| 227 | self.load_model(load_weights_path=load_weights_path) | 237 | self.load_model(load_weights_path=load_weights_path) |
| 228 | self.model.summary() | 238 | self.model.summary() |
| ... | @@ -231,7 +241,9 @@ class F3Classification(BaseModel): | ... | @@ -231,7 +241,9 @@ class F3Classification(BaseModel): |
| 231 | dataset_dir=os.path.join(dataset_dir, validate_dir_name), | 241 | dataset_dir=os.path.join(dataset_dir, validate_dir_name), |
| 232 | name=validate_dir_name, | 242 | name=validate_dir_name, |
| 233 | batch_size=batch_size, | 243 | batch_size=batch_size, |
| 234 | augmentation_methods=[] | 244 | augmentation_methods=[ |
| 245 | 'rgb_2_bgr' | ||
| 246 | ] | ||
| 235 | ) | 247 | ) |
| 236 | 248 | ||
| 237 | label_true_list = [] | 249 | label_true_list = [] | ... | ... |
-
Please register or sign in to post a comment