add Seq Labeling solver
Showing
12 changed files
with
887 additions
and
1 deletions
config/sl.yaml
0 → 100644
1 | seed: 3407 | ||
2 | |||
3 | dataset: | ||
4 | name: 'SLData' | ||
5 | args: | ||
6 | data_root: '/Users/zhouweiqi/Downloads/gcfp/data/dataset2' | ||
7 | train_anno_file: '/Users/zhouweiqi/Downloads/gcfp/data/dataset2/train.csv' | ||
8 | val_anno_file: '/Users/zhouweiqi/Downloads/gcfp/data/dataset2/valid.csv' | ||
9 | |||
10 | dataloader: | ||
11 | batch_size: 8 | ||
12 | num_workers: 4 | ||
13 | pin_memory: true | ||
14 | shuffle: true | ||
15 | |||
16 | model: | ||
17 | name: 'SLTransformer' | ||
18 | args: | ||
19 | seq_lens: 200 | ||
20 | num_classes: 10 | ||
21 | embed_dim: 9 | ||
22 | depth: 6 | ||
23 | num_heads: 1 | ||
24 | mlp_ratio: 4.0 | ||
25 | qkv_bias: true | ||
26 | qk_scale: null | ||
27 | drop_ratio: 0. | ||
28 | attn_drop_ratio: 0. | ||
29 | drop_path_ratio: 0. | ||
30 | norm_layer: null | ||
31 | act_layer: null | ||
32 | |||
33 | solver: | ||
34 | name: 'SLSolver' | ||
35 | args: | ||
36 | epoch: 100 | ||
37 | base_on: null | ||
38 | model_path: null | ||
39 | |||
40 | optimizer: | ||
41 | name: 'Adam' | ||
42 | args: | ||
43 | lr: !!float 1e-3 | ||
44 | # weight_decay: !!float 5e-5 | ||
45 | |||
46 | lr_scheduler: | ||
47 | name: 'CosineLR' | ||
48 | args: | ||
49 | epochs: 100 | ||
50 | lrf: 0.1 | ||
51 | |||
52 | loss: | ||
53 | name: 'MaskedSigmoidFocalLoss' | ||
54 | # name: 'SigmoidFocalLoss' | ||
55 | # name: 'CrossEntropyLoss' | ||
56 | args: | ||
57 | reduction: "mean" | ||
58 | alpha: 0.95 | ||
59 | |||
60 | logger: | ||
61 | log_root: '/Users/zhouweiqi/Downloads/test/logs' | ||
62 | suffix: 'sl-6-1' | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
... | @@ -60,6 +60,7 @@ solver: | ... | @@ -60,6 +60,7 @@ solver: |
60 | # name: 'CrossEntropyLoss' | 60 | # name: 'CrossEntropyLoss' |
61 | args: | 61 | args: |
62 | reduction: "mean" | 62 | reduction: "mean" |
63 | alpha: 0.95 | ||
63 | 64 | ||
64 | logger: | 65 | logger: |
65 | log_root: '/Users/zhouweiqi/Downloads/test/logs' | 66 | log_root: '/Users/zhouweiqi/Downloads/test/logs' | ... | ... |
data/SLData.py
0 → 100644
1 | import os | ||
2 | import json | ||
3 | import torch | ||
4 | from torch.utils.data import DataLoader, Dataset | ||
5 | import pandas as pd | ||
6 | from utils.registery import DATASET_REGISTRY | ||
7 | |||
8 | |||
9 | @DATASET_REGISTRY.register() | ||
10 | class SLData(Dataset): | ||
11 | |||
12 | def __init__(self, | ||
13 | data_root: str = '/Users/zhouweiqi/Downloads/gcfp/data/dataset', | ||
14 | anno_file: str = 'train.csv', | ||
15 | phase: str = 'train'): | ||
16 | self.data_root = data_root | ||
17 | self.df = pd.read_csv(anno_file) | ||
18 | self.phase = phase | ||
19 | |||
20 | |||
21 | def __len__(self): | ||
22 | return len(self.df) | ||
23 | |||
24 | def __getitem__(self, idx): | ||
25 | series = self.df.iloc[idx] | ||
26 | name = series['name'] | ||
27 | |||
28 | with open(os.path.join(self.data_root, self.phase, name), 'r') as fp: | ||
29 | input_list, label_list, valid_lens = json.load(fp) | ||
30 | |||
31 | input_tensor = torch.tensor(input_list) | ||
32 | label_tensor = torch.tensor(label_list).float() | ||
33 | |||
34 | return input_tensor, label_tensor, valid_lens | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
... | @@ -3,6 +3,7 @@ from torch.utils.data import DataLoader | ... | @@ -3,6 +3,7 @@ from torch.utils.data import DataLoader |
3 | from utils.registery import DATASET_REGISTRY | 3 | from utils.registery import DATASET_REGISTRY |
4 | 4 | ||
5 | from .CoordinatesData import CoordinatesData | 5 | from .CoordinatesData import CoordinatesData |
6 | from .SLData import SLData | ||
6 | 7 | ||
7 | 8 | ||
8 | def build_dataset(cfg): | 9 | def build_dataset(cfg): | ... | ... |
... | @@ -93,7 +93,8 @@ def build_dataset(img_dir, go_res_dir, label_dir, top_text_list, skip_list, save | ... | @@ -93,7 +93,8 @@ def build_dataset(img_dir, go_res_dir, label_dir, top_text_list, skip_list, save |
93 | label_json_path = os.path.join(label_dir, '{0}.json'.format(base_image_name)) | 93 | label_json_path = os.path.join(label_dir, '{0}.json'.format(base_image_name)) |
94 | label_res = load_json(label_json_path) | 94 | label_res = load_json(label_json_path) |
95 | 95 | ||
96 | # 开票日期 发票代码 机打号码 车辆类型 电话 | 96 | # 开票日期 发票代码 机打号码 车辆类型 电话 |
97 | # 发动机号码 车架号 帐号 开户银行 小写 | ||
97 | test_group_id = [1, 2, 5, 9, 20] | 98 | test_group_id = [1, 2, 5, 9, 20] |
98 | group_list = [] | 99 | group_list = [] |
99 | for group_id in test_group_id: | 100 | for group_id in test_group_id: | ... | ... |
data/create_dataset2.py
0 → 100644
1 | import copy | ||
2 | import json | ||
3 | import os | ||
4 | import random | ||
5 | import uuid | ||
6 | |||
7 | import cv2 | ||
8 | import pandas as pd | ||
9 | from tools import get_file_paths, load_json | ||
10 | |||
11 | |||
12 | def clean_go_res(go_res_dir): | ||
13 | max_seq_count = None | ||
14 | seq_sum = 0 | ||
15 | file_count = 0 | ||
16 | |||
17 | go_res_json_paths = get_file_paths(go_res_dir, ['.json', ]) | ||
18 | for go_res_json_path in go_res_json_paths: | ||
19 | print('Info: start {0}'.format(go_res_json_path)) | ||
20 | |||
21 | remove_key_set = set() | ||
22 | go_res = load_json(go_res_json_path) | ||
23 | for key, (_, text) in go_res.items(): | ||
24 | if text.strip() == '': | ||
25 | remove_key_set.add(key) | ||
26 | print(text) | ||
27 | |||
28 | if len(remove_key_set) > 0: | ||
29 | for del_key in remove_key_set: | ||
30 | del go_res[del_key] | ||
31 | |||
32 | go_res_list = sorted(list(go_res.values()), key=lambda x: (x[0][1], x[0][0]), reverse=False) | ||
33 | |||
34 | with open(go_res_json_path, 'w') as fp: | ||
35 | json.dump(go_res_list, fp) | ||
36 | print('Rerewirte {0}'.format(go_res_json_path)) | ||
37 | |||
38 | seq_sum += len(go_res_list) | ||
39 | file_count += 1 | ||
40 | if max_seq_count is None or len(go_res_list) > max_seq_count: | ||
41 | max_seq_count = len(go_res_list) | ||
42 | max_seq_file_name = go_res_json_path | ||
43 | |||
44 | seq_lens_mean = seq_sum // file_count | ||
45 | return max_seq_count, seq_lens_mean, max_seq_file_name | ||
46 | |||
47 | def text_statistics(go_res_dir): | ||
48 | """ | ||
49 | Args: | ||
50 | go_res_dir: str 通用OCR的JSON文件夹 | ||
51 | Returns: list 出现次数最多的文本及其次数 | ||
52 | """ | ||
53 | json_count = 0 | ||
54 | text_dict = {} | ||
55 | go_res_json_paths = get_file_paths(go_res_dir, ['.json', ]) | ||
56 | for go_res_json_path in go_res_json_paths: | ||
57 | print('Info: start {0}'.format(go_res_json_path)) | ||
58 | json_count += 1 | ||
59 | go_res = load_json(go_res_json_path) | ||
60 | for _, text in go_res.values(): | ||
61 | if text in text_dict: | ||
62 | text_dict[text] += 1 | ||
63 | else: | ||
64 | text_dict[text] = 1 | ||
65 | top_text_list = [] | ||
66 | # 按照次数排序 | ||
67 | for text, count in sorted(text_dict.items(), key=lambda x: x[1], reverse=True): | ||
68 | if text == '': | ||
69 | continue | ||
70 | # 丢弃:次数少于总数的2/3 | ||
71 | if count <= json_count // 3: | ||
72 | break | ||
73 | top_text_list.append((text, count)) | ||
74 | return top_text_list | ||
75 | |||
76 | def build_anno_file(dataset_dir, anno_file_path): | ||
77 | img_list = os.listdir(dataset_dir) | ||
78 | random.shuffle(img_list) | ||
79 | df = pd.DataFrame(columns=['name']) | ||
80 | df['name'] = img_list | ||
81 | df.to_csv(anno_file_path) | ||
82 | |||
83 | def build_dataset(img_dir, go_res_dir, label_dir, top_text_list, skip_list, save_dir): | ||
84 | """ | ||
85 | Args: | ||
86 | img_dir: str 图片目录 | ||
87 | go_res_dir: str 通用OCR的JSON保存目录 | ||
88 | label_dir: str 标注的JSON保存目录 | ||
89 | top_text_list: list 出现次数最多的文本及其次数 | ||
90 | skip_list: list 跳过的图片列表 | ||
91 | save_dir: str 数据集保存目录 | ||
92 | """ | ||
93 | if os.path.exists(save_dir): | ||
94 | return | ||
95 | else: | ||
96 | os.makedirs(save_dir, exist_ok=True) | ||
97 | |||
98 | # 开票日期 发票代码 机打号码 车辆类型 电话 | ||
99 | # 发动机号码 车架号 帐号 开户银行 小写 | ||
100 | group_cn_list = ['开票日期', '发票代码', '机打号码', '车辆类型', '电话', '发动机号码', '车架号', '帐号', '开户银行', '小写'] | ||
101 | test_group_id = [1, 2, 5, 9, 20, 15, 16, 22, 24, 28] | ||
102 | |||
103 | for img_name in sorted(os.listdir(img_dir)): | ||
104 | if img_name in skip_list: | ||
105 | print('Info: skip {0}'.format(img_name)) | ||
106 | continue | ||
107 | |||
108 | print('Info: start {0}'.format(img_name)) | ||
109 | image_path = os.path.join(img_dir, img_name) | ||
110 | img = cv2.imread(image_path) | ||
111 | h, w, _ = img.shape | ||
112 | base_image_name, _ = os.path.splitext(img_name) | ||
113 | go_res_json_path = os.path.join(go_res_dir, '{0}.json'.format(base_image_name)) | ||
114 | go_res_list = load_json(go_res_json_path) | ||
115 | |||
116 | valid_lens = len(go_res_list) | ||
117 | |||
118 | top_text_idx_set = set() | ||
119 | for top_text, _ in top_text_list: | ||
120 | for go_idx, (_, text) in enumerate(go_res_list): | ||
121 | if text == top_text: | ||
122 | top_text_idx_set.add(go_idx) | ||
123 | break | ||
124 | |||
125 | label_json_path = os.path.join(label_dir, '{0}.json'.format(base_image_name)) | ||
126 | label_res = load_json(label_json_path) | ||
127 | |||
128 | group_list = [] | ||
129 | for group_id in test_group_id: | ||
130 | for item in label_res.get("shapes", []): | ||
131 | if item.get("group_id") == group_id: | ||
132 | x_list = [] | ||
133 | y_list = [] | ||
134 | for point in item['points']: | ||
135 | x_list.append(point[0]) | ||
136 | y_list.append(point[1]) | ||
137 | group_list.append([min(x_list) + (max(x_list) - min(x_list))/2, min(y_list) + (max(y_list) - min(y_list))/2]) | ||
138 | break | ||
139 | else: | ||
140 | group_list.append(None) | ||
141 | |||
142 | go_center_list = [] | ||
143 | for (x0, y0, x1, y1, x2, y2, x3, y3), _ in go_res_list: | ||
144 | xmin = min(x0, x1, x2, x3) | ||
145 | ymin = min(y0, y1, y2, y3) | ||
146 | xmax = max(x0, x1, x2, x3) | ||
147 | ymax = max(y0, y1, y2, y3) | ||
148 | xcenter = xmin + (xmax - xmin)/2 | ||
149 | ycenter = ymin + (ymax - ymin)/2 | ||
150 | go_center_list.append((xcenter, ycenter)) | ||
151 | |||
152 | label_idx_dict = dict() | ||
153 | for label_idx, label_center_list in enumerate(group_list): | ||
154 | if isinstance(label_center_list, list): | ||
155 | min_go_key = None | ||
156 | min_length = None | ||
157 | for go_idx, (go_x_center, go_y_center) in enumerate(go_center_list): | ||
158 | if go_idx in top_text_idx_set or go_idx in label_idx_dict: | ||
159 | continue | ||
160 | length = abs(go_x_center-label_center_list[0])+abs(go_y_center-label_center_list[1]) | ||
161 | if min_go_key is None or length < min_length: | ||
162 | min_go_key = go_idx | ||
163 | min_length = length | ||
164 | if min_go_key is not None: | ||
165 | label_idx_dict[min_go_key] = label_idx | ||
166 | |||
167 | X = list() | ||
168 | y_true = list() | ||
169 | for i in range(200): | ||
170 | if i >= valid_lens: | ||
171 | X.append([0., 0., 0., 0., 0., 0., 0., 0., 0.]) | ||
172 | y_true.append([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) | ||
173 | elif i in top_text_idx_set: | ||
174 | (x0, y0, x1, y1, x2, y2, x3, y3), _ = go_res_list[i] | ||
175 | X.append([1., x0/w, y0/h, x1/w, y1/h, x2/w, y2/h, x3/w, y3/h]) | ||
176 | y_true.append([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) | ||
177 | elif i in label_idx_dict: | ||
178 | (x0, y0, x1, y1, x2, y2, x3, y3), _ = go_res_list[i] | ||
179 | X.append([0., x0/w, y0/h, x1/w, y1/h, x2/w, y2/h, x3/w, y3/h]) | ||
180 | base_label_list = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0] | ||
181 | base_label_list[label_idx_dict[i]] = 1 | ||
182 | y_true.append(base_label_list) | ||
183 | else: | ||
184 | (x0, y0, x1, y1, x2, y2, x3, y3), _ = go_res_list[i] | ||
185 | X.append([0., x0/w, y0/h, x1/w, y1/h, x2/w, y2/h, x3/w, y3/h]) | ||
186 | y_true.append([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) | ||
187 | |||
188 | all_data = [X, y_true, valid_lens] | ||
189 | |||
190 | with open(os.path.join(save_dir, '{0}.json'.format(uuid.uuid3(uuid.NAMESPACE_DNS, img_name))), 'w') as fp: | ||
191 | json.dump(all_data, fp) | ||
192 | |||
193 | # print('top text find:') | ||
194 | # for i in top_text_idx_set: | ||
195 | # _, text = go_res_list[i] | ||
196 | # print(text) | ||
197 | |||
198 | # print('-------------') | ||
199 | # print('label value find:') | ||
200 | # for k, v in label_idx_dict.items(): | ||
201 | # _, text = go_res_list[k] | ||
202 | # print('{0}: {1}'.format(group_cn_list[v], text)) | ||
203 | |||
204 | # break | ||
205 | |||
206 | |||
207 | if __name__ == '__main__': | ||
208 | base_dir = '/Users/zhouweiqi/Downloads/gcfp/data' | ||
209 | go_dir = os.path.join(base_dir, 'go_res') | ||
210 | dataset_save_dir = os.path.join(base_dir, 'dataset2') | ||
211 | label_dir = os.path.join(base_dir, 'labeled') | ||
212 | |||
213 | train_go_path = os.path.join(go_dir, 'train') | ||
214 | train_image_path = os.path.join(label_dir, 'train', 'image') | ||
215 | train_label_path = os.path.join(label_dir, 'train', 'label') | ||
216 | train_dataset_dir = os.path.join(dataset_save_dir, 'train') | ||
217 | train_anno_file_path = os.path.join(dataset_save_dir, 'train.csv') | ||
218 | |||
219 | valid_go_path = os.path.join(go_dir, 'valid') | ||
220 | valid_image_path = os.path.join(label_dir, 'valid', 'image') | ||
221 | valid_label_path = os.path.join(label_dir, 'valid', 'label') | ||
222 | valid_dataset_dir = os.path.join(dataset_save_dir, 'valid') | ||
223 | valid_anno_file_path = os.path.join(dataset_save_dir, 'valid.csv') | ||
224 | |||
225 | # max_seq_lens, seq_lens_mean, max_seq_file_name = clean_go_res(go_dir) | ||
226 | # print(max_seq_lens) # 152 | ||
227 | # print(max_seq_file_name) # CH-B101805176_page_2_img_0.json | ||
228 | # print(seq_lens_mean) # 92 | ||
229 | |||
230 | # top_text_list = text_statistics(go_dir) | ||
231 | # for t in top_text_list: | ||
232 | # print(t) | ||
233 | |||
234 | filter_from_top_text_list = [ | ||
235 | ('机器编号', 496), | ||
236 | ('购买方名称', 496), | ||
237 | ('合格证号', 495), | ||
238 | ('进口证明书号', 495), | ||
239 | ('机打代码', 494), | ||
240 | ('车辆类型', 492), | ||
241 | ('完税凭证号码', 492), | ||
242 | ('机打号码', 491), | ||
243 | ('发动机号码', 491), | ||
244 | ('主管税务', 491), | ||
245 | ('价税合计', 489), | ||
246 | ('机关及代码', 489), | ||
247 | ('销货单位名称', 486), | ||
248 | ('厂牌型号', 485), | ||
249 | ('产地', 485), | ||
250 | ('商检单号', 483), | ||
251 | ('电话', 476), | ||
252 | ('开户银行', 472), | ||
253 | ('车辆识别代号/车架号码', 463), | ||
254 | ('身份证号码', 454), | ||
255 | ('吨位', 452), | ||
256 | ('备注:一车一票', 439), | ||
257 | ('地', 432), | ||
258 | ('账号', 431), | ||
259 | ('统一社会信用代码/', 424), | ||
260 | ('限乘人数', 404), | ||
261 | ('税额', 465), | ||
262 | ('址', 392) | ||
263 | ] | ||
264 | |||
265 | skip_list_train = [ | ||
266 | 'CH-B101910792-page-12.jpg', | ||
267 | 'CH-B101655312-page-13.jpg', | ||
268 | 'CH-B102278656.jpg', | ||
269 | 'CH-B101846620_page_1_img_0.jpg', | ||
270 | 'CH-B103062528-0.jpg', | ||
271 | 'CH-B102613120-3.jpg', | ||
272 | 'CH-B102997980-3.jpg', | ||
273 | 'CH-B102680060-3.jpg', | ||
274 | # 'CH-B102995500-2.jpg', # 没value | ||
275 | ] | ||
276 | |||
277 | skip_list_valid = [ | ||
278 | 'CH-B102897920-2.jpg', | ||
279 | 'CH-B102551284-0.jpg', | ||
280 | 'CH-B102879376-2.jpg', | ||
281 | 'CH-B101509488-page-16.jpg', | ||
282 | 'CH-B102708352-2.jpg', | ||
283 | ] | ||
284 | |||
285 | build_dataset(train_image_path, train_go_path, train_label_path, filter_from_top_text_list, skip_list_train, train_dataset_dir) | ||
286 | build_anno_file(train_dataset_dir, train_anno_file_path) | ||
287 | |||
288 | build_dataset(valid_image_path, valid_go_path, valid_label_path, filter_from_top_text_list, skip_list_valid, valid_dataset_dir) | ||
289 | build_anno_file(valid_dataset_dir, valid_anno_file_path) | ||
290 | |||
291 |
... | @@ -2,6 +2,7 @@ import copy | ... | @@ -2,6 +2,7 @@ import copy |
2 | import torch | 2 | import torch |
3 | import inspect | 3 | import inspect |
4 | from utils.registery import LOSS_REGISTRY | 4 | from utils.registery import LOSS_REGISTRY |
5 | from utils import sequence_mask | ||
5 | from torchvision.ops import sigmoid_focal_loss | 6 | from torchvision.ops import sigmoid_focal_loss |
6 | 7 | ||
7 | class SigmoidFocalLoss(torch.nn.modules.loss._WeightedLoss): | 8 | class SigmoidFocalLoss(torch.nn.modules.loss._WeightedLoss): |
... | @@ -21,9 +22,31 @@ class SigmoidFocalLoss(torch.nn.modules.loss._WeightedLoss): | ... | @@ -21,9 +22,31 @@ class SigmoidFocalLoss(torch.nn.modules.loss._WeightedLoss): |
21 | def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: | 22 | def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: |
22 | return sigmoid_focal_loss(inputs, targets, self.alpha, self.gamma, self.reduction) | 23 | return sigmoid_focal_loss(inputs, targets, self.alpha, self.gamma, self.reduction) |
23 | 24 | ||
25 | class MaskedSigmoidFocalLoss(torch.nn.modules.loss._WeightedLoss): | ||
26 | |||
27 | def __init__(self, | ||
28 | weight= None, | ||
29 | size_average=None, | ||
30 | reduce=None, | ||
31 | reduction: str = 'mean', | ||
32 | alpha: float = 0.25, | ||
33 | gamma: float = 2): | ||
34 | super().__init__(weight, size_average, reduce, reduction) | ||
35 | self.alpha = alpha | ||
36 | self.gamma = gamma | ||
37 | self.reduction = reduction | ||
38 | |||
39 | def forward(self, inputs: torch.Tensor, targets: torch.Tensor, valid_lens) -> torch.Tensor: | ||
40 | weights = torch.ones_like(targets) | ||
41 | weights = sequence_mask(weights, valid_lens) | ||
42 | unweighted_loss = sigmoid_focal_loss(inputs, targets, self.alpha, self.gamma, reduction='none') | ||
43 | weighted_loss = (unweighted_loss * weights).mean(dim=-1) | ||
44 | return weighted_loss | ||
45 | |||
24 | 46 | ||
25 | def register_sigmoid_focal_loss(): | 47 | def register_sigmoid_focal_loss(): |
26 | LOSS_REGISTRY.register()(SigmoidFocalLoss) | 48 | LOSS_REGISTRY.register()(SigmoidFocalLoss) |
49 | LOSS_REGISTRY.register()(MaskedSigmoidFocalLoss) | ||
27 | 50 | ||
28 | 51 | ||
29 | def register_torch_loss(): | 52 | def register_torch_loss(): | ... | ... |
... | @@ -3,6 +3,7 @@ from utils import MODEL_REGISTRY | ... | @@ -3,6 +3,7 @@ from utils import MODEL_REGISTRY |
3 | 3 | ||
4 | from .mlp import MLPModel | 4 | from .mlp import MLPModel |
5 | from .vit import VisionTransformer | 5 | from .vit import VisionTransformer |
6 | from .seq_labeling import SLTransformer | ||
6 | 7 | ||
7 | 8 | ||
8 | def build_model(cfg): | 9 | def build_model(cfg): | ... | ... |
model/seq_labeling.py
0 → 100644
1 | import math | ||
2 | from functools import partial | ||
3 | from collections import OrderedDict | ||
4 | |||
5 | import torch | ||
6 | import torch.nn as nn | ||
7 | from utils.registery import MODEL_REGISTRY | ||
8 | from utils import sequence_mask | ||
9 | |||
10 | |||
11 | def masked_softmax(X, valid_lens): | ||
12 | """Perform softmax operation by masking elements on the last axis. | ||
13 | Defined in :numref:`sec_attention-scoring-functions`""" | ||
14 | # `X`: 3D tensor, `valid_lens`: 1D or 2D tensor | ||
15 | if valid_lens is None: | ||
16 | return nn.functional.softmax(X, dim=-1) | ||
17 | else: | ||
18 | # [batch_size, num_heads, seq_len, seq_len] | ||
19 | shape = X.shape | ||
20 | if valid_lens.dim() == 1: | ||
21 | valid_lens = torch.repeat_interleave(valid_lens, shape[2]) | ||
22 | else: | ||
23 | valid_lens = valid_lens.reshape(-1) | ||
24 | # On the last axis, replace masked elements with a very large negative | ||
25 | # value, whose exponentiation outputs 0 | ||
26 | X = sequence_mask(X.reshape(-1, shape[-1]), valid_lens, value=-1e6) | ||
27 | return nn.functional.softmax(X.reshape(shape), dim=-1) | ||
28 | |||
29 | |||
30 | class PositionalEncoding(nn.Module): | ||
31 | """Positional encoding. | ||
32 | Defined in :numref:`sec_self-attention-and-positional-encoding`""" | ||
33 | def __init__(self, embed_dim, drop_ratio, max_len=1000): | ||
34 | super(PositionalEncoding, self).__init__() | ||
35 | self.dropout = nn.Dropout(drop_ratio) | ||
36 | # Create a long enough `P` | ||
37 | self.P = torch.zeros((1, max_len, embed_dim)) | ||
38 | X = torch.arange(max_len, dtype=torch.float32).reshape( | ||
39 | -1, 1) / torch.pow(10000, torch.arange( | ||
40 | 0, embed_dim, 2, dtype=torch.float32) / embed_dim) | ||
41 | self.P[:, :, 0::2] = torch.sin(X) | ||
42 | self.P[:, :, 1::2] = torch.cos(X) | ||
43 | |||
44 | def forward(self, X): | ||
45 | X = X + self.P[:, :X.shape[1], :].to(X.device) | ||
46 | return self.dropout(X) | ||
47 | |||
48 | |||
49 | def _init_vit_weights(m): | ||
50 | """ | ||
51 | ViT weight initialization | ||
52 | :param m: module | ||
53 | """ | ||
54 | if isinstance(m, nn.Linear): | ||
55 | nn.init.trunc_normal_(m.weight, std=.01) | ||
56 | if m.bias is not None: | ||
57 | nn.init.zeros_(m.bias) | ||
58 | elif isinstance(m, nn.Conv2d): | ||
59 | nn.init.kaiming_normal_(m.weight, mode="fan_out") | ||
60 | if m.bias is not None: | ||
61 | nn.init.zeros_(m.bias) | ||
62 | elif isinstance(m, nn.LayerNorm): | ||
63 | nn.init.zeros_(m.bias) | ||
64 | nn.init.ones_(m.weight) | ||
65 | |||
66 | |||
67 | def drop_path(x, drop_prob: float = 0., training: bool = False): | ||
68 | """ | ||
69 | Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). | ||
70 | This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, | ||
71 | the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... | ||
72 | See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for | ||
73 | changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use | ||
74 | 'survival rate' as the argument. | ||
75 | """ | ||
76 | if drop_prob == 0. or not training: | ||
77 | return x | ||
78 | keep_prob = 1 - drop_prob | ||
79 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets | ||
80 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) | ||
81 | random_tensor.floor_() # binarize | ||
82 | output = x.div(keep_prob) * random_tensor | ||
83 | return output | ||
84 | |||
85 | |||
86 | class DropPath(nn.Module): | ||
87 | """ | ||
88 | Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). | ||
89 | """ | ||
90 | def __init__(self, drop_prob=None): | ||
91 | super(DropPath, self).__init__() | ||
92 | self.drop_prob = drop_prob | ||
93 | |||
94 | def forward(self, x): | ||
95 | return drop_path(x, self.drop_prob, self.training) | ||
96 | |||
97 | |||
98 | class Attention(nn.Module): | ||
99 | def __init__(self, | ||
100 | dim, # 输入token的dim | ||
101 | num_heads=8, | ||
102 | qkv_bias=False, | ||
103 | qk_scale=None, | ||
104 | attn_drop_ratio=0., | ||
105 | proj_drop_ratio=0.): | ||
106 | super(Attention, self).__init__() | ||
107 | self.num_heads = num_heads | ||
108 | head_dim = dim // num_heads | ||
109 | self.scale = qk_scale or head_dim ** -0.5 | ||
110 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) | ||
111 | self.attn_drop = nn.Dropout(attn_drop_ratio) | ||
112 | self.proj = nn.Linear(dim, dim) | ||
113 | self.proj_drop = nn.Dropout(proj_drop_ratio) | ||
114 | |||
115 | def forward(self, x, valid_lens): | ||
116 | # [batch_size, seq_len, total_embed_dim] | ||
117 | B, N, C = x.shape | ||
118 | |||
119 | # qkv(): -> [batch_size, seq_len, 3 * total_embed_dim] | ||
120 | # reshape: -> [batch_size, seq_len, 3, num_heads, embed_dim_per_head] | ||
121 | # permute: -> [3, batch_size, num_heads, seq_len, embed_dim_per_head] | ||
122 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) | ||
123 | # [batch_size, num_heads, seq_len, embed_dim_per_head] | ||
124 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) | ||
125 | |||
126 | # transpose: -> [batch_size, num_heads, embed_dim_per_head, seq_len] | ||
127 | # @: multiply -> [batch_size, num_heads, seq_len, seq_len] | ||
128 | attn = (q @ k.transpose(-2, -1)) * self.scale | ||
129 | # attn = attn.softmax(dim=-1) | ||
130 | attn = masked_softmax(attn, valid_lens) | ||
131 | attn = self.attn_drop(attn) | ||
132 | |||
133 | # @: multiply -> [batch_size, num_heads, seq_len, embed_dim_per_head] | ||
134 | # transpose: -> [batch_size, seq_len, num_heads, embed_dim_per_head] | ||
135 | # reshape: -> [batch_size, seq_len, total_embed_dim] | ||
136 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) | ||
137 | x = self.proj(x) | ||
138 | x = self.proj_drop(x) | ||
139 | return x | ||
140 | |||
141 | |||
142 | class Mlp(nn.Module): | ||
143 | """ | ||
144 | MLP as used in Vision Transformer, MLP-Mixer and related networks | ||
145 | """ | ||
146 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): | ||
147 | super().__init__() | ||
148 | out_features = out_features or in_features | ||
149 | hidden_features = hidden_features or in_features | ||
150 | self.fc1 = nn.Linear(in_features, hidden_features) | ||
151 | self.act = act_layer() | ||
152 | self.fc2 = nn.Linear(hidden_features, out_features) | ||
153 | self.drop = nn.Dropout(drop) | ||
154 | |||
155 | def forward(self, x): | ||
156 | x = self.fc1(x) | ||
157 | x = self.act(x) | ||
158 | x = self.drop(x) | ||
159 | x = self.fc2(x) | ||
160 | x = self.drop(x) | ||
161 | return x | ||
162 | |||
163 | |||
164 | class Block(nn.Module): | ||
165 | def __init__(self, | ||
166 | dim, | ||
167 | num_heads, | ||
168 | mlp_ratio=4., | ||
169 | qkv_bias=False, | ||
170 | qk_scale=None, | ||
171 | drop_ratio=0., | ||
172 | attn_drop_ratio=0., | ||
173 | drop_path_ratio=0., | ||
174 | act_layer=nn.GELU, | ||
175 | norm_layer=nn.LayerNorm): | ||
176 | super(Block, self).__init__() | ||
177 | self.norm1 = norm_layer(dim) | ||
178 | self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, | ||
179 | attn_drop_ratio=attn_drop_ratio, proj_drop_ratio=drop_ratio) | ||
180 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here | ||
181 | self.drop_path = DropPath(drop_path_ratio) if drop_path_ratio > 0. else nn.Identity() | ||
182 | self.norm2 = norm_layer(dim) | ||
183 | mlp_hidden_dim = int(dim * mlp_ratio) | ||
184 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop_ratio) | ||
185 | |||
186 | def forward(self, x, valid_lens): | ||
187 | # [batch_size, seq_len, total_embed_dim] | ||
188 | x = x + self.drop_path(self.attn(self.norm1(x), valid_lens)) | ||
189 | # [batch_size, seq_len, total_embed_dim] | ||
190 | x = x + self.drop_path(self.mlp(self.norm2(x))) | ||
191 | return x | ||
192 | |||
193 | |||
194 | @MODEL_REGISTRY.register() | ||
195 | class SLTransformer(nn.Module): | ||
196 | def __init__(self, | ||
197 | seq_lens=200, | ||
198 | num_classes=1000, | ||
199 | embed_dim=768, | ||
200 | depth=12, | ||
201 | num_heads=12, | ||
202 | mlp_ratio=4.0, | ||
203 | qkv_bias=True, | ||
204 | qk_scale=None, | ||
205 | drop_ratio=0., | ||
206 | attn_drop_ratio=0., | ||
207 | drop_path_ratio=0., | ||
208 | norm_layer=None, | ||
209 | act_layer=None, | ||
210 | ): | ||
211 | """ | ||
212 | Args: | ||
213 | num_classes (int): number of classes for classification head | ||
214 | embed_dim (int): embedding dimension | ||
215 | depth (int): depth of transformer | ||
216 | num_heads (int): number of attention heads | ||
217 | mlp_ratio (int): ratio of mlp hidden dim to embedding dim | ||
218 | qkv_bias (bool): enable bias for qkv if True | ||
219 | qk_scale (float): override default qk scale of head_dim ** -0.5 if set | ||
220 | drop_ratio (float): dropout rate | ||
221 | attn_drop_ratio (float): attention dropout rate | ||
222 | drop_path_ratio (float): stochastic depth rate | ||
223 | norm_layer: (nn.Module): normalization layer | ||
224 | """ | ||
225 | super(SLTransformer, self).__init__() | ||
226 | self.num_classes = num_classes | ||
227 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models | ||
228 | norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) | ||
229 | act_layer = act_layer or nn.GELU | ||
230 | |||
231 | # self.pos_embed = PositionalEncoding(self.embed_dim, drop_ratio, max_len=seq_lens) | ||
232 | self.pos_embed = nn.Parameter(torch.zeros(1, seq_lens, embed_dim)) | ||
233 | self.pos_drop = nn.Dropout(p=drop_ratio) | ||
234 | |||
235 | dpr = [x.item() for x in torch.linspace(0, drop_path_ratio, depth)] # stochastic depth decay rule | ||
236 | self.blocks = nn.Sequential(*[ | ||
237 | Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, | ||
238 | drop_ratio=drop_ratio, attn_drop_ratio=attn_drop_ratio, drop_path_ratio=dpr[i], | ||
239 | norm_layer=norm_layer, act_layer=act_layer) | ||
240 | for i in range(depth) | ||
241 | ]) | ||
242 | self.norm = norm_layer(embed_dim) | ||
243 | |||
244 | # Classifier head(s) | ||
245 | self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() | ||
246 | |||
247 | # Weight init | ||
248 | nn.init.trunc_normal_(self.pos_embed, std=0.02) | ||
249 | self.apply(_init_vit_weights) | ||
250 | |||
251 | def forward(self, x, valid_lens): | ||
252 | # x: [B, seq_len, embed_dim] | ||
253 | # valid_lens: [B, ] | ||
254 | |||
255 | # TODO sin/cos位置编码? | ||
256 | # 因为位置编码值在-1和1之间, | ||
257 | # 因此嵌入值乘以嵌入维度的平方根进行缩放, | ||
258 | # 然后再与位置编码相加。 | ||
259 | # x = self.pos_embed(x * math.sqrt(self.embed_dim)) | ||
260 | |||
261 | # 参数的位置编码 | ||
262 | x = self.pos_drop(x + self.pos_embed) | ||
263 | |||
264 | # [batch_size, seq_len, total_embed_dim] | ||
265 | for block in self.blocks: | ||
266 | x = block(x, valid_lens) | ||
267 | # x = self.blocks(x, valid_lens) | ||
268 | x = self.norm(x) | ||
269 | |||
270 | # [batch_size, seq_len, num_classes] | ||
271 | x = self.head(x) | ||
272 | return x | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
... | @@ -3,6 +3,7 @@ import copy | ... | @@ -3,6 +3,7 @@ import copy |
3 | from utils.registery import SOLVER_REGISTRY | 3 | from utils.registery import SOLVER_REGISTRY |
4 | from .mlp_solver import MLPSolver | 4 | from .mlp_solver import MLPSolver |
5 | from .vit_solver import VITSolver | 5 | from .vit_solver import VITSolver |
6 | from .sl_solver import SLSolver | ||
6 | 7 | ||
7 | 8 | ||
8 | def build_solver(cfg): | 9 | def build_solver(cfg): | ... | ... |
solver/sl_solver.py
0 → 100644
1 | import copy | ||
2 | import os | ||
3 | |||
4 | import torch | ||
5 | |||
6 | from data import build_dataloader | ||
7 | from loss import build_loss | ||
8 | from model import build_model | ||
9 | from optimizer import build_lr_scheduler, build_optimizer | ||
10 | from utils import SOLVER_REGISTRY, get_logger_and_log_dir | ||
11 | from utils import sequence_mask | ||
12 | from sklearn.metrics import confusion_matrix, accuracy_score, classification_report | ||
13 | |||
14 | |||
15 | @SOLVER_REGISTRY.register() | ||
16 | class SLSolver(object): | ||
17 | |||
18 | def __init__(self, cfg): | ||
19 | self.device = "cuda" if torch.cuda.is_available() else "cpu" | ||
20 | |||
21 | self.cfg = copy.deepcopy(cfg) | ||
22 | |||
23 | self.train_loader, self.val_loader = build_dataloader(cfg) | ||
24 | self.train_loader_size, self.val_loader_size = len(self.train_loader), len(self.val_loader) | ||
25 | self.train_dataset_size, self.val_dataset_size = len(self.train_loader.dataset), len(self.val_loader.dataset) | ||
26 | |||
27 | # BatchNorm ? | ||
28 | self.model = build_model(cfg).to(self.device) | ||
29 | |||
30 | self.loss_fn = build_loss(cfg) | ||
31 | |||
32 | self.optimizer = build_optimizer(cfg)(self.model.parameters(), **cfg['solver']['optimizer']['args']) | ||
33 | |||
34 | self.hyper_params = cfg['solver']['args'] | ||
35 | self.base_on = self.hyper_params['base_on'] | ||
36 | self.model_path = self.hyper_params['model_path'] | ||
37 | try: | ||
38 | self.epoch = self.hyper_params['epoch'] | ||
39 | except Exception: | ||
40 | raise 'should contain epoch in {solver.args}' | ||
41 | |||
42 | self.logger, self.log_dir = get_logger_and_log_dir(**cfg['solver']['logger']) | ||
43 | |||
44 | def accuracy(self, y_pred, y_true, valid_lens, thresholds=0.5): | ||
45 | # [batch_size, seq_len, num_classes] | ||
46 | y_pred_sigmoid = torch.nn.Sigmoid()(y_pred) | ||
47 | # [batch_size, seq_len] | ||
48 | y_pred_idx = torch.argmax(y_pred_sigmoid, dim=-1) + 1 | ||
49 | # [batch_size, seq_len] | ||
50 | y_pred_is_other = (torch.amax(y_pred_sigmoid, dim=-1) > thresholds).int() | ||
51 | y_pred_rebuild = torch.multiply(y_pred_idx, y_pred_is_other) | ||
52 | |||
53 | y_true_idx = torch.argmax(y_true, dim=-1) + 1 | ||
54 | y_true_is_other = torch.sum(y_true, dim=-1).int() | ||
55 | y_true_rebuild = torch.multiply(y_true_idx, y_true_is_other) | ||
56 | |||
57 | masked_y_true_rebuild = sequence_mask(y_true_rebuild, valid_lens, value=-1) | ||
58 | |||
59 | return torch.sum((y_pred_rebuild == masked_y_true_rebuild).int()).item() | ||
60 | |||
61 | def train_loop(self): | ||
62 | self.model.train() | ||
63 | |||
64 | seq_lens_sum = torch.zeros(1).to(self.device) | ||
65 | train_loss = torch.zeros(1).to(self.device) | ||
66 | correct = torch.zeros(1).to(self.device) | ||
67 | for batch, (X, y, valid_lens) in enumerate(self.train_loader): | ||
68 | X, y = X.to(self.device), y.to(self.device) | ||
69 | |||
70 | pred = self.model(X, valid_lens) | ||
71 | # [batch_size, seq_len, num_classes] | ||
72 | |||
73 | loss = self.loss_fn(pred, y, valid_lens) | ||
74 | train_loss += loss.sum() | ||
75 | |||
76 | if batch % 100 == 0: | ||
77 | loss_value, current = loss.sum().item(), batch | ||
78 | self.logger.info(f'train iteration: {current}/{self.train_loader_size}, train loss: {loss_value :.4f}') | ||
79 | |||
80 | self.optimizer.zero_grad() | ||
81 | loss.sum().backward() | ||
82 | self.optimizer.step() | ||
83 | |||
84 | seq_lens_sum += valid_lens.sum() | ||
85 | correct += self.accuracy(pred, y, valid_lens) | ||
86 | |||
87 | # correct /= self.train_dataset_size | ||
88 | correct /= seq_lens_sum | ||
89 | train_loss /= self.train_loader_size | ||
90 | self.logger.info(f'train accuracy: {correct.item() :.4f}, train mean loss: {train_loss.item() :.4f}') | ||
91 | |||
92 | @torch.no_grad() | ||
93 | def val_loop(self, t): | ||
94 | self.model.eval() | ||
95 | |||
96 | seq_lens_sum = torch.zeros(1).to(self.device) | ||
97 | val_loss = torch.zeros(1).to(self.device) | ||
98 | correct = torch.zeros(1).to(self.device) | ||
99 | for X, y, valid_lens in self.val_loader: | ||
100 | X, y = X.to(self.device), y.to(self.device) | ||
101 | |||
102 | # pred = torch.nn.Sigmoid()(self.model(X)) | ||
103 | pred = self.model(X, valid_lens) | ||
104 | # [batch_size, seq_len, num_classes] | ||
105 | |||
106 | loss = self.loss_fn(pred, y, valid_lens) | ||
107 | val_loss += loss.sum() | ||
108 | |||
109 | seq_lens_sum += valid_lens.sum() | ||
110 | correct += self.accuracy(pred, y, valid_lens) | ||
111 | |||
112 | # correct /= self.val_dataset_size | ||
113 | correct /= seq_lens_sum | ||
114 | val_loss /= self.val_loader_size | ||
115 | |||
116 | self.logger.info(f"val accuracy: {correct.item() :.4f}, val mean loss: {val_loss.item() :.4f}") | ||
117 | |||
118 | def save_checkpoint(self, epoch_id): | ||
119 | self.model.eval() | ||
120 | torch.save(self.model.state_dict(), os.path.join(self.log_dir, f'ckpt_epoch_{epoch_id}.pt')) | ||
121 | |||
122 | def run(self): | ||
123 | if isinstance(self.base_on, str) and os.path.exists(self.base_on): | ||
124 | self.model.load_state_dict(torch.load(self.base_on)) | ||
125 | self.logger.info(f'==> Load Model from {self.base_on}') | ||
126 | |||
127 | self.logger.info('==> Start Training') | ||
128 | print(self.model) | ||
129 | |||
130 | lr_scheduler = build_lr_scheduler(self.cfg)(self.optimizer, **self.cfg['solver']['lr_scheduler']['args']) | ||
131 | |||
132 | for t in range(self.epoch): | ||
133 | self.logger.info(f'==> epoch {t + 1}') | ||
134 | |||
135 | self.train_loop() | ||
136 | self.val_loop(t + 1) | ||
137 | self.save_checkpoint(t + 1) | ||
138 | |||
139 | lr_scheduler.step() | ||
140 | |||
141 | self.logger.info('==> End Training') | ||
142 | |||
143 | # def run(self): | ||
144 | # from torch.nn import functional | ||
145 | |||
146 | # y = functional.one_hot(torch.randint(0, 10, (8, 100)), 10) | ||
147 | # valid_lens = torch.randint(50, 100, (8, )) | ||
148 | # print(valid_lens) | ||
149 | |||
150 | # pred = functional.one_hot(torch.randint(0, 10, (8, 100)), 10) | ||
151 | |||
152 | # print(self.accuracy(pred, y, valid_lens)) | ||
153 | |||
154 | def evaluate(self): | ||
155 | if isinstance(self.model_path, str) and os.path.exists(self.model_path): | ||
156 | self.model.load_state_dict(torch.load(self.model_path)) | ||
157 | self.logger.info(f'==> Load Model from {self.model_path}') | ||
158 | else: | ||
159 | return | ||
160 | |||
161 | self.model.eval() | ||
162 | |||
163 | label_true_list = [] | ||
164 | label_pred_list = [] | ||
165 | for X, y in self.val_loader: | ||
166 | X, y_true = X.to(self.device), y.to(self.device) | ||
167 | |||
168 | # pred = torch.nn.Sigmoid()(self.model(X)) | ||
169 | pred = self.model(X) | ||
170 | |||
171 | y_pred = torch.nn.Sigmoid()(pred) | ||
172 | |||
173 | y_pred_idx = torch.argmax(y_pred, dim=1) + 1 | ||
174 | y_pred_is_other = (torch.amax(y_pred, dim=1) > 0.5).int() | ||
175 | y_pred_rebuild = torch.multiply(y_pred_idx, y_pred_is_other) | ||
176 | |||
177 | y_true_idx = torch.argmax(y_true, dim=1) + 1 | ||
178 | y_true_is_other = torch.sum(y_true, dim=1) | ||
179 | y_true_rebuild = torch.multiply(y_true_idx, y_true_is_other) | ||
180 | |||
181 | label_true_list.extend(y_true_rebuild.cpu().numpy().tolist()) | ||
182 | label_pred_list.extend(y_pred_rebuild.cpu().numpy().tolist()) | ||
183 | |||
184 | |||
185 | acc = accuracy_score(label_true_list, label_pred_list) | ||
186 | cm = confusion_matrix(label_true_list, label_pred_list) | ||
187 | report = classification_report(label_true_list, label_pred_list) | ||
188 | print(acc) | ||
189 | print(cm) | ||
190 | print(report) |
1 | import torch | ||
1 | from .registery import * | 2 | from .registery import * |
2 | from .logger import get_logger_and_log_dir | 3 | from .logger import get_logger_and_log_dir |
3 | 4 | ||
4 | __all__ = [ | 5 | __all__ = [ |
5 | 'Registry', | 6 | 'Registry', |
6 | 'get_logger_and_log_dir', | 7 | 'get_logger_and_log_dir', |
8 | 'sequence_mask', | ||
7 | ] | 9 | ] |
8 | 10 | ||
11 | def sequence_mask(X, valid_len, value=0): | ||
12 | """Mask irrelevant entries in sequences. | ||
13 | Defined in :numref:`sec_seq2seq_decoder`""" | ||
14 | maxlen = X.size(1) | ||
15 | mask = torch.arange((maxlen), dtype=torch.float32, device=X.device)[None, :] < valid_len[:, None] | ||
16 | X[~mask] = value | ||
17 | return X | ... | ... |
-
Please register or sign in to post a comment