a33165ce by 周伟奇

add async test

1 parent 158a5293
1 from .model import F3Classification
2 from .const import CLASS_CN_LIST, CLASS_OTHER_FIRST
3
4 classifier = F3Classification(
5 class_name_list=CLASS_CN_LIST,
6 class_other_first=CLASS_OTHER_FIRST
7 )
8
9
1 CLASS_OTHER_CN = '其他'
2
3 CLASS_OTHER_FIRST = True
4
5 # CLASS_CN_LIST = [CLASS_OTHER_CN, '身份证', '营业执照', '经销商授权书', '个人授权书']
6 CLASS_CN_LIST = [CLASS_OTHER_CN, '营业执照', '经销商授权书', '个人授权书']
7
8 OTHER_THRESHOLDS = 0.5
9
1 import tensorflow as tf
2
3
4 class F3Classification:
5
6 def __init__(self, class_name_list, class_other_first, *args, **kwargs):
7 self.class_name_list = class_name_list
8 self.class_count = len(class_name_list) if not class_other_first else len(class_name_list) - 1
9 self.model_name = 'classification_model'
10 self.signature_name = 'serving_default'
11 self.server_name = 'server_1'
12
13 @staticmethod
14 def preprocess_input(image):
15 image = tf.image.resize(image, [224, 224])
16 image = tf.keras.applications.mobilenet_v2.preprocess_input(image)
17 input_images = tf.expand_dims(image, axis=0)
18 return input_images
19
20 def reprocess_output(self, outputs, thresholds=0.5):
21 for output in outputs:
22 idx = tf.math.argmax(output)
23 confidence = output[idx]
24 if confidence < thresholds:
25 idx = -1
26 label = self.class_name_list[idx + 1]
27 break
28
29 res = {
30 'label': label,
31 'confidence': confidence
32 }
33 return res
34
1 import time
2 from locust import HttpUser, task, between, constant, tag
3
4
5 class QuickstartUser(HttpUser):
6 # wait_time = between(1, 5)
7
8 @tag('sync')
9 @task
10 def sync_test(self):
11 self.client.get("/sync")
12
13 @tag('async')
14 @task
15 def async_test(self):
16 self.client.get("/async")
...\ No newline at end of file ...\ No newline at end of file
1 sanic==22.6.0
2 locust==2.10.1
3
1 import asyncio
2 import time
3
4 from sanic import Sanic
5 from sanic.response import json
6
7 app = Sanic("async_test")
8
9
10 @app.get("/sync")
11 async def sync_handler(request):
12 time.sleep(2)
13 return json({'code': 1})
14
15
16 @app.get("/async")
17 async def async_handler(request):
18 await asyncio.sleep(2)
19 return json({'code': 1})
20
21
22 if __name__ == '__main__':
23 app.run(host='0.0.0.0', port=1337, workers=5)
1 import grpc
2 import cv2
3 import numpy as np
4 import tensorflow as tf
5 from tensorflow_serving.apis import prediction_service_pb2_grpc, predict_pb2
6
7 from sanic import Sanic
8 from sanic.response import json
9
10 from classification import classifier
11
12 app = Sanic("async_test")
13
14 # TODO 从配置文件读取
15 tf_serving_settings = {
16 'servers': {
17 'server_1': {
18 'host': '192.168.10.191',
19 'port': '8500',
20 'options': [
21 ('grpc.max_send_message_length', 1000 * 1024 * 1024),
22 ('grpc.max_receive_message_length', 1000 * 1024 * 1024),
23 ],
24 }
25 },
26 }
27 app.config.update(tf_serving_settings)
28
29
30 @app.post("/sync_classification")
31 async def sync_handler(request):
32 image = request.files.get("image")
33 img_array = np.frombuffer(image.body, np.uint8)
34 image = cv2.imdecode(img_array, cv2.IMREAD_COLOR)
35 input_images = classifier.preprocess_input(image)
36
37 # print(type(image))
38
39 # See prediction_service.proto for gRPC request/response details.
40 request = predict_pb2.PredictRequest()
41 request.model_spec.name = classifier.model_name
42 request.model_spec.signature_name = classifier.signature_name
43 stub = getattr(app.ctx, classifier.server_name)
44
45 request.inputs['input_1'].CopyFrom(tf.make_tensor_proto(input_images))
46 result = stub.Predict(request, 100.0) # 100 secs timeout
47 outputs = tf.make_ndarray(result.outputs['output'])
48
49 res = classifier.reprocess_output(outputs)
50 return json(res)
51
52
53 # @app.get("/async")
54 # async def async_handler(request):
55 # await asyncio.sleep(2)
56 # return json({'code': 1})
57
58
59 @app.listener("before_server_start")
60 async def set_grpc_channel(app, loop):
61 for server_name, server_settings in app.config['servers'].items():
62 channel = grpc.insecure_channel(
63 '{0}:{1}'.format(server_settings['host'], server_settings['port']),
64 options=server_settings.get('options'))
65 stub = prediction_service_pb2_grpc.PredictionServiceStub(channel)
66 setattr(app.ctx, server_name, stub)
67
68
69 if __name__ == '__main__':
70 app.run(host='0.0.0.0', port=6699, workers=5)
...@@ -9,5 +9,5 @@ classifier = F3Classification( ...@@ -9,5 +9,5 @@ classifier = F3Classification(
9 ) 9 )
10 10
11 classifier.load_model(load_weights_path=os.path.join( 11 classifier.load_model(load_weights_path=os.path.join(
12 os.path.dirname(os.path.abspath(__file__)), 'ckpt_prod.h5')) 12 os.path.dirname(os.path.abspath(__file__)), 'ckpt_best_0.5.h5'))
13 13
......
1 import os 1 import os
2 import random 2 import random
3 import cv2
3 import tensorflow as tf 4 import tensorflow as tf
4 import tensorflow_addons as tfa 5 import tensorflow_addons as tfa
5 6
...@@ -23,10 +24,9 @@ class F3Classification(BaseModel): ...@@ -23,10 +24,9 @@ class F3Classification(BaseModel):
23 self.model = None 24 self.model = None
24 25
25 @staticmethod 26 @staticmethod
26 def gpu_config(): 27 def gpu_config(gpu_idx=0):
27 gpus = tf.config.experimental.list_physical_devices(device_type='GPU') 28 gpus = tf.config.experimental.list_physical_devices(device_type='GPU')
28 # print(gpus) 29 tf.config.set_visible_devices(devices=gpus[gpu_idx], device_type='GPU')
29 tf.config.set_visible_devices(devices=gpus[1], device_type='GPU')
30 30
31 @staticmethod 31 @staticmethod
32 def get_class_label_map(class_name_list, class_other_first=False): 32 def get_class_label_map(class_name_list, class_other_first=False):
...@@ -53,15 +53,20 @@ class F3Classification(BaseModel): ...@@ -53,15 +53,20 @@ class F3Classification(BaseModel):
53 @staticmethod 53 @staticmethod
54 # @tf.function 54 # @tf.function
55 def random_rgb_2_bgr(image, label): 55 def random_rgb_2_bgr(image, label):
56 # 1/5 56 # 1/2
57 if random.random() < 0.1: 57 if random.random() < 0.5:
58 image = image[:, :, ::-1] 58 image = image[:, :, ::-1]
59 return image, label 59 return image, label
60 60
61 @staticmethod 61 @staticmethod
62 # @tf.function 62 # @tf.function
63 def rgb_2_bgr(image, label):
64 image = image[:, :, ::-1]
65 return image, label
66
67 @staticmethod
68 # @tf.function
63 def random_grayscale_expand(image, label): 69 def random_grayscale_expand(image, label):
64 # 1/10
65 if random.random() < 0.1: 70 if random.random() < 0.1:
66 image = tf.image.rgb_to_grayscale(image) 71 image = tf.image.rgb_to_grayscale(image)
67 image = tf.image.grayscale_to_rgb(image) 72 image = tf.image.grayscale_to_rgb(image)
...@@ -69,22 +74,19 @@ class F3Classification(BaseModel): ...@@ -69,22 +74,19 @@ class F3Classification(BaseModel):
69 74
70 @staticmethod 75 @staticmethod
71 def random_flip_left_right(image, label): 76 def random_flip_left_right(image, label):
72 # 1/10 77 # if random.random() < 0.2:
73 if random.random() < 0.2: 78 image = tf.image.random_flip_left_right(image)
74 image = tf.image.random_flip_left_right(image)
75 return image, label 79 return image, label
76 80
77 @staticmethod 81 @staticmethod
78 def random_flip_up_down(image, label): 82 def random_flip_up_down(image, label):
79 # 1/10 83 # if random.random() < 0.2:
80 if random.random() < 0.2: 84 image = tf.image.random_flip_up_down(image)
81 image = tf.image.random_flip_up_down(image)
82 return image, label 85 return image, label
83 86
84 @staticmethod 87 @staticmethod
85 def random_rot90(image, label): 88 def random_rot90(image, label):
86 # 1/10 89 if random.random() < 0.3:
87 if random.random() < 0.1:
88 image = tf.image.rot90(image, k=random.randint(1, 3)) 90 image = tf.image.rot90(image, k=random.randint(1, 3))
89 return image, label 91 return image, label
90 92
...@@ -93,13 +95,15 @@ class F3Classification(BaseModel): ...@@ -93,13 +95,15 @@ class F3Classification(BaseModel):
93 def load_image(image_path, label): 95 def load_image(image_path, label):
94 image = tf.io.read_file(image_path) 96 image = tf.io.read_file(image_path)
95 # image = tf.image.decode_image(image, channels=3) # TODO ? 97 # image = tf.image.decode_image(image, channels=3) # TODO ?
96 image = tf.image.decode_png(image, channels=3) 98 # image = tf.image.decode_png(image, channels=3)
99 image = tf.image.decode_jpeg(image, channels=3, dct_method='INTEGER_ACCURATE')
100 image = tf.image.resize(image, [224, 224])
97 return image, label 101 return image, label
98 102
99 @staticmethod 103 @staticmethod
100 # @tf.function 104 # @tf.function
101 def preprocess_input(image, label): 105 def preprocess_input(image, label):
102 image = tf.image.resize(image, [224, 224]) 106 # image = tf.image.resize(image, [224, 224])
103 image = applications.mobilenet_v2.preprocess_input(image) 107 image = applications.mobilenet_v2.preprocess_input(image)
104 return image, label 108 return image, label
105 109
...@@ -131,7 +135,9 @@ class F3Classification(BaseModel): ...@@ -131,7 +135,9 @@ class F3Classification(BaseModel):
131 135
132 base_model = MobileNetV2( 136 base_model = MobileNetV2(
133 input_shape=(224, 224, 3), 137 input_shape=(224, 224, 3),
134 alpha=0.35, 138 # alpha=0.35,
139 alpha=0.5,
140 # alpha=1,
135 include_top=False, 141 include_top=False,
136 weights='imagenet', 142 weights='imagenet',
137 pooling='avg', 143 pooling='avg',
...@@ -143,18 +149,15 @@ class F3Classification(BaseModel): ...@@ -143,18 +149,15 @@ class F3Classification(BaseModel):
143 x = layers.Dense(self.class_count, activation='sigmoid', name='output')(x) 149 x = layers.Dense(self.class_count, activation='sigmoid', name='output')(x)
144 self.model = models.Model(inputs=base_model.input, outputs=x) 150 self.model = models.Model(inputs=base_model.input, outputs=x)
145 151
146 if for_training: 152 if isinstance(load_weights_path, str) and os.path.isfile(load_weights_path):
153 self.model.load_weights(load_weights_path, by_name=True, skip_mismatch=True)
154 elif for_training:
147 freeze = True 155 freeze = True
148 for layer in self.model.layers: 156 for layer in self.model.layers:
149 layer.trainable = not freeze 157 layer.trainable = not freeze
150 if freeze and layer.name == 'block_16_project_BN': 158 if freeze and layer.name == 'block_16_project_BN':
151 freeze = False 159 freeze = False
152 160
153 if isinstance(load_weights_path, str):
154 if not os.path.isfile(load_weights_path):
155 raise Exception('load_weights_path can not find')
156 self.model.load_weights(load_weights_path, by_name=True, skip_mismatch=True)
157
158 def train(self, 161 def train(self,
159 dataset_dir, 162 dataset_dir,
160 epoch, 163 epoch,
...@@ -167,13 +170,14 @@ class F3Classification(BaseModel): ...@@ -167,13 +170,14 @@ class F3Classification(BaseModel):
167 thresholds=0.5, 170 thresholds=0.5,
168 metrics_name='accuracy'): 171 metrics_name='accuracy'):
169 172
170 self.gpu_config() 173 self.gpu_config(1)
171 174
172 self.load_model(for_training=True, load_weights_path=load_weights_path) 175 self.load_model(for_training=True, load_weights_path=load_weights_path)
173 self.model.summary() 176 self.model.summary()
174 177
175 self.model.compile( 178 self.model.compile(
176 optimizer=optimizers.Adam(learning_rate=3e-4), 179 # optimizer=optimizers.Adam(learning_rate=3e-4),
180 optimizer=optimizers.Adam(learning_rate=1e-4),
177 loss=tfa.losses.SigmoidFocalCrossEntropy(), # TODO ? 181 loss=tfa.losses.SigmoidFocalCrossEntropy(), # TODO ?
178 metrics=[CustomMetric(thresholds, name=metrics_name), ], 182 metrics=[CustomMetric(thresholds, name=metrics_name), ],
179 183
...@@ -193,7 +197,8 @@ class F3Classification(BaseModel): ...@@ -193,7 +197,8 @@ class F3Classification(BaseModel):
193 'random_flip_left_right', 197 'random_flip_left_right',
194 'random_flip_up_down', 198 'random_flip_up_down',
195 'random_rot90', 199 'random_rot90',
196 'random_rgb_2_bgr', 200 # 'random_rgb_2_bgr',
201 # 'rgb_2_bgr',
197 'random_grayscale_expand' 202 'random_grayscale_expand'
198 ], 203 ],
199 ) 204 )
...@@ -201,16 +206,21 @@ class F3Classification(BaseModel): ...@@ -201,16 +206,21 @@ class F3Classification(BaseModel):
201 dataset_dir=os.path.join(dataset_dir, validate_dir_name), 206 dataset_dir=os.path.join(dataset_dir, validate_dir_name),
202 name=validate_dir_name, 207 name=validate_dir_name,
203 batch_size=batch_size, 208 batch_size=batch_size,
204 augmentation_methods=[] 209 augmentation_methods=[
210 'rgb_2_bgr'
211 ]
205 ) 212 )
206 213
207 ckpt_callback = callbacks.ModelCheckpoint(ckpt_path, save_best_only=True) 214 ckpt_callback = callbacks.ModelCheckpoint(ckpt_path, save_best_only=True)
215 es_callback = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)
208 216
209 history = self.model.fit( 217 history = self.model.fit(
210 train_dataset, 218 train_dataset,
211 epochs=epoch, 219 epochs=epoch,
212 validation_data=validate_dataset, 220 validation_data=validate_dataset,
213 callbacks=[ckpt_callback, ], 221 callbacks=[ckpt_callback,
222 # es_callback
223 ],
214 ) 224 )
215 225
216 history_save(history, history_save_path, metrics_name) 226 history_save(history, history_save_path, metrics_name)
...@@ -222,7 +232,7 @@ class F3Classification(BaseModel): ...@@ -222,7 +232,7 @@ class F3Classification(BaseModel):
222 batch_size, 232 batch_size,
223 validate_dir_name='test', 233 validate_dir_name='test',
224 thresholds=0.5): 234 thresholds=0.5):
225 self.gpu_config() 235 self.gpu_config(3)
226 236
227 self.load_model(load_weights_path=load_weights_path) 237 self.load_model(load_weights_path=load_weights_path)
228 self.model.summary() 238 self.model.summary()
...@@ -231,7 +241,9 @@ class F3Classification(BaseModel): ...@@ -231,7 +241,9 @@ class F3Classification(BaseModel):
231 dataset_dir=os.path.join(dataset_dir, validate_dir_name), 241 dataset_dir=os.path.join(dataset_dir, validate_dir_name),
232 name=validate_dir_name, 242 name=validate_dir_name,
233 batch_size=batch_size, 243 batch_size=batch_size,
234 augmentation_methods=[] 244 augmentation_methods=[
245 'rgb_2_bgr'
246 ]
235 ) 247 )
236 248
237 label_true_list = [] 249 label_true_list = []
......
Styling with Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!