974e6763 by 乔峰昇

the first submit

1 parent 32d03d1b
1 cmake_minimum_required(VERSION 3.10)
2 project(main)
3 set(CMAKE_CXX_STANDARD 11)
4 find_package(OpenCV REQUIRED)
5 set(MNN_DIR /home/situ/MNN/MNN1.0/MNN)
6 include_directories(${MNN_DIR}/include)
7 LINK_DIRECTORIES(${MNN_DIR}/build)
8 include_directories(/home/situ/qfs/sdk_project/mnn_projects/speak_recognize_mnn/include)
9 aux_source_directory(/home/situ/qfs/sdk_project/mnn_projects/speak_recognize_mnn/tools SOURCE_CPP)
10 link_directories(/home/situ/qfs/sdk_project/mnn_projects/speak_recognize_mnn/lib)
11 # add_library(speakrecognize SHARED speak_detector.cpp ${SOURCE_CPP})
12 # add_executable(speakrecognize main.cpp retinaface.cpp facelandmarks.cpp speakcls.cpp speak_detector.cpp)
13 add_executable(speakrecognize main.cpp)
14 # target_link_libraries(speakrecognize -lMNN ${OpenCV_LIBS})
15 target_link_libraries(speakrecognize -lspeakrecognize -lMNN ${OpenCV_LIBS})
16
1 #ifndef FACELANDMARKS_H
2 #define FACELANDMARKS_H
3 #include <opencv2/opencv.hpp>
4 #include<MNN/Interpreter.hpp>
5 #include<MNN/ImageProcess.hpp>
6 #include<iostream>
7 #include<memory>
8
9 using namespace std;
10 using namespace cv;
11 using namespace MNN;
12
13 class FaceLandmarks{
14 public:
15 int num_thread = 2;
16 MNNForwardType forward_type = MNN_FORWARD_CPU;
17
18 public:
19 FaceLandmarks(){};
20 // ~FaceLandmarks();
21 bool init_model(string model_path);
22 vector<vector<float>> inference(string image_path);
23 vector<vector<float>> inference(Mat image);
24
25 private:
26 bool model_init;
27 float normal[3]={1.0f/256.f,1.0f/256.f,1.0f/256.f};
28 std::shared_ptr<MNN::Interpreter> pfld_interpreter = nullptr;
29 MNN::Session* session = nullptr;
30 MNN::Tensor* input_tensor = nullptr;
31 shared_ptr<MNN::CV::ImageProcess> pretreat;
32 };
33 #endif
1 #ifndef RETINAFACE_H
2 #define RETINAFACE_H
3 #include<opencv2/opencv.hpp>
4 #include<MNN/Interpreter.hpp>
5 #include<MNN/ImageProcess.hpp>
6 #include<iostream>
7 #include<memory>
8
9 using namespace MNN;
10 using namespace std;
11 using namespace cv;
12 struct Bbox{
13 float xmin;
14 float ymin;
15 float xmax;
16 float ymax;
17 float score;
18 float x1;
19 float y1;
20 float x2;
21 float y2;
22 float x3;
23 float y3;
24 float x4;
25 float y4;
26 float x5;
27 float y5;
28 };
29 class RetinaFace{
30 public:
31 float confidence_threshold = 0.5;
32 bool is_bbox_process=true;
33 int num_thread = 2;
34 MNNForwardType forward_type = MNN_FORWARD_CPU;
35 private:
36 bool model_init=false;
37 vector<int> input_size={640,640};
38 vector<float> variances={0.1,0.2};
39 float mean[3] = {104.0f, 117.0f, 123.0f};
40 float keep_top_k = 100;
41 float nms_threshold = 0.4;
42 float resize_scale = 1.0;
43
44 std::shared_ptr<MNN::Interpreter> net;
45 Session *session = nullptr;
46 MNN::Tensor* input_tensor=nullptr;
47 shared_ptr<MNN::CV::ImageProcess> pretreat;
48 vector<vector<float>> anchors;
49
50 private:
51 // 生成anchors
52 vector<vector<float>> priorBox(vector<int> image_size);
53 // 解析bounding box landmarks 包含置信度
54 vector<Bbox> decode(float *loc,float *score,float *pre,vector<vector<float>> priors,vector<float> variances);
55 // 解析landmarks
56 // vector<vector<float>> decode_landm(vector<vector<float>> pre,vector<vector<float>> priors,vector<float> variances);
57 //NMS
58 void nms_cpu(std::vector<Bbox> &bboxes, float threshold);
59 // 根据阈值筛选
60 vector<Bbox> select_score(vector<Bbox> bboxes,float threshold,float w_r,float h_r);
61 // 数据后处理
62 vector<Bbox> bbox_process(vector<Bbox> bboxes,float frame_w,float frame_h);
63
64 public:
65
66 RetinaFace(){};
67 // ~RetinaFace();
68 bool init_model(string model_path);
69
70 // 推理
71 vector<Bbox> inference(string image_path);
72 vector<Bbox> inference(Mat image);
73 };
74 #endif
...\ No newline at end of file ...\ No newline at end of file
1 #ifndef SPEAKCLS_DETECTOR
2 #define SPEALCLS_DETECTOR
3 #include "speakcls.h"
4 #include "retinaface.h"
5 #include "facelandmarks.h"
6
7 class SpeakDetector{
8 private:
9 RetinaFace face_det;
10 FaceLandmarks landm_det;
11 SpeakCls speak_cls;
12
13
14 public:
15 SpeakDetector(){};
16 void init_model(string face_det_model,string landm_det_model,string speak_cls_model);
17 float iou_compute(Bbox b1, Bbox b2);
18 vector<vector<cv::Mat>> mouth_process(vector<vector<vector<vector<float>>>> batch_landmarks, vector<cv::Mat> batch_images);
19 void image_reader(string file_path,int segment_num,vector<Mat> &bgr_frames,vector<vector<int>> &indices);
20
21
22 void speak_recognize(string image_path);
23
24 };
25
26
27 #endif
...\ No newline at end of file ...\ No newline at end of file
1 #ifndef SPEAKCLS_H
2 #define SPEALCLS_H
3 #include<opencv2/opencv.hpp>
4 #include<MNN/Interpreter.hpp>
5 #include<MNN/ImageProcess.hpp>
6 #include<memory>
7 using namespace std;
8 using namespace cv;
9 using namespace MNN;
10 class SpeakCls{
11 private:
12 std::shared_ptr<MNN::Interpreter> net;
13 MNN::Session* session;
14 MNN::Tensor* input_tensor;
15 ScheduleConfig config;
16 int split_nums = 10;
17
18 public:
19 SpeakCls(){};
20 bool init_model(string model_path);
21 bool inference(vector<Mat> images);
22
23 private:
24 cv::Mat standardize(cv::Mat image);
25 cv::Mat data_process(vector<Mat> images);
26 vector<double> softmax(vector<double> input);
27
28 };
29 #endif
...\ No newline at end of file ...\ No newline at end of file
No preview for this file type
1 #include "speak_detector.h"
2
3 int main(){
4
5 SpeakDetector speak = SpeakDetector();
6 string face_det_model = "/home/situ/qfs/sdk_project/mnn_projects/speak_recognize_mnn/model/det_face_retina_mnn_1.0.0_v0.1.1.mnn";
7 string face_landm_model = "/home/situ/qfs/sdk_project/mnn_projects/speak_recognize_mnn/model/det_landmarks_106_v0.0.1.mnn";
8 string speakcls_model = "/home/situ/qfs/sdk_project/mnn_projects/speak_recognize_mnn/model/cls_speak_v0.2.2.mnn";
9 speak.init_model(face_det_model,face_landm_model,speakcls_model);
10 speak.speak_recognize("/data/speak/bank_test/no_speak/2395QUESTION_ANSWER");
11 return 0;
12 }
No preview for this file type
No preview for this file type
No preview for this file type
1 #include "speak_detector.h"
2
3 void SpeakDetector::init_model(string face_det_model,string landm_det_model,string speak_cls_model){
4 face_det = RetinaFace();
5 face_det.init_model(face_det_model);
6 landm_det = FaceLandmarks();
7 landm_det.init_model(landm_det_model);
8 speak_cls = SpeakCls();
9 speak_cls.init_model(speak_cls_model);
10 }
11
12 float SpeakDetector::iou_compute(Bbox b1, Bbox b2)
13 {
14 float tmp_w=min(b1.xmax,b2.xmax) - max(b1.xmin, b2.xmin);
15 float tmp_h=min(b1.ymax, b2.ymax) - max(b1.ymin, b2.ymin);
16 float w = max(tmp_w, float(0));
17 float h = max(tmp_h, float(0));
18 return w*h / ((b1.xmax-b1.xmin)*(b1.ymax-b1.ymin) + (b2.xmax-b2.xmin)*(b2.ymax-b2.ymin) - w*h);
19 }
20 vector<vector<cv::Mat>> SpeakDetector::mouth_process(vector<vector<vector<vector<float>>>> batch_landmarks, vector<cv::Mat> batch_images){
21 int input_size=112;
22 vector<vector<cv::Mat>> align_mouths;
23 for(int i=0;i<batch_images.size();++i){
24 cv::Mat image = batch_images[i];
25 vector<cv::Mat> tmp_mouths;
26 for(int j=0;j<batch_landmarks[i].size();++j){
27 vector<float> mouth_xs;
28 vector<float> mouth_ys;
29 for(int k=84;k<int(104);++k){
30 float x_q = round(batch_landmarks[i][j][k][0]);
31 float y_q = round(batch_landmarks[i][j][k][1]);
32 mouth_xs.push_back(x_q);
33 mouth_ys.push_back(y_q);
34 }
35 float mouth_width=*max_element(mouth_xs.begin(),mouth_xs.end())-*min_element(mouth_xs.begin(),mouth_xs.end());
36 float mouth_height=*max_element(mouth_ys.begin(),mouth_ys.end())-*min_element(mouth_ys.begin(),mouth_ys.end());
37 int mouth_min_x=ceil(*min_element(mouth_xs.begin(),mouth_xs.end())-mouth_width*0.2);
38 int mouth_min_y=ceil(*min_element(mouth_ys.begin(),mouth_ys.end())-mouth_height*0.1);
39 int mouth_max_x=ceil(*max_element(mouth_xs.begin(),mouth_xs.end())+mouth_width*0.2);
40 int mouth_max_y=ceil(*max_element(mouth_ys.begin(),mouth_ys.end())+mouth_height*0.1);
41
42 mouth_min_x=mouth_min_x>0?mouth_min_x:0;
43 mouth_min_y=mouth_min_y>0?mouth_min_y:0;
44 cv::Rect mouth_rect = Rect(mouth_min_x,mouth_min_y,mouth_max_x-mouth_min_x,mouth_max_y-mouth_min_y);
45 cv::Mat mouth_crop = image(mouth_rect);
46 cv::Mat resize_mouth_crop;
47 cv::resize(mouth_crop,resize_mouth_crop,Size(input_size,input_size));
48 Point center=Point(input_size/2,input_size/2);
49 float dx = batch_landmarks[i][j][90][0]-batch_landmarks[i][j][84][0];
50 float dy = batch_landmarks[i][j][90][1]-batch_landmarks[i][j][84][1];
51 double angle = atan2(dy,dx)*180/float(M_PI);
52 cv::Mat rotate_matrix = cv::getRotationMatrix2D(center,double(angle),1);
53 cv::Mat rot_img;
54 cv::warpAffine(resize_mouth_crop,rot_img,rotate_matrix,Size(input_size,input_size));
55 tmp_mouths.push_back(rot_img);
56 }
57 align_mouths.push_back(tmp_mouths);
58 }
59 return align_mouths;
60 }
61 //视频/图像数据切片
62 //图像
63 void SpeakDetector::image_reader(string file_path,int segment_num,vector<Mat> &bgr_frames,vector<vector<int>> &indices){
64 int new_length = 1;
65 vector<String> image_files;
66 glob(file_path, image_files, false);
67 int total_frames_num = (int)image_files.size();
68 float tick = float(total_frames_num - new_length + 1) / float(segment_num);
69 vector<int> indice;
70 for(int x=0;x<segment_num;++x){
71 indice.push_back(int(tick / 2.0 + tick * x));
72 }
73 indices.push_back(indice);
74
75 for(auto im_file:image_files){
76 Mat bgr_img=cv::imread(im_file);
77 bgr_frames.push_back(bgr_img);
78 }
79 }
80 void SpeakDetector::speak_recognize(string image_path){
81 vector<Mat> all_bgr_images;
82 vector<vector<int>> total_split_indices;
83 image_reader(image_path,10,all_bgr_images,total_split_indices);
84 // vector<json> all_results;
85
86 bool is_talk=false;
87
88 for(int im_i=0;im_i<total_split_indices.size();++im_i){
89 vector<vector<cv::Mat>> face_list;
90 vector<vector<Bbox>> bbox_list;
91 vector<cv::Mat> rgb_frames;
92 vector<cv::Mat> bgr_frames;
93 int tmp_rows,tmp_cols;
94 for(int im_j=0;im_j<total_split_indices[im_i].size();++im_j){
95 Mat tmp_img=all_bgr_images[total_split_indices[im_i][im_j]];
96 if(im_j !=0){
97 if(tmp_img.rows!=tmp_rows&&tmp_img.cols!=tmp_cols){
98 cv::resize(tmp_img,tmp_img,Size(int(tmp_img.cols),int(tmp_img.rows)));
99 }
100 }
101 tmp_rows=tmp_img.rows;
102 tmp_cols=tmp_img.cols;
103 Mat rgb_tmp_img;
104 cv::cvtColor(tmp_img,rgb_tmp_img,cv::COLOR_BGR2RGB);
105 bgr_frames.push_back(tmp_img);
106 rgb_frames.push_back(rgb_tmp_img);
107
108 }
109 for(auto bgr_frame:bgr_frames){
110 vector<Bbox> boxes=face_det.inference(bgr_frame);
111 vector<cv::Mat> tmp_face_areas;
112 vector<Bbox> tmp_bbox_list;
113 for(auto box:boxes){
114 tmp_bbox_list.push_back(box);
115
116 // cout<<box.xmin<<" "<<box.ymin<<" "<<box.xmax-box.xmin<<" "<<box.ymax-box.ymin<<endl;
117
118 Rect m_select = Rect(box.xmin,box.ymin,box.xmax-box.xmin,box.ymax-box.ymin);
119
120 cv::Mat face_area=bgr_frame(m_select);
121 tmp_face_areas.push_back(face_area);
122 // cv::waitKey(0);
123 }
124 face_list.push_back(tmp_face_areas);
125 bbox_list.push_back(tmp_bbox_list);
126 }
127 // cout<<123<<endl;
128 vector<vector<vector<vector<float>>>> landms_list;
129 for(int i=0;i<face_list.size();++i){
130 vector<vector<vector<float>>> tmp_landm_list;
131 for(int j=0;j<face_list[i].size();++j){
132 vector<vector<float>> tmp_landms=landm_det.inference(face_list[i][j]);
133 for(int k=0;k<tmp_landms.size();++k){
134 tmp_landms[k][0]=tmp_landms[k][0]+bbox_list[i][j].xmin;
135 tmp_landms[k][1]=tmp_landms[k][1]+bbox_list[i][j].ymin;
136 }
137 tmp_landm_list.push_back(tmp_landms);
138 }
139 landms_list.push_back(tmp_landm_list);
140 }
141 vector<vector<cv::Mat>> mouth_list=mouth_process(landms_list,rgb_frames);
142
143 vector<vector<Bbox>> last_bboxes=bbox_list;
144 vector<Bbox> first_bboxes = bbox_list[0];
145 vector<vector<Bbox>>::iterator k = last_bboxes.begin();
146 last_bboxes.erase(k);
147
148 vector<vector<Bbox>> all_track_bbox_list;
149 vector<vector<cv::Mat>> all_face_list,all_mouth_list;
150
151 for(int i=0;i<first_bboxes.size();++i){
152 Bbox first_bbox=first_bboxes[i];
153 vector<Bbox> track_bbox_list;
154 vector<cv::Mat> trace_face_list,trace_mouth_list;
155 track_bbox_list.push_back(first_bbox);
156 trace_face_list.push_back(face_list[0][i]);
157 trace_mouth_list.push_back(mouth_list[0][i]);
158 for(int j=0;j<last_bboxes.size();++j){
159 vector<Bbox> next_bboxes=last_bboxes[j];
160 for(int k=0;k<next_bboxes.size();++k){
161
162 Bbox next_bbox = next_bboxes[k];
163 float iou=iou_compute(first_bbox,next_bbox);
164 if(iou>=0.4){
165 track_bbox_list.push_back(next_bbox);
166 trace_face_list.push_back(face_list[j+1][k]);
167 trace_mouth_list.push_back(mouth_list[j+1][k]);
168 break;
169 }
170 }
171 }
172 all_track_bbox_list.push_back(track_bbox_list);
173 all_face_list.push_back(trace_face_list);
174 all_mouth_list.push_back(trace_mouth_list);
175 }
176 for(int j=0;j<all_mouth_list.size();j++){
177 vector<cv::Mat> select_mouth_list=all_mouth_list[j];
178
179 /**
180 * @brief 模型推理部分代码,返回result 0/1 ,其中1为说话,0为未说话
181 *
182 */
183 bool result=speak_cls.inference(select_mouth_list);
184
185 // bool result=true;
186 if(result){
187 is_talk=true;
188 // speak_duration = (split_indices[0], split_indices[-1])
189 // Mat speaker = all_face_list[j][0];
190 // speaker_str = cv::imencode('.jpg', speaker)[1].tostring()
191 // speaker_str = base64.b64encode(speaker_str).decode()
192 // int position = j
193 // json cur_output={
194 // "is_talk":true,
195 // "speak_duration":[str(speak_duration[0]), str(speak_duration[1])],
196 // "speaker":speaker_str,
197 // "position":position
198 // }
199 // all_results.push_back(cur_output);
200 cout<<is_talk<<endl;
201 }else{
202 cout<<is_talk<<endl;
203 }
204
205 }
206
207 // return 0;
208 }
209 }
...\ No newline at end of file ...\ No newline at end of file
1 #include "facelandmarks.h"
2
3
4 // FaceLandmarks::~FaceLandmarks(){
5 // pfld_interpreter->releaseModel();
6 // pfld_interpreter->releaseSession(session);
7 // }
8
9 bool FaceLandmarks::init_model(string model_path){
10 pfld_interpreter = unique_ptr<MNN::Interpreter>(MNN::Interpreter::createFromFile(model_path.c_str()));
11 if(nullptr==pfld_interpreter){
12 return false;
13 }
14 //创建session
15 MNN::ScheduleConfig schedule_config;
16 schedule_config.type = forward_type;
17 schedule_config.numThread = num_thread;
18 MNN::BackendConfig backend_config;
19 backend_config.memory = MNN::BackendConfig::Memory_Normal;
20 backend_config.power = MNN::BackendConfig::Power_Normal;
21 backend_config.precision = MNN::BackendConfig::Precision_Normal;
22 schedule_config.backendConfig = &backend_config;
23 session = pfld_interpreter->createSession(schedule_config);
24 input_tensor = pfld_interpreter->getSessionInput(session,NULL);
25 pfld_interpreter->resizeTensor(input_tensor,{1,3,112,112});
26 pfld_interpreter->resizeSession(session);
27
28 //数据预处理
29 MNN::CV::ImageProcess::Config image_config;
30 ::memcpy(image_config.normal,normal,sizeof(normal));
31 image_config.sourceFormat = MNN::CV::BGR;
32 image_config.destFormat = MNN::CV::BGR;
33
34 pretreat = shared_ptr<MNN::CV::ImageProcess>(MNN::CV::ImageProcess::create(image_config));
35 // pretreat->setMatrix(transforms);
36
37 return true;
38 }
39
40 vector<vector<float>> FaceLandmarks::inference(string image_path){
41 Mat image = cv::imread(image_path);
42 vector<vector<float>> landmarks;
43 int width = image.cols;
44 int height = image.rows;
45 Mat resize_image;
46 cv::resize(image,resize_image,Size(112,112));
47 float ws = float(width)/float(112.0);
48 float hs = float(height)/float(112.0);
49
50 pretreat->convert(resize_image.data,112,112,0,input_tensor);
51
52 pfld_interpreter->runSession(session);
53
54 auto output_landmark = pfld_interpreter->getSessionOutput(session, NULL);
55 MNN::Tensor landmark_tensor(output_landmark, output_landmark->getDimensionType());
56 output_landmark->copyToHostTensor(&landmark_tensor);
57 float* result = landmark_tensor.host<float>();
58 for (int i = 0; i < 106; ++i) {
59 vector<float> curr_pt={result[2 * i + 0] * ws,result[2 * i + 1] * hs};
60 landmarks.push_back(curr_pt);
61 }
62 return landmarks;
63 }
64
65 vector<vector<float>> FaceLandmarks::inference(Mat image){
66 vector<vector<float>> landmarks;
67 int width = image.cols;
68 int height = image.rows;
69 Mat resize_image;
70 cv::resize(image,resize_image,Size(112,112));
71 float ws = float(width)/float(112.0);
72 float hs = float(height)/float(112.0);
73
74 pretreat->convert(resize_image.data,112,112,0,input_tensor);
75
76 pfld_interpreter->runSession(session);
77
78 auto output_landmark = pfld_interpreter->getSessionOutput(session, NULL);
79 MNN::Tensor landmark_tensor(output_landmark, output_landmark->getDimensionType());
80 output_landmark->copyToHostTensor(&landmark_tensor);
81 float* result = landmark_tensor.host<float>();
82 for (int i = 0; i < 106; ++i) {
83 vector<float> curr_pt={result[2 * i + 0] * ws,result[2 * i + 1] * hs};
84 landmarks.push_back(curr_pt);
85 }
86 return landmarks;
87 }
...\ No newline at end of file ...\ No newline at end of file
1 #include "speakcls.h"
2
3
4 bool SpeakCls::init_model(string model_path){
5 net= std::shared_ptr<MNN::Interpreter>(MNN::Interpreter::createFromFile(model_path.c_str()));//创建解释器
6 config.numThread = 2;
7 config.type = MNN_FORWARD_CPU;
8 session = net->createSession(config);
9 input_tensor = net->getSessionInput(session,NULL);
10 net->resizeTensor(input_tensor,{1,3*split_nums,112,112});
11 net->resizeSession(session);
12 }
13
14 cv::Mat SpeakCls::standardize(cv::Mat image){
15 cv::Mat image_f,dst;
16 image.convertTo(image_f, CV_32F);
17 Scalar max_pix = Scalar(255.0f,255.0f,255.0f);
18 Scalar mean = Scalar(0.485f, 0.456f, 0.406f);
19 Scalar std = Scalar(0.229f, 0.224f, 0.225f);
20 dst=image_f/max_pix;
21 dst = (dst-mean)/std;
22 return dst;
23 }
24
25 cv::Mat SpeakCls::data_process(vector<Mat> images){
26 std::vector<cv::Mat> all_image_channels;
27 for(auto f:images){
28 Mat tmp_image = standardize(f);
29 std::vector<cv::Mat> tmp_channels;
30 cv::split(tmp_image,tmp_channels);
31 all_image_channels.push_back(tmp_channels[0]);
32 all_image_channels.push_back(tmp_channels[1]);
33 all_image_channels.push_back(tmp_channels[2]);
34 }
35 Mat input_data;
36 cv::merge(all_image_channels,input_data);
37 return input_data;
38 }
39
40 vector<double> SpeakCls::softmax(vector<double> input){
41 double total=0;
42 for(auto x:input)
43 {
44 total+=exp(x);
45 }
46 vector<double> result;
47 for(auto x:input)
48 {
49 result.push_back(exp(x)/total);
50 }
51 return result;
52 }
53
54 bool SpeakCls::inference(vector<Mat> images){
55
56 Mat input_data=data_process(images);
57 // cout << _Tensor->elementSize() << endl;
58 std::vector<std::vector<cv::Mat>> nChannels;
59 std::vector<cv::Mat> rgbChannels(3*split_nums);
60 cv::split(input_data, rgbChannels);
61 nChannels.push_back(rgbChannels); // NHWC 转NCHW
62 auto *pvData = malloc(1 * 3*split_nums * 112 * 112 *sizeof(float));
63 int nPlaneSize = 112 * 112;
64 for (int c = 0; c < 3*split_nums; ++c)
65 {
66 cv::Mat matPlane = nChannels[0][c];
67 memcpy((float *)(pvData) + c * nPlaneSize,\
68 matPlane.data, nPlaneSize * sizeof(float));
69 }
70
71 auto nchwTensor = new Tensor(input_tensor, Tensor::CAFFE);
72 ::memcpy(nchwTensor->host<float>(), pvData, nPlaneSize * 3*split_nums * sizeof(float));
73
74 input_tensor->copyFromHostTensor(nchwTensor);
75 //推理
76 net->runSession(session);
77 auto output= net->getSessionOutput(session, NULL);
78
79 MNN::Tensor feat_tensor(output, output->getDimensionType());
80 output->copyToHostTensor(&feat_tensor);
81
82 auto scores_dataPtr = feat_tensor.host<float>();
83
84 cout<<scores_dataPtr[0]<<" "<<scores_dataPtr[1]<<endl;
85 vector<double> outputs={scores_dataPtr[0],scores_dataPtr[1]};
86 // softmax
87 vector<double> result=softmax(outputs);
88
89 printf("output belong to class: %f %f\n", result[0],result[1]);
90 if(result[0]>result[1]){
91 return false;
92 }else{
93 return true;
94 }
95 }
...\ No newline at end of file ...\ No newline at end of file
Styling with Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!