05e0f320 by 周伟奇

add simple word2vec

1 parent 3e58f6b0
......@@ -55,7 +55,7 @@ solver:
# name: 'CrossEntropyLoss'
args:
reduction: "mean"
alpha: 0.95
alpha: 0.8
logger:
log_root: '/Users/zhouweiqi/Downloads/test/logs'
......
......@@ -7,34 +7,55 @@ import uuid
import cv2
import pandas as pd
from tools import get_file_paths, load_json
from word2vec import simple_word2vec
def clean_go_res(go_res_dir):
max_seq_count = None
seq_sum = 0
file_count = 0
go_res_json_paths = get_file_paths(go_res_dir, ['.json', ])
for go_res_json_path in go_res_json_paths:
print('Info: start {0}'.format(go_res_json_path))
remove_key_set = set()
go_res = load_json(go_res_json_path)
for key, (_, text) in go_res.items():
remove_idx_set = set()
src_go_res_list = load_json(go_res_json_path)
for idx, (_, text) in enumerate(src_go_res_list):
if text.strip() == '':
remove_key_set.add(key)
remove_idx_set.add(idx)
print(text)
if len(remove_key_set) > 0:
for del_key in remove_key_set:
del go_res[del_key]
if len(remove_idx_set) > 0:
for del_idx in remove_idx_set:
del src_go_res_list[del_idx]
go_res_list = sorted(list(go_res.values()), key=lambda x: (x[0][1], x[0][0]), reverse=False)
go_res_list = sorted(src_go_res_list, key=lambda x: (x[0][1], x[0][0]), reverse=False)
with open(go_res_json_path, 'w') as fp:
json.dump(go_res_list, fp)
print('Rerewirte {0}'.format(go_res_json_path))
def char_length_statistics(go_res_dir):
max_char_length = None
target_file_name = None
go_res_json_paths = get_file_paths(go_res_dir, ['.json', ])
for go_res_json_path in go_res_json_paths:
print('Info: start {0}'.format(go_res_json_path))
src_go_res_list = load_json(go_res_json_path)
for _, text in src_go_res_list:
if max_char_length is None or len(text.strip()) > max_char_length:
max_char_length = len(text.strip())
target_file_name = go_res_json_path
return max_char_length, target_file_name
def bbox_statistics(go_res_dir):
max_seq_count = None
seq_sum = 0
file_count = 0
go_res_json_paths = get_file_paths(go_res_dir, ['.json', ])
for go_res_json_path in go_res_json_paths:
print('Info: start {0}'.format(go_res_json_path))
go_res_list = load_json(go_res_json_path)
seq_sum += len(go_res_list)
file_count += 1
if max_seq_count is None or len(go_res_list) > max_seq_count:
......@@ -168,21 +189,35 @@ def build_dataset(img_dir, go_res_dir, label_dir, top_text_list, skip_list, save
y_true = list()
for i in range(160):
if i >= valid_lens:
X.append([0., 0., 0., 0., 0., 0., 0., 0., 0.])
X.append([0. for _ in range(14)])
y_true.append([0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
elif i in top_text_idx_set:
(x0, y0, x1, y1, x2, y2, x3, y3), _ = go_res_list[i]
X.append([1., x0/w, y0/h, x1/w, y1/h, x2/w, y2/h, x3/w, y3/h])
(x0, y0, x1, y1, x2, y2, x3, y3), text = go_res_list[i]
feature_vec = [1.]
feature_vec.extend(simple_word2vec(text))
feature_vec.extend([x0/w, y0/h, x1/w, y1/h, x2/w, y2/h, x3/w, y3/h])
X.append(feature_vec)
y_true.append([0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
elif i in label_idx_dict:
(x0, y0, x1, y1, x2, y2, x3, y3), _ = go_res_list[i]
X.append([0., x0/w, y0/h, x1/w, y1/h, x2/w, y2/h, x3/w, y3/h])
(x0, y0, x1, y1, x2, y2, x3, y3), text = go_res_list[i]
feature_vec = [0.]
feature_vec.extend(simple_word2vec(text))
feature_vec.extend([x0/w, y0/h, x1/w, y1/h, x2/w, y2/h, x3/w, y3/h])
X.append(feature_vec)
base_label_list = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
base_label_list[label_idx_dict[i]] = 1
y_true.append(base_label_list)
else:
(x0, y0, x1, y1, x2, y2, x3, y3), _ = go_res_list[i]
X.append([0., x0/w, y0/h, x1/w, y1/h, x2/w, y2/h, x3/w, y3/h])
(x0, y0, x1, y1, x2, y2, x3, y3), text = go_res_list[i]
feature_vec = [0.]
feature_vec.extend(simple_word2vec(text))
feature_vec.extend([x0/w, y0/h, x1/w, y1/h, x2/w, y2/h, x3/w, y3/h])
X.append(feature_vec)
y_true.append([0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
all_data = [X, y_true, valid_lens]
......@@ -222,11 +257,15 @@ if __name__ == '__main__':
valid_dataset_dir = os.path.join(dataset_save_dir, 'valid')
valid_anno_file_path = os.path.join(dataset_save_dir, 'valid.csv')
# max_seq_lens, seq_lens_mean, max_seq_file_name = clean_go_res(go_dir)
# max_seq_lens, seq_lens_mean, max_seq_file_name = bbox_statistics(go_dir)
# print(max_seq_lens) # 152
# print(max_seq_file_name) # CH-B101805176_page_2_img_0.json
# print(max_seq_file_name) # train/CH-B101805176_page_2_img_0.json
# print(seq_lens_mean) # 92
# max_char_lens, target_file_name = char_length_statistics(go_dir)
# print(max_char_lens) # 72
# print(target_file_name) # train/CH-B103053828-4.json
# top_text_list = text_statistics(go_dir)
# for t in top_text_list:
# print(t)
......@@ -288,4 +327,6 @@ if __name__ == '__main__':
build_dataset(valid_image_path, valid_go_path, valid_label_path, filter_from_top_text_list, skip_list_valid, valid_dataset_dir)
build_anno_file(valid_dataset_dir, valid_anno_file_path)
# print(simple_word2vec(' fd2jk接口 额24;叁‘,。测ADF壹试!¥? '))
......
import re
# from gensim.models import word2vec
def simple_word2vec(text):
clean_text = text.strip()
text_len = len(clean_text)
digit_num = 0
en_num = 0
cn_num = 0
space_num = 0
other_num = 0
for char in clean_text:
if char.isdigit():
digit_num += 1
elif re.match(r'[A-Za-z]', char):
en_num += 1
elif char.isspace():
space_num += 1
elif re.match(r'[\u4e00-\u9fa5]', char):
cn_num += 1
else:
other_num += 1
vec = [text_len/100,
cn_num/text_len,
en_num/text_len,
digit_num/text_len,
# space_num/text_len,
other_num/text_len,
]
# print(text)
# print(clean_text)
# print('-------------')
# print(en_num)
# print(cn_num)
# print(digit_num)
# print(space_num)
# print(other_num)
# print('-------------')
return vec
\ No newline at end of file
Styling with Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!