60c39554 by 周伟奇

add drwa

1 parent 890ea78a
......@@ -3,9 +3,9 @@ seed: 3407
dataset:
name: 'SLData'
args:
data_root: '/Users/zhouweiqi/Downloads/gcfp/data/dataset2'
train_anno_file: '/Users/zhouweiqi/Downloads/gcfp/data/dataset2/train.csv'
val_anno_file: '/Users/zhouweiqi/Downloads/gcfp/data/dataset2/valid.csv'
data_root: '/dataset160x14'
train_anno_file: '/dataset160x14/train.csv'
val_anno_file: '/dataset160x14/valid.csv'
dataloader:
batch_size: 8
......@@ -18,7 +18,7 @@ model:
args:
seq_lens: 160
num_classes: 10
embed_dim: 9
embed_dim: 14
depth: 6
num_heads: 1
mlp_ratio: 4.0
......@@ -36,6 +36,11 @@ solver:
epoch: 100
base_on: null
model_path: null
val_image_path: '/labeled/valid/image'
val_go_path: '/go_res/valid'
val_map_path: '/dataset160x14/create_map.json'
draw_font_path: '/dataset160x14/STZHONGS.TTF'
thresholds: 0.5
optimizer:
name: 'Adam'
......@@ -58,5 +63,5 @@ solver:
alpha: 0.8
logger:
log_root: '/Users/zhouweiqi/Downloads/test/logs'
log_root: '/logs'
suffix: 'sl-6-1'
\ No newline at end of file
......
......@@ -7,7 +7,7 @@ import uuid
import cv2
import pandas as pd
from tools import get_file_paths, load_json
from word2vec import simple_word2vec, jwq_word2vec
from word2vec import jwq_word2vec, simple_word2vec
def clean_go_res(go_res_dir):
......@@ -101,7 +101,7 @@ def build_anno_file(dataset_dir, anno_file_path):
df['name'] = img_list
df.to_csv(anno_file_path)
def build_dataset(img_dir, go_res_dir, label_dir, top_text_list, skip_list, save_dir):
def build_dataset(img_dir, go_res_dir, label_dir, top_text_list, skip_list, save_dir, is_create_map=False):
"""
Args:
img_dir: str 图片目录
......@@ -121,6 +121,7 @@ def build_dataset(img_dir, go_res_dir, label_dir, top_text_list, skip_list, save
group_cn_list = ['开票日期', '发票代码', '机打号码', '车辆类型', '电话', '发动机号码', '车架号', '帐号', '开户银行', '小写']
test_group_id = [1, 2, 5, 9, 20, 15, 16, 22, 24, 28]
create_map = {}
for img_name in sorted(os.listdir(img_dir)):
if img_name in skip_list:
print('Info: skip {0}'.format(img_name))
......@@ -188,8 +189,9 @@ def build_dataset(img_dir, go_res_dir, label_dir, top_text_list, skip_list, save
X = list()
y_true = list()
text_vec_max_lens = 15 * 50
dim = 1 + 5 + 8 + text_vec_max_lens
# text_vec_max_lens = 15 * 50
# dim = 1 + 5 + 8 + text_vec_max_lens
dim = 1 + 5 + 8
num_classes = 10
for i in range(160):
if i >= valid_lens:
......@@ -201,7 +203,7 @@ def build_dataset(img_dir, go_res_dir, label_dir, top_text_list, skip_list, save
feature_vec = [1.]
feature_vec.extend(simple_word2vec(text))
feature_vec.extend([x0/w, y0/h, x1/w, y1/h, x2/w, y2/h, x3/w, y3/h])
feature_vec.extend(jwq_word2vec(text, text_vec_max_lens))
# feature_vec.extend(jwq_word2vec(text, text_vec_max_lens))
X.append(feature_vec)
y_true.append([0 for _ in range(num_classes)])
......@@ -211,7 +213,7 @@ def build_dataset(img_dir, go_res_dir, label_dir, top_text_list, skip_list, save
feature_vec = [0.]
feature_vec.extend(simple_word2vec(text))
feature_vec.extend([x0/w, y0/h, x1/w, y1/h, x2/w, y2/h, x3/w, y3/h])
feature_vec.extend(jwq_word2vec(text, text_vec_max_lens))
# feature_vec.extend(jwq_word2vec(text, text_vec_max_lens))
X.append(feature_vec)
base_label_list = [0 for _ in range(num_classes)]
......@@ -222,16 +224,34 @@ def build_dataset(img_dir, go_res_dir, label_dir, top_text_list, skip_list, save
feature_vec = [0.]
feature_vec.extend(simple_word2vec(text))
feature_vec.extend([x0/w, y0/h, x1/w, y1/h, x2/w, y2/h, x3/w, y3/h])
feature_vec.extend(jwq_word2vec(text, text_vec_max_lens))
# feature_vec.extend(jwq_word2vec(text, text_vec_max_lens))
X.append(feature_vec)
y_true.append([0 for _ in range(num_classes)])
all_data = [X, y_true, valid_lens]
with open(os.path.join(save_dir, '{0}.json'.format(uuid.uuid3(uuid.NAMESPACE_DNS, img_name))), 'w') as fp:
save_json_name = '{0}.json'.format(uuid.uuid3(uuid.NAMESPACE_DNS, img_name))
with open(os.path.join(save_dir, save_json_name), 'w') as fp:
json.dump(all_data, fp)
if is_create_map:
create_map[img_name] = {
'x_y_valid_lens': save_json_name,
'find_top_text': [go_res_list[i][-1] for i in top_text_idx_set],
'find_value': {group_cn_list[v]: go_res_list[k][-1] for k, v in label_idx_dict.items()}
}
# break
# print(create_map)
# print(is_create_map)
if create_map:
with open(os.path.join(os.path.dirname(save_dir), 'create_map.json'), 'w') as fp:
json.dump(create_map, fp)
# print('top text find:')
# for i in top_text_idx_set:
# _, text = go_res_list[i]
......@@ -249,7 +269,7 @@ def build_dataset(img_dir, go_res_dir, label_dir, top_text_list, skip_list, save
if __name__ == '__main__':
base_dir = '/Users/zhouweiqi/Downloads/gcfp/data'
go_dir = os.path.join(base_dir, 'go_res')
dataset_save_dir = os.path.join(base_dir, 'dataset2')
dataset_save_dir = os.path.join(base_dir, 'dataset160x14')
label_dir = os.path.join(base_dir, 'labeled')
train_go_path = os.path.join(go_dir, 'train')
......@@ -331,7 +351,7 @@ if __name__ == '__main__':
build_dataset(train_image_path, train_go_path, train_label_path, filter_from_top_text_list, skip_list_train, train_dataset_dir)
build_anno_file(train_dataset_dir, train_anno_file_path)
build_dataset(valid_image_path, valid_go_path, valid_label_path, filter_from_top_text_list, skip_list_valid, valid_dataset_dir)
build_dataset(valid_image_path, valid_go_path, valid_label_path, filter_from_top_text_list, skip_list_valid, valid_dataset_dir, True)
build_anno_file(valid_dataset_dir, valid_anno_file_path)
# print(simple_word2vec(' fd2jk接口 额24;叁‘,。测ADF壹试!¥? '))
......
CUDA_VISIBLE_DEVICES=0 nohup python main.py --config=config/sl.yaml -d > draw.log 2>&1 &
\ No newline at end of file
CUDA_VISIBLE_DEVICES=0 nohup python main.py --config=config/sl.yaml -e > eval.log 2>&1 &
\ No newline at end of file
......@@ -8,6 +8,7 @@ def main():
parser = argparse.ArgumentParser()
parser.add_argument('--config', default='./config/mlp.yaml', type=str, help='config file')
parser.add_argument('-e', '--eval', action="store_true")
parser.add_argument('-d', '--draw', action="store_true")
args = parser.parse_args()
cfg = yaml.load(open(args.config, 'r').read(), Loader=yaml.FullLoader)
......@@ -18,6 +19,8 @@ def main():
if args.eval:
solver.evaluate()
elif args.draw:
solver.draw_val()
else:
solver.run()
......
......@@ -18,7 +18,7 @@ def masked_softmax(X, valid_lens):
# [batch_size, num_heads, seq_len, seq_len]
shape = X.shape
if valid_lens.dim() == 1:
valid_lens = torch.repeat_interleave(valid_lens, shape[1])
valid_lens = torch.repeat_interleave(valid_lens, shape[2])
else:
valid_lens = valid_lens.reshape(-1)
# On the last axis, replace masked elements with a very large negative
......
import copy
import os
import cv2
import json
import torch
from PIL import Image, ImageDraw, ImageFont
from data import build_dataloader
from loss import build_loss
......@@ -34,6 +37,11 @@ class SLSolver(object):
self.hyper_params = cfg['solver']['args']
self.base_on = self.hyper_params['base_on']
self.model_path = self.hyper_params['model_path']
self.val_image_path = self.hyper_params['val_image_path']
self.val_go_path = self.hyper_params['val_go_path']
self.val_map_path = self.hyper_params['val_map_path']
self.draw_font_path = self.hyper_params['draw_font_path']
self.thresholds = self.hyper_params['thresholds']
try:
self.epoch = self.hyper_params['epoch']
except Exception:
......@@ -41,19 +49,22 @@ class SLSolver(object):
self.logger, self.log_dir = get_logger_and_log_dir(**cfg['solver']['logger'])
def accuracy(self, y_pred, y_true, valid_lens, thresholds=0.5):
def accuracy(self, y_pred, y_true, valid_lens, eval=False):
# [batch_size, seq_len, num_classes]
y_pred_sigmoid = torch.nn.Sigmoid()(y_pred)
# [batch_size, seq_len]
y_pred_idx = torch.argmax(y_pred_sigmoid, dim=-1) + 1
# [batch_size, seq_len]
y_pred_is_other = (torch.amax(y_pred_sigmoid, dim=-1) > thresholds).int()
y_pred_is_other = (torch.amax(y_pred_sigmoid, dim=-1) > self.thresholds).int()
y_pred_rebuild = torch.multiply(y_pred_idx, y_pred_is_other)
y_true_idx = torch.argmax(y_true, dim=-1) + 1
y_true_is_other = torch.sum(y_true, dim=-1).int()
y_true_rebuild = torch.multiply(y_true_idx, y_true_is_other)
if eval:
return y_pred_rebuild, y_true_rebuild
masked_y_true_rebuild = sequence_mask(y_true_rebuild, valid_lens, value=-1)
return torch.sum((y_pred_rebuild == masked_y_true_rebuild).int()).item()
......@@ -168,19 +179,7 @@ class SLSolver(object):
# pred = torch.nn.Sigmoid()(self.model(X))
y_pred = self.model(X, valid_lens)
# [batch_size, seq_len, num_classes]
y_pred_sigmoid = torch.nn.Sigmoid()(y_pred)
# [batch_size, seq_len]
y_pred_idx = torch.argmax(y_pred_sigmoid, dim=-1) + 1
# [batch_size, seq_len]
y_pred_is_other = (torch.amax(y_pred_sigmoid, dim=-1) > 0.5).int()
y_pred_rebuild = torch.multiply(y_pred_idx, y_pred_is_other)
y_true_idx = torch.argmax(y_true, dim=-1) + 1
y_true_is_other = torch.sum(y_true, dim=-1).int()
y_true_rebuild = torch.multiply(y_true_idx, y_true_is_other)
# masked_y_true_rebuild = sequence_mask(y_true_rebuild, valid_lens, value=-1)
y_pred_rebuild, y_true_rebuild = self.accuracy(y_pred, y_true, valid_lens, eval=True)
for idx, seq_result in enumerate(y_true_rebuild.cpu().numpy().tolist()):
label_true_list.extend(seq_result[: valid_lens.cpu().numpy()[idx]])
......@@ -193,3 +192,111 @@ class SLSolver(object):
print(acc)
print(cm)
print(report)
def draw_val(self):
if not os.path.isdir(self.val_image_path):
print('Warn: val_image_path not exists: {0}'.format(self.val_image_path))
return
if not os.path.isdir(self.val_go_path):
print('Warn: val_go_path not exists: {0}'.format(self.val_go_path))
return
if not os.path.isfile(self.val_map_path):
print('Warn: val_map_path not exists: {0}'.format(self.val_map_path))
return
map_key_input = 'x_y_valid_lens'
map_key_text = 'find_top_text'
map_key_value = 'find_value'
group_cn_list = ['其他', '开票日期', '发票代码', '机打号码', '车辆类型', '电话', '发动机号码', '车架号', '帐号', '开户银行', '小写']
dataset_base_dir = os.path.dirname(self.val_map_path)
val_dataset_dir = os.path.join(dataset_base_dir, 'valid')
save_dir = os.path.join(dataset_base_dir, 'draw_val')
if not os.path.isdir(save_dir):
os.makedirs(save_dir, exist_ok=True)
self.model.eval()
with open(self.val_map_path, 'r') as fp:
val_map = json.load(fp)
for img_name in sorted(os.listdir(self.val_image_path)):
print('Info: start {0}'.format(img_name))
image_path = os.path.join(self.val_image_path, img_name)
img = cv2.imread(image_path)
im_h, im_w, _ = img.shape
img_pil = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
draw = ImageDraw.Draw(img_pil)
if im_h < im_w:
size = int(im_h * 0.015)
else:
size = int(im_w * 0.015)
if size < 14:
size = 14
font = ImageFont.truetype(self.draw_font_path, size, encoding='utf-8')
green_color = (0, 255, 0)
red_color = (255, 0, 0)
blue_color = (0, 0, 255)
base_image_name, _ = os.path.splitext(img_name)
go_res_json_path = os.path.join(self.val_go_path, '{0}.json'.format(base_image_name))
with open(go_res_json_path, 'r') as fp:
go_res_list = json.load(fp)
with open(os.path.join(val_dataset_dir, val_map[img_name][map_key_input]), 'r') as fp:
input_list, label_list, valid_lens_scalar = json.load(fp)
X = torch.tensor(input_list).unsqueeze(0).to(self.device)
y_true = torch.tensor(label_list).unsqueeze(0).float().to(self.device)
valid_lens = torch.tenor([valid_lens_scalar, ]).to(self.device)
del input_list
del label_list
y_pred = self.model(X, valid_lens)
y_pred_rebuild, y_true_rebuild = self.accuracy(y_pred, y_true, valid_lens, eval=True)
pred = y_pred_rebuild.cpu().numpy().tolist()[0]
label = y_true_rebuild.cpu().numpy().tolist()[0]
correct = 0
bbox_draw_dict = dict()
for i in range(valid_lens_scalar):
if pred[i] == label[i]:
correct += 1
if pred[i] != 0:
# 绿色
bbox_draw_dict[i] = (group_cn_list[pred[i]], )
else:
# 红色:左上角label,右上角pred
bbox_draw_dict[i] = (group_cn_list[label[i]], group_cn_list[pred[i]])
correct_rate = correct / valid_lens_scalar
# 画图
for idx, text_tuple in bbox_draw_dict.items():
(x0, y0, x1, y1, x2, y2, x3, y3), _ = go_res_list[idx]
line_color = green_color if len(text_tuple) == 1 else red_color
draw.polygon([(x0, y0), (x1, y1), (x2, y2), (x3, y3)], outline=line_color)
draw.text((int(x0), int(y0)), text_tuple[0], green_color, font=font)
if len(text_tuple) == 2:
draw.text((int(x1), int(y1)), text_tuple[1], red_color, font=font)
draw.text((0, 0), str(correct_rate), blue_color, font=font)
last_y = size
for k, v in val_map[img_name][map_key_value].items():
draw.text((0, last_y), '{0}: {1}'.format(k, v), blue_color, font=font)
last_y += size
img_pil.save(os.path.join(save_dir, img_name))
# break
......
CUDA_VISIBLE_DEVICES=0 nohup python main.py > train.log 2>&1 &
\ No newline at end of file
CUDA_VISIBLE_DEVICES=0 nohup python main.py --config=config/sl.yaml > train.log 2>&1 &
\ No newline at end of file
......
Styling with Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!