40ca6fe1 by 周伟奇

add Seq Labeling solver

1 parent b3694ec8
1 seed: 3407
2
3 dataset:
4 name: 'SLData'
5 args:
6 data_root: '/Users/zhouweiqi/Downloads/gcfp/data/dataset2'
7 train_anno_file: '/Users/zhouweiqi/Downloads/gcfp/data/dataset2/train.csv'
8 val_anno_file: '/Users/zhouweiqi/Downloads/gcfp/data/dataset2/valid.csv'
9
10 dataloader:
11 batch_size: 8
12 num_workers: 4
13 pin_memory: true
14 shuffle: true
15
16 model:
17 name: 'SLTransformer'
18 args:
19 seq_lens: 200
20 num_classes: 10
21 embed_dim: 9
22 depth: 6
23 num_heads: 1
24 mlp_ratio: 4.0
25 qkv_bias: true
26 qk_scale: null
27 drop_ratio: 0.
28 attn_drop_ratio: 0.
29 drop_path_ratio: 0.
30 norm_layer: null
31 act_layer: null
32
33 solver:
34 name: 'SLSolver'
35 args:
36 epoch: 100
37 base_on: null
38 model_path: null
39
40 optimizer:
41 name: 'Adam'
42 args:
43 lr: !!float 1e-3
44 # weight_decay: !!float 5e-5
45
46 lr_scheduler:
47 name: 'CosineLR'
48 args:
49 epochs: 100
50 lrf: 0.1
51
52 loss:
53 name: 'MaskedSigmoidFocalLoss'
54 # name: 'SigmoidFocalLoss'
55 # name: 'CrossEntropyLoss'
56 args:
57 reduction: "mean"
58 alpha: 0.95
59
60 logger:
61 log_root: '/Users/zhouweiqi/Downloads/test/logs'
62 suffix: 'sl-6-1'
...\ No newline at end of file ...\ No newline at end of file
...@@ -60,6 +60,7 @@ solver: ...@@ -60,6 +60,7 @@ solver:
60 # name: 'CrossEntropyLoss' 60 # name: 'CrossEntropyLoss'
61 args: 61 args:
62 reduction: "mean" 62 reduction: "mean"
63 alpha: 0.95
63 64
64 logger: 65 logger:
65 log_root: '/Users/zhouweiqi/Downloads/test/logs' 66 log_root: '/Users/zhouweiqi/Downloads/test/logs'
......
1 import os
2 import json
3 import torch
4 from torch.utils.data import DataLoader, Dataset
5 import pandas as pd
6 from utils.registery import DATASET_REGISTRY
7
8
9 @DATASET_REGISTRY.register()
10 class SLData(Dataset):
11
12 def __init__(self,
13 data_root: str = '/Users/zhouweiqi/Downloads/gcfp/data/dataset',
14 anno_file: str = 'train.csv',
15 phase: str = 'train'):
16 self.data_root = data_root
17 self.df = pd.read_csv(anno_file)
18 self.phase = phase
19
20
21 def __len__(self):
22 return len(self.df)
23
24 def __getitem__(self, idx):
25 series = self.df.iloc[idx]
26 name = series['name']
27
28 with open(os.path.join(self.data_root, self.phase, name), 'r') as fp:
29 input_list, label_list, valid_lens = json.load(fp)
30
31 input_tensor = torch.tensor(input_list)
32 label_tensor = torch.tensor(label_list).float()
33
34 return input_tensor, label_tensor, valid_lens
...\ No newline at end of file ...\ No newline at end of file
...@@ -3,6 +3,7 @@ from torch.utils.data import DataLoader ...@@ -3,6 +3,7 @@ from torch.utils.data import DataLoader
3 from utils.registery import DATASET_REGISTRY 3 from utils.registery import DATASET_REGISTRY
4 4
5 from .CoordinatesData import CoordinatesData 5 from .CoordinatesData import CoordinatesData
6 from .SLData import SLData
6 7
7 8
8 def build_dataset(cfg): 9 def build_dataset(cfg):
......
...@@ -93,7 +93,8 @@ def build_dataset(img_dir, go_res_dir, label_dir, top_text_list, skip_list, save ...@@ -93,7 +93,8 @@ def build_dataset(img_dir, go_res_dir, label_dir, top_text_list, skip_list, save
93 label_json_path = os.path.join(label_dir, '{0}.json'.format(base_image_name)) 93 label_json_path = os.path.join(label_dir, '{0}.json'.format(base_image_name))
94 label_res = load_json(label_json_path) 94 label_res = load_json(label_json_path)
95 95
96 # 开票日期 发票代码 机打号码 车辆类型 电话 96 # 开票日期 发票代码 机打号码 车辆类型 电话
97 # 发动机号码 车架号 帐号 开户银行 小写
97 test_group_id = [1, 2, 5, 9, 20] 98 test_group_id = [1, 2, 5, 9, 20]
98 group_list = [] 99 group_list = []
99 for group_id in test_group_id: 100 for group_id in test_group_id:
......
1 import copy
2 import json
3 import os
4 import random
5 import uuid
6
7 import cv2
8 import pandas as pd
9 from tools import get_file_paths, load_json
10
11
12 def clean_go_res(go_res_dir):
13 max_seq_count = None
14 seq_sum = 0
15 file_count = 0
16
17 go_res_json_paths = get_file_paths(go_res_dir, ['.json', ])
18 for go_res_json_path in go_res_json_paths:
19 print('Info: start {0}'.format(go_res_json_path))
20
21 remove_key_set = set()
22 go_res = load_json(go_res_json_path)
23 for key, (_, text) in go_res.items():
24 if text.strip() == '':
25 remove_key_set.add(key)
26 print(text)
27
28 if len(remove_key_set) > 0:
29 for del_key in remove_key_set:
30 del go_res[del_key]
31
32 go_res_list = sorted(list(go_res.values()), key=lambda x: (x[0][1], x[0][0]), reverse=False)
33
34 with open(go_res_json_path, 'w') as fp:
35 json.dump(go_res_list, fp)
36 print('Rerewirte {0}'.format(go_res_json_path))
37
38 seq_sum += len(go_res_list)
39 file_count += 1
40 if max_seq_count is None or len(go_res_list) > max_seq_count:
41 max_seq_count = len(go_res_list)
42 max_seq_file_name = go_res_json_path
43
44 seq_lens_mean = seq_sum // file_count
45 return max_seq_count, seq_lens_mean, max_seq_file_name
46
47 def text_statistics(go_res_dir):
48 """
49 Args:
50 go_res_dir: str 通用OCR的JSON文件夹
51 Returns: list 出现次数最多的文本及其次数
52 """
53 json_count = 0
54 text_dict = {}
55 go_res_json_paths = get_file_paths(go_res_dir, ['.json', ])
56 for go_res_json_path in go_res_json_paths:
57 print('Info: start {0}'.format(go_res_json_path))
58 json_count += 1
59 go_res = load_json(go_res_json_path)
60 for _, text in go_res.values():
61 if text in text_dict:
62 text_dict[text] += 1
63 else:
64 text_dict[text] = 1
65 top_text_list = []
66 # 按照次数排序
67 for text, count in sorted(text_dict.items(), key=lambda x: x[1], reverse=True):
68 if text == '':
69 continue
70 # 丢弃:次数少于总数的2/3
71 if count <= json_count // 3:
72 break
73 top_text_list.append((text, count))
74 return top_text_list
75
76 def build_anno_file(dataset_dir, anno_file_path):
77 img_list = os.listdir(dataset_dir)
78 random.shuffle(img_list)
79 df = pd.DataFrame(columns=['name'])
80 df['name'] = img_list
81 df.to_csv(anno_file_path)
82
83 def build_dataset(img_dir, go_res_dir, label_dir, top_text_list, skip_list, save_dir):
84 """
85 Args:
86 img_dir: str 图片目录
87 go_res_dir: str 通用OCR的JSON保存目录
88 label_dir: str 标注的JSON保存目录
89 top_text_list: list 出现次数最多的文本及其次数
90 skip_list: list 跳过的图片列表
91 save_dir: str 数据集保存目录
92 """
93 if os.path.exists(save_dir):
94 return
95 else:
96 os.makedirs(save_dir, exist_ok=True)
97
98 # 开票日期 发票代码 机打号码 车辆类型 电话
99 # 发动机号码 车架号 帐号 开户银行 小写
100 group_cn_list = ['开票日期', '发票代码', '机打号码', '车辆类型', '电话', '发动机号码', '车架号', '帐号', '开户银行', '小写']
101 test_group_id = [1, 2, 5, 9, 20, 15, 16, 22, 24, 28]
102
103 for img_name in sorted(os.listdir(img_dir)):
104 if img_name in skip_list:
105 print('Info: skip {0}'.format(img_name))
106 continue
107
108 print('Info: start {0}'.format(img_name))
109 image_path = os.path.join(img_dir, img_name)
110 img = cv2.imread(image_path)
111 h, w, _ = img.shape
112 base_image_name, _ = os.path.splitext(img_name)
113 go_res_json_path = os.path.join(go_res_dir, '{0}.json'.format(base_image_name))
114 go_res_list = load_json(go_res_json_path)
115
116 valid_lens = len(go_res_list)
117
118 top_text_idx_set = set()
119 for top_text, _ in top_text_list:
120 for go_idx, (_, text) in enumerate(go_res_list):
121 if text == top_text:
122 top_text_idx_set.add(go_idx)
123 break
124
125 label_json_path = os.path.join(label_dir, '{0}.json'.format(base_image_name))
126 label_res = load_json(label_json_path)
127
128 group_list = []
129 for group_id in test_group_id:
130 for item in label_res.get("shapes", []):
131 if item.get("group_id") == group_id:
132 x_list = []
133 y_list = []
134 for point in item['points']:
135 x_list.append(point[0])
136 y_list.append(point[1])
137 group_list.append([min(x_list) + (max(x_list) - min(x_list))/2, min(y_list) + (max(y_list) - min(y_list))/2])
138 break
139 else:
140 group_list.append(None)
141
142 go_center_list = []
143 for (x0, y0, x1, y1, x2, y2, x3, y3), _ in go_res_list:
144 xmin = min(x0, x1, x2, x3)
145 ymin = min(y0, y1, y2, y3)
146 xmax = max(x0, x1, x2, x3)
147 ymax = max(y0, y1, y2, y3)
148 xcenter = xmin + (xmax - xmin)/2
149 ycenter = ymin + (ymax - ymin)/2
150 go_center_list.append((xcenter, ycenter))
151
152 label_idx_dict = dict()
153 for label_idx, label_center_list in enumerate(group_list):
154 if isinstance(label_center_list, list):
155 min_go_key = None
156 min_length = None
157 for go_idx, (go_x_center, go_y_center) in enumerate(go_center_list):
158 if go_idx in top_text_idx_set or go_idx in label_idx_dict:
159 continue
160 length = abs(go_x_center-label_center_list[0])+abs(go_y_center-label_center_list[1])
161 if min_go_key is None or length < min_length:
162 min_go_key = go_idx
163 min_length = length
164 if min_go_key is not None:
165 label_idx_dict[min_go_key] = label_idx
166
167 X = list()
168 y_true = list()
169 for i in range(200):
170 if i >= valid_lens:
171 X.append([0., 0., 0., 0., 0., 0., 0., 0., 0.])
172 y_true.append([0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
173 elif i in top_text_idx_set:
174 (x0, y0, x1, y1, x2, y2, x3, y3), _ = go_res_list[i]
175 X.append([1., x0/w, y0/h, x1/w, y1/h, x2/w, y2/h, x3/w, y3/h])
176 y_true.append([0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
177 elif i in label_idx_dict:
178 (x0, y0, x1, y1, x2, y2, x3, y3), _ = go_res_list[i]
179 X.append([0., x0/w, y0/h, x1/w, y1/h, x2/w, y2/h, x3/w, y3/h])
180 base_label_list = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
181 base_label_list[label_idx_dict[i]] = 1
182 y_true.append(base_label_list)
183 else:
184 (x0, y0, x1, y1, x2, y2, x3, y3), _ = go_res_list[i]
185 X.append([0., x0/w, y0/h, x1/w, y1/h, x2/w, y2/h, x3/w, y3/h])
186 y_true.append([0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
187
188 all_data = [X, y_true, valid_lens]
189
190 with open(os.path.join(save_dir, '{0}.json'.format(uuid.uuid3(uuid.NAMESPACE_DNS, img_name))), 'w') as fp:
191 json.dump(all_data, fp)
192
193 # print('top text find:')
194 # for i in top_text_idx_set:
195 # _, text = go_res_list[i]
196 # print(text)
197
198 # print('-------------')
199 # print('label value find:')
200 # for k, v in label_idx_dict.items():
201 # _, text = go_res_list[k]
202 # print('{0}: {1}'.format(group_cn_list[v], text))
203
204 # break
205
206
207 if __name__ == '__main__':
208 base_dir = '/Users/zhouweiqi/Downloads/gcfp/data'
209 go_dir = os.path.join(base_dir, 'go_res')
210 dataset_save_dir = os.path.join(base_dir, 'dataset2')
211 label_dir = os.path.join(base_dir, 'labeled')
212
213 train_go_path = os.path.join(go_dir, 'train')
214 train_image_path = os.path.join(label_dir, 'train', 'image')
215 train_label_path = os.path.join(label_dir, 'train', 'label')
216 train_dataset_dir = os.path.join(dataset_save_dir, 'train')
217 train_anno_file_path = os.path.join(dataset_save_dir, 'train.csv')
218
219 valid_go_path = os.path.join(go_dir, 'valid')
220 valid_image_path = os.path.join(label_dir, 'valid', 'image')
221 valid_label_path = os.path.join(label_dir, 'valid', 'label')
222 valid_dataset_dir = os.path.join(dataset_save_dir, 'valid')
223 valid_anno_file_path = os.path.join(dataset_save_dir, 'valid.csv')
224
225 # max_seq_lens, seq_lens_mean, max_seq_file_name = clean_go_res(go_dir)
226 # print(max_seq_lens) # 152
227 # print(max_seq_file_name) # CH-B101805176_page_2_img_0.json
228 # print(seq_lens_mean) # 92
229
230 # top_text_list = text_statistics(go_dir)
231 # for t in top_text_list:
232 # print(t)
233
234 filter_from_top_text_list = [
235 ('机器编号', 496),
236 ('购买方名称', 496),
237 ('合格证号', 495),
238 ('进口证明书号', 495),
239 ('机打代码', 494),
240 ('车辆类型', 492),
241 ('完税凭证号码', 492),
242 ('机打号码', 491),
243 ('发动机号码', 491),
244 ('主管税务', 491),
245 ('价税合计', 489),
246 ('机关及代码', 489),
247 ('销货单位名称', 486),
248 ('厂牌型号', 485),
249 ('产地', 485),
250 ('商检单号', 483),
251 ('电话', 476),
252 ('开户银行', 472),
253 ('车辆识别代号/车架号码', 463),
254 ('身份证号码', 454),
255 ('吨位', 452),
256 ('备注:一车一票', 439),
257 ('地', 432),
258 ('账号', 431),
259 ('统一社会信用代码/', 424),
260 ('限乘人数', 404),
261 ('税额', 465),
262 ('址', 392)
263 ]
264
265 skip_list_train = [
266 'CH-B101910792-page-12.jpg',
267 'CH-B101655312-page-13.jpg',
268 'CH-B102278656.jpg',
269 'CH-B101846620_page_1_img_0.jpg',
270 'CH-B103062528-0.jpg',
271 'CH-B102613120-3.jpg',
272 'CH-B102997980-3.jpg',
273 'CH-B102680060-3.jpg',
274 # 'CH-B102995500-2.jpg', # 没value
275 ]
276
277 skip_list_valid = [
278 'CH-B102897920-2.jpg',
279 'CH-B102551284-0.jpg',
280 'CH-B102879376-2.jpg',
281 'CH-B101509488-page-16.jpg',
282 'CH-B102708352-2.jpg',
283 ]
284
285 build_dataset(train_image_path, train_go_path, train_label_path, filter_from_top_text_list, skip_list_train, train_dataset_dir)
286 build_anno_file(train_dataset_dir, train_anno_file_path)
287
288 build_dataset(valid_image_path, valid_go_path, valid_label_path, filter_from_top_text_list, skip_list_valid, valid_dataset_dir)
289 build_anno_file(valid_dataset_dir, valid_anno_file_path)
290
291
...@@ -2,6 +2,7 @@ import copy ...@@ -2,6 +2,7 @@ import copy
2 import torch 2 import torch
3 import inspect 3 import inspect
4 from utils.registery import LOSS_REGISTRY 4 from utils.registery import LOSS_REGISTRY
5 from utils import sequence_mask
5 from torchvision.ops import sigmoid_focal_loss 6 from torchvision.ops import sigmoid_focal_loss
6 7
7 class SigmoidFocalLoss(torch.nn.modules.loss._WeightedLoss): 8 class SigmoidFocalLoss(torch.nn.modules.loss._WeightedLoss):
...@@ -21,9 +22,31 @@ class SigmoidFocalLoss(torch.nn.modules.loss._WeightedLoss): ...@@ -21,9 +22,31 @@ class SigmoidFocalLoss(torch.nn.modules.loss._WeightedLoss):
21 def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: 22 def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
22 return sigmoid_focal_loss(inputs, targets, self.alpha, self.gamma, self.reduction) 23 return sigmoid_focal_loss(inputs, targets, self.alpha, self.gamma, self.reduction)
23 24
25 class MaskedSigmoidFocalLoss(torch.nn.modules.loss._WeightedLoss):
26
27 def __init__(self,
28 weight= None,
29 size_average=None,
30 reduce=None,
31 reduction: str = 'mean',
32 alpha: float = 0.25,
33 gamma: float = 2):
34 super().__init__(weight, size_average, reduce, reduction)
35 self.alpha = alpha
36 self.gamma = gamma
37 self.reduction = reduction
38
39 def forward(self, inputs: torch.Tensor, targets: torch.Tensor, valid_lens) -> torch.Tensor:
40 weights = torch.ones_like(targets)
41 weights = sequence_mask(weights, valid_lens)
42 unweighted_loss = sigmoid_focal_loss(inputs, targets, self.alpha, self.gamma, reduction='none')
43 weighted_loss = (unweighted_loss * weights).mean(dim=-1)
44 return weighted_loss
45
24 46
25 def register_sigmoid_focal_loss(): 47 def register_sigmoid_focal_loss():
26 LOSS_REGISTRY.register()(SigmoidFocalLoss) 48 LOSS_REGISTRY.register()(SigmoidFocalLoss)
49 LOSS_REGISTRY.register()(MaskedSigmoidFocalLoss)
27 50
28 51
29 def register_torch_loss(): 52 def register_torch_loss():
......
...@@ -3,6 +3,7 @@ from utils import MODEL_REGISTRY ...@@ -3,6 +3,7 @@ from utils import MODEL_REGISTRY
3 3
4 from .mlp import MLPModel 4 from .mlp import MLPModel
5 from .vit import VisionTransformer 5 from .vit import VisionTransformer
6 from .seq_labeling import SLTransformer
6 7
7 8
8 def build_model(cfg): 9 def build_model(cfg):
......
1 import math
2 from functools import partial
3 from collections import OrderedDict
4
5 import torch
6 import torch.nn as nn
7 from utils.registery import MODEL_REGISTRY
8 from utils import sequence_mask
9
10
11 def masked_softmax(X, valid_lens):
12 """Perform softmax operation by masking elements on the last axis.
13 Defined in :numref:`sec_attention-scoring-functions`"""
14 # `X`: 3D tensor, `valid_lens`: 1D or 2D tensor
15 if valid_lens is None:
16 return nn.functional.softmax(X, dim=-1)
17 else:
18 # [batch_size, num_heads, seq_len, seq_len]
19 shape = X.shape
20 if valid_lens.dim() == 1:
21 valid_lens = torch.repeat_interleave(valid_lens, shape[2])
22 else:
23 valid_lens = valid_lens.reshape(-1)
24 # On the last axis, replace masked elements with a very large negative
25 # value, whose exponentiation outputs 0
26 X = sequence_mask(X.reshape(-1, shape[-1]), valid_lens, value=-1e6)
27 return nn.functional.softmax(X.reshape(shape), dim=-1)
28
29
30 class PositionalEncoding(nn.Module):
31 """Positional encoding.
32 Defined in :numref:`sec_self-attention-and-positional-encoding`"""
33 def __init__(self, embed_dim, drop_ratio, max_len=1000):
34 super(PositionalEncoding, self).__init__()
35 self.dropout = nn.Dropout(drop_ratio)
36 # Create a long enough `P`
37 self.P = torch.zeros((1, max_len, embed_dim))
38 X = torch.arange(max_len, dtype=torch.float32).reshape(
39 -1, 1) / torch.pow(10000, torch.arange(
40 0, embed_dim, 2, dtype=torch.float32) / embed_dim)
41 self.P[:, :, 0::2] = torch.sin(X)
42 self.P[:, :, 1::2] = torch.cos(X)
43
44 def forward(self, X):
45 X = X + self.P[:, :X.shape[1], :].to(X.device)
46 return self.dropout(X)
47
48
49 def _init_vit_weights(m):
50 """
51 ViT weight initialization
52 :param m: module
53 """
54 if isinstance(m, nn.Linear):
55 nn.init.trunc_normal_(m.weight, std=.01)
56 if m.bias is not None:
57 nn.init.zeros_(m.bias)
58 elif isinstance(m, nn.Conv2d):
59 nn.init.kaiming_normal_(m.weight, mode="fan_out")
60 if m.bias is not None:
61 nn.init.zeros_(m.bias)
62 elif isinstance(m, nn.LayerNorm):
63 nn.init.zeros_(m.bias)
64 nn.init.ones_(m.weight)
65
66
67 def drop_path(x, drop_prob: float = 0., training: bool = False):
68 """
69 Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
70 This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
71 the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
72 See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
73 changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
74 'survival rate' as the argument.
75 """
76 if drop_prob == 0. or not training:
77 return x
78 keep_prob = 1 - drop_prob
79 shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
80 random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
81 random_tensor.floor_() # binarize
82 output = x.div(keep_prob) * random_tensor
83 return output
84
85
86 class DropPath(nn.Module):
87 """
88 Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
89 """
90 def __init__(self, drop_prob=None):
91 super(DropPath, self).__init__()
92 self.drop_prob = drop_prob
93
94 def forward(self, x):
95 return drop_path(x, self.drop_prob, self.training)
96
97
98 class Attention(nn.Module):
99 def __init__(self,
100 dim, # 输入token的dim
101 num_heads=8,
102 qkv_bias=False,
103 qk_scale=None,
104 attn_drop_ratio=0.,
105 proj_drop_ratio=0.):
106 super(Attention, self).__init__()
107 self.num_heads = num_heads
108 head_dim = dim // num_heads
109 self.scale = qk_scale or head_dim ** -0.5
110 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
111 self.attn_drop = nn.Dropout(attn_drop_ratio)
112 self.proj = nn.Linear(dim, dim)
113 self.proj_drop = nn.Dropout(proj_drop_ratio)
114
115 def forward(self, x, valid_lens):
116 # [batch_size, seq_len, total_embed_dim]
117 B, N, C = x.shape
118
119 # qkv(): -> [batch_size, seq_len, 3 * total_embed_dim]
120 # reshape: -> [batch_size, seq_len, 3, num_heads, embed_dim_per_head]
121 # permute: -> [3, batch_size, num_heads, seq_len, embed_dim_per_head]
122 qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
123 # [batch_size, num_heads, seq_len, embed_dim_per_head]
124 q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
125
126 # transpose: -> [batch_size, num_heads, embed_dim_per_head, seq_len]
127 # @: multiply -> [batch_size, num_heads, seq_len, seq_len]
128 attn = (q @ k.transpose(-2, -1)) * self.scale
129 # attn = attn.softmax(dim=-1)
130 attn = masked_softmax(attn, valid_lens)
131 attn = self.attn_drop(attn)
132
133 # @: multiply -> [batch_size, num_heads, seq_len, embed_dim_per_head]
134 # transpose: -> [batch_size, seq_len, num_heads, embed_dim_per_head]
135 # reshape: -> [batch_size, seq_len, total_embed_dim]
136 x = (attn @ v).transpose(1, 2).reshape(B, N, C)
137 x = self.proj(x)
138 x = self.proj_drop(x)
139 return x
140
141
142 class Mlp(nn.Module):
143 """
144 MLP as used in Vision Transformer, MLP-Mixer and related networks
145 """
146 def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
147 super().__init__()
148 out_features = out_features or in_features
149 hidden_features = hidden_features or in_features
150 self.fc1 = nn.Linear(in_features, hidden_features)
151 self.act = act_layer()
152 self.fc2 = nn.Linear(hidden_features, out_features)
153 self.drop = nn.Dropout(drop)
154
155 def forward(self, x):
156 x = self.fc1(x)
157 x = self.act(x)
158 x = self.drop(x)
159 x = self.fc2(x)
160 x = self.drop(x)
161 return x
162
163
164 class Block(nn.Module):
165 def __init__(self,
166 dim,
167 num_heads,
168 mlp_ratio=4.,
169 qkv_bias=False,
170 qk_scale=None,
171 drop_ratio=0.,
172 attn_drop_ratio=0.,
173 drop_path_ratio=0.,
174 act_layer=nn.GELU,
175 norm_layer=nn.LayerNorm):
176 super(Block, self).__init__()
177 self.norm1 = norm_layer(dim)
178 self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
179 attn_drop_ratio=attn_drop_ratio, proj_drop_ratio=drop_ratio)
180 # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
181 self.drop_path = DropPath(drop_path_ratio) if drop_path_ratio > 0. else nn.Identity()
182 self.norm2 = norm_layer(dim)
183 mlp_hidden_dim = int(dim * mlp_ratio)
184 self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop_ratio)
185
186 def forward(self, x, valid_lens):
187 # [batch_size, seq_len, total_embed_dim]
188 x = x + self.drop_path(self.attn(self.norm1(x), valid_lens))
189 # [batch_size, seq_len, total_embed_dim]
190 x = x + self.drop_path(self.mlp(self.norm2(x)))
191 return x
192
193
194 @MODEL_REGISTRY.register()
195 class SLTransformer(nn.Module):
196 def __init__(self,
197 seq_lens=200,
198 num_classes=1000,
199 embed_dim=768,
200 depth=12,
201 num_heads=12,
202 mlp_ratio=4.0,
203 qkv_bias=True,
204 qk_scale=None,
205 drop_ratio=0.,
206 attn_drop_ratio=0.,
207 drop_path_ratio=0.,
208 norm_layer=None,
209 act_layer=None,
210 ):
211 """
212 Args:
213 num_classes (int): number of classes for classification head
214 embed_dim (int): embedding dimension
215 depth (int): depth of transformer
216 num_heads (int): number of attention heads
217 mlp_ratio (int): ratio of mlp hidden dim to embedding dim
218 qkv_bias (bool): enable bias for qkv if True
219 qk_scale (float): override default qk scale of head_dim ** -0.5 if set
220 drop_ratio (float): dropout rate
221 attn_drop_ratio (float): attention dropout rate
222 drop_path_ratio (float): stochastic depth rate
223 norm_layer: (nn.Module): normalization layer
224 """
225 super(SLTransformer, self).__init__()
226 self.num_classes = num_classes
227 self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
228 norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
229 act_layer = act_layer or nn.GELU
230
231 # self.pos_embed = PositionalEncoding(self.embed_dim, drop_ratio, max_len=seq_lens)
232 self.pos_embed = nn.Parameter(torch.zeros(1, seq_lens, embed_dim))
233 self.pos_drop = nn.Dropout(p=drop_ratio)
234
235 dpr = [x.item() for x in torch.linspace(0, drop_path_ratio, depth)] # stochastic depth decay rule
236 self.blocks = nn.Sequential(*[
237 Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
238 drop_ratio=drop_ratio, attn_drop_ratio=attn_drop_ratio, drop_path_ratio=dpr[i],
239 norm_layer=norm_layer, act_layer=act_layer)
240 for i in range(depth)
241 ])
242 self.norm = norm_layer(embed_dim)
243
244 # Classifier head(s)
245 self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
246
247 # Weight init
248 nn.init.trunc_normal_(self.pos_embed, std=0.02)
249 self.apply(_init_vit_weights)
250
251 def forward(self, x, valid_lens):
252 # x: [B, seq_len, embed_dim]
253 # valid_lens: [B, ]
254
255 # TODO sin/cos位置编码?
256 # 因为位置编码值在-1和1之间,
257 # 因此嵌入值乘以嵌入维度的平方根进行缩放,
258 # 然后再与位置编码相加。
259 # x = self.pos_embed(x * math.sqrt(self.embed_dim))
260
261 # 参数的位置编码
262 x = self.pos_drop(x + self.pos_embed)
263
264 # [batch_size, seq_len, total_embed_dim]
265 for block in self.blocks:
266 x = block(x, valid_lens)
267 # x = self.blocks(x, valid_lens)
268 x = self.norm(x)
269
270 # [batch_size, seq_len, num_classes]
271 x = self.head(x)
272 return x
...\ No newline at end of file ...\ No newline at end of file
...@@ -3,6 +3,7 @@ import copy ...@@ -3,6 +3,7 @@ import copy
3 from utils.registery import SOLVER_REGISTRY 3 from utils.registery import SOLVER_REGISTRY
4 from .mlp_solver import MLPSolver 4 from .mlp_solver import MLPSolver
5 from .vit_solver import VITSolver 5 from .vit_solver import VITSolver
6 from .sl_solver import SLSolver
6 7
7 8
8 def build_solver(cfg): 9 def build_solver(cfg):
......
1 import copy
2 import os
3
4 import torch
5
6 from data import build_dataloader
7 from loss import build_loss
8 from model import build_model
9 from optimizer import build_lr_scheduler, build_optimizer
10 from utils import SOLVER_REGISTRY, get_logger_and_log_dir
11 from utils import sequence_mask
12 from sklearn.metrics import confusion_matrix, accuracy_score, classification_report
13
14
15 @SOLVER_REGISTRY.register()
16 class SLSolver(object):
17
18 def __init__(self, cfg):
19 self.device = "cuda" if torch.cuda.is_available() else "cpu"
20
21 self.cfg = copy.deepcopy(cfg)
22
23 self.train_loader, self.val_loader = build_dataloader(cfg)
24 self.train_loader_size, self.val_loader_size = len(self.train_loader), len(self.val_loader)
25 self.train_dataset_size, self.val_dataset_size = len(self.train_loader.dataset), len(self.val_loader.dataset)
26
27 # BatchNorm ?
28 self.model = build_model(cfg).to(self.device)
29
30 self.loss_fn = build_loss(cfg)
31
32 self.optimizer = build_optimizer(cfg)(self.model.parameters(), **cfg['solver']['optimizer']['args'])
33
34 self.hyper_params = cfg['solver']['args']
35 self.base_on = self.hyper_params['base_on']
36 self.model_path = self.hyper_params['model_path']
37 try:
38 self.epoch = self.hyper_params['epoch']
39 except Exception:
40 raise 'should contain epoch in {solver.args}'
41
42 self.logger, self.log_dir = get_logger_and_log_dir(**cfg['solver']['logger'])
43
44 def accuracy(self, y_pred, y_true, valid_lens, thresholds=0.5):
45 # [batch_size, seq_len, num_classes]
46 y_pred_sigmoid = torch.nn.Sigmoid()(y_pred)
47 # [batch_size, seq_len]
48 y_pred_idx = torch.argmax(y_pred_sigmoid, dim=-1) + 1
49 # [batch_size, seq_len]
50 y_pred_is_other = (torch.amax(y_pred_sigmoid, dim=-1) > thresholds).int()
51 y_pred_rebuild = torch.multiply(y_pred_idx, y_pred_is_other)
52
53 y_true_idx = torch.argmax(y_true, dim=-1) + 1
54 y_true_is_other = torch.sum(y_true, dim=-1).int()
55 y_true_rebuild = torch.multiply(y_true_idx, y_true_is_other)
56
57 masked_y_true_rebuild = sequence_mask(y_true_rebuild, valid_lens, value=-1)
58
59 return torch.sum((y_pred_rebuild == masked_y_true_rebuild).int()).item()
60
61 def train_loop(self):
62 self.model.train()
63
64 seq_lens_sum = torch.zeros(1).to(self.device)
65 train_loss = torch.zeros(1).to(self.device)
66 correct = torch.zeros(1).to(self.device)
67 for batch, (X, y, valid_lens) in enumerate(self.train_loader):
68 X, y = X.to(self.device), y.to(self.device)
69
70 pred = self.model(X, valid_lens)
71 # [batch_size, seq_len, num_classes]
72
73 loss = self.loss_fn(pred, y, valid_lens)
74 train_loss += loss.sum()
75
76 if batch % 100 == 0:
77 loss_value, current = loss.sum().item(), batch
78 self.logger.info(f'train iteration: {current}/{self.train_loader_size}, train loss: {loss_value :.4f}')
79
80 self.optimizer.zero_grad()
81 loss.sum().backward()
82 self.optimizer.step()
83
84 seq_lens_sum += valid_lens.sum()
85 correct += self.accuracy(pred, y, valid_lens)
86
87 # correct /= self.train_dataset_size
88 correct /= seq_lens_sum
89 train_loss /= self.train_loader_size
90 self.logger.info(f'train accuracy: {correct.item() :.4f}, train mean loss: {train_loss.item() :.4f}')
91
92 @torch.no_grad()
93 def val_loop(self, t):
94 self.model.eval()
95
96 seq_lens_sum = torch.zeros(1).to(self.device)
97 val_loss = torch.zeros(1).to(self.device)
98 correct = torch.zeros(1).to(self.device)
99 for X, y, valid_lens in self.val_loader:
100 X, y = X.to(self.device), y.to(self.device)
101
102 # pred = torch.nn.Sigmoid()(self.model(X))
103 pred = self.model(X, valid_lens)
104 # [batch_size, seq_len, num_classes]
105
106 loss = self.loss_fn(pred, y, valid_lens)
107 val_loss += loss.sum()
108
109 seq_lens_sum += valid_lens.sum()
110 correct += self.accuracy(pred, y, valid_lens)
111
112 # correct /= self.val_dataset_size
113 correct /= seq_lens_sum
114 val_loss /= self.val_loader_size
115
116 self.logger.info(f"val accuracy: {correct.item() :.4f}, val mean loss: {val_loss.item() :.4f}")
117
118 def save_checkpoint(self, epoch_id):
119 self.model.eval()
120 torch.save(self.model.state_dict(), os.path.join(self.log_dir, f'ckpt_epoch_{epoch_id}.pt'))
121
122 def run(self):
123 if isinstance(self.base_on, str) and os.path.exists(self.base_on):
124 self.model.load_state_dict(torch.load(self.base_on))
125 self.logger.info(f'==> Load Model from {self.base_on}')
126
127 self.logger.info('==> Start Training')
128 print(self.model)
129
130 lr_scheduler = build_lr_scheduler(self.cfg)(self.optimizer, **self.cfg['solver']['lr_scheduler']['args'])
131
132 for t in range(self.epoch):
133 self.logger.info(f'==> epoch {t + 1}')
134
135 self.train_loop()
136 self.val_loop(t + 1)
137 self.save_checkpoint(t + 1)
138
139 lr_scheduler.step()
140
141 self.logger.info('==> End Training')
142
143 # def run(self):
144 # from torch.nn import functional
145
146 # y = functional.one_hot(torch.randint(0, 10, (8, 100)), 10)
147 # valid_lens = torch.randint(50, 100, (8, ))
148 # print(valid_lens)
149
150 # pred = functional.one_hot(torch.randint(0, 10, (8, 100)), 10)
151
152 # print(self.accuracy(pred, y, valid_lens))
153
154 def evaluate(self):
155 if isinstance(self.model_path, str) and os.path.exists(self.model_path):
156 self.model.load_state_dict(torch.load(self.model_path))
157 self.logger.info(f'==> Load Model from {self.model_path}')
158 else:
159 return
160
161 self.model.eval()
162
163 label_true_list = []
164 label_pred_list = []
165 for X, y in self.val_loader:
166 X, y_true = X.to(self.device), y.to(self.device)
167
168 # pred = torch.nn.Sigmoid()(self.model(X))
169 pred = self.model(X)
170
171 y_pred = torch.nn.Sigmoid()(pred)
172
173 y_pred_idx = torch.argmax(y_pred, dim=1) + 1
174 y_pred_is_other = (torch.amax(y_pred, dim=1) > 0.5).int()
175 y_pred_rebuild = torch.multiply(y_pred_idx, y_pred_is_other)
176
177 y_true_idx = torch.argmax(y_true, dim=1) + 1
178 y_true_is_other = torch.sum(y_true, dim=1)
179 y_true_rebuild = torch.multiply(y_true_idx, y_true_is_other)
180
181 label_true_list.extend(y_true_rebuild.cpu().numpy().tolist())
182 label_pred_list.extend(y_pred_rebuild.cpu().numpy().tolist())
183
184
185 acc = accuracy_score(label_true_list, label_pred_list)
186 cm = confusion_matrix(label_true_list, label_pred_list)
187 report = classification_report(label_true_list, label_pred_list)
188 print(acc)
189 print(cm)
190 print(report)
1 import torch
1 from .registery import * 2 from .registery import *
2 from .logger import get_logger_and_log_dir 3 from .logger import get_logger_and_log_dir
3 4
4 __all__ = [ 5 __all__ = [
5 'Registry', 6 'Registry',
6 'get_logger_and_log_dir', 7 'get_logger_and_log_dir',
8 'sequence_mask',
7 ] 9 ]
8 10
11 def sequence_mask(X, valid_len, value=0):
12 """Mask irrelevant entries in sequences.
13 Defined in :numref:`sec_seq2seq_decoder`"""
14 maxlen = X.size(1)
15 mask = torch.arange((maxlen), dtype=torch.float32, device=X.device)[None, :] < valid_len[:, None]
16 X[~mask] = value
17 return X
......
Styling with Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!