add Seq Labeling solver
Showing
12 changed files
with
887 additions
and
1 deletions
config/sl.yaml
0 → 100644
| 1 | seed: 3407 | ||
| 2 | |||
| 3 | dataset: | ||
| 4 | name: 'SLData' | ||
| 5 | args: | ||
| 6 | data_root: '/Users/zhouweiqi/Downloads/gcfp/data/dataset2' | ||
| 7 | train_anno_file: '/Users/zhouweiqi/Downloads/gcfp/data/dataset2/train.csv' | ||
| 8 | val_anno_file: '/Users/zhouweiqi/Downloads/gcfp/data/dataset2/valid.csv' | ||
| 9 | |||
| 10 | dataloader: | ||
| 11 | batch_size: 8 | ||
| 12 | num_workers: 4 | ||
| 13 | pin_memory: true | ||
| 14 | shuffle: true | ||
| 15 | |||
| 16 | model: | ||
| 17 | name: 'SLTransformer' | ||
| 18 | args: | ||
| 19 | seq_lens: 200 | ||
| 20 | num_classes: 10 | ||
| 21 | embed_dim: 9 | ||
| 22 | depth: 6 | ||
| 23 | num_heads: 1 | ||
| 24 | mlp_ratio: 4.0 | ||
| 25 | qkv_bias: true | ||
| 26 | qk_scale: null | ||
| 27 | drop_ratio: 0. | ||
| 28 | attn_drop_ratio: 0. | ||
| 29 | drop_path_ratio: 0. | ||
| 30 | norm_layer: null | ||
| 31 | act_layer: null | ||
| 32 | |||
| 33 | solver: | ||
| 34 | name: 'SLSolver' | ||
| 35 | args: | ||
| 36 | epoch: 100 | ||
| 37 | base_on: null | ||
| 38 | model_path: null | ||
| 39 | |||
| 40 | optimizer: | ||
| 41 | name: 'Adam' | ||
| 42 | args: | ||
| 43 | lr: !!float 1e-3 | ||
| 44 | # weight_decay: !!float 5e-5 | ||
| 45 | |||
| 46 | lr_scheduler: | ||
| 47 | name: 'CosineLR' | ||
| 48 | args: | ||
| 49 | epochs: 100 | ||
| 50 | lrf: 0.1 | ||
| 51 | |||
| 52 | loss: | ||
| 53 | name: 'MaskedSigmoidFocalLoss' | ||
| 54 | # name: 'SigmoidFocalLoss' | ||
| 55 | # name: 'CrossEntropyLoss' | ||
| 56 | args: | ||
| 57 | reduction: "mean" | ||
| 58 | alpha: 0.95 | ||
| 59 | |||
| 60 | logger: | ||
| 61 | log_root: '/Users/zhouweiqi/Downloads/test/logs' | ||
| 62 | suffix: 'sl-6-1' | ||
| ... | \ No newline at end of file | ... | \ No newline at end of file |
| ... | @@ -60,6 +60,7 @@ solver: | ... | @@ -60,6 +60,7 @@ solver: |
| 60 | # name: 'CrossEntropyLoss' | 60 | # name: 'CrossEntropyLoss' |
| 61 | args: | 61 | args: |
| 62 | reduction: "mean" | 62 | reduction: "mean" |
| 63 | alpha: 0.95 | ||
| 63 | 64 | ||
| 64 | logger: | 65 | logger: |
| 65 | log_root: '/Users/zhouweiqi/Downloads/test/logs' | 66 | log_root: '/Users/zhouweiqi/Downloads/test/logs' | ... | ... |
data/SLData.py
0 → 100644
| 1 | import os | ||
| 2 | import json | ||
| 3 | import torch | ||
| 4 | from torch.utils.data import DataLoader, Dataset | ||
| 5 | import pandas as pd | ||
| 6 | from utils.registery import DATASET_REGISTRY | ||
| 7 | |||
| 8 | |||
| 9 | @DATASET_REGISTRY.register() | ||
| 10 | class SLData(Dataset): | ||
| 11 | |||
| 12 | def __init__(self, | ||
| 13 | data_root: str = '/Users/zhouweiqi/Downloads/gcfp/data/dataset', | ||
| 14 | anno_file: str = 'train.csv', | ||
| 15 | phase: str = 'train'): | ||
| 16 | self.data_root = data_root | ||
| 17 | self.df = pd.read_csv(anno_file) | ||
| 18 | self.phase = phase | ||
| 19 | |||
| 20 | |||
| 21 | def __len__(self): | ||
| 22 | return len(self.df) | ||
| 23 | |||
| 24 | def __getitem__(self, idx): | ||
| 25 | series = self.df.iloc[idx] | ||
| 26 | name = series['name'] | ||
| 27 | |||
| 28 | with open(os.path.join(self.data_root, self.phase, name), 'r') as fp: | ||
| 29 | input_list, label_list, valid_lens = json.load(fp) | ||
| 30 | |||
| 31 | input_tensor = torch.tensor(input_list) | ||
| 32 | label_tensor = torch.tensor(label_list).float() | ||
| 33 | |||
| 34 | return input_tensor, label_tensor, valid_lens | ||
| ... | \ No newline at end of file | ... | \ No newline at end of file |
| ... | @@ -3,6 +3,7 @@ from torch.utils.data import DataLoader | ... | @@ -3,6 +3,7 @@ from torch.utils.data import DataLoader |
| 3 | from utils.registery import DATASET_REGISTRY | 3 | from utils.registery import DATASET_REGISTRY |
| 4 | 4 | ||
| 5 | from .CoordinatesData import CoordinatesData | 5 | from .CoordinatesData import CoordinatesData |
| 6 | from .SLData import SLData | ||
| 6 | 7 | ||
| 7 | 8 | ||
| 8 | def build_dataset(cfg): | 9 | def build_dataset(cfg): | ... | ... |
| ... | @@ -93,7 +93,8 @@ def build_dataset(img_dir, go_res_dir, label_dir, top_text_list, skip_list, save | ... | @@ -93,7 +93,8 @@ def build_dataset(img_dir, go_res_dir, label_dir, top_text_list, skip_list, save |
| 93 | label_json_path = os.path.join(label_dir, '{0}.json'.format(base_image_name)) | 93 | label_json_path = os.path.join(label_dir, '{0}.json'.format(base_image_name)) |
| 94 | label_res = load_json(label_json_path) | 94 | label_res = load_json(label_json_path) |
| 95 | 95 | ||
| 96 | # 开票日期 发票代码 机打号码 车辆类型 电话 | 96 | # 开票日期 发票代码 机打号码 车辆类型 电话 |
| 97 | # 发动机号码 车架号 帐号 开户银行 小写 | ||
| 97 | test_group_id = [1, 2, 5, 9, 20] | 98 | test_group_id = [1, 2, 5, 9, 20] |
| 98 | group_list = [] | 99 | group_list = [] |
| 99 | for group_id in test_group_id: | 100 | for group_id in test_group_id: | ... | ... |
data/create_dataset2.py
0 → 100644
| 1 | import copy | ||
| 2 | import json | ||
| 3 | import os | ||
| 4 | import random | ||
| 5 | import uuid | ||
| 6 | |||
| 7 | import cv2 | ||
| 8 | import pandas as pd | ||
| 9 | from tools import get_file_paths, load_json | ||
| 10 | |||
| 11 | |||
| 12 | def clean_go_res(go_res_dir): | ||
| 13 | max_seq_count = None | ||
| 14 | seq_sum = 0 | ||
| 15 | file_count = 0 | ||
| 16 | |||
| 17 | go_res_json_paths = get_file_paths(go_res_dir, ['.json', ]) | ||
| 18 | for go_res_json_path in go_res_json_paths: | ||
| 19 | print('Info: start {0}'.format(go_res_json_path)) | ||
| 20 | |||
| 21 | remove_key_set = set() | ||
| 22 | go_res = load_json(go_res_json_path) | ||
| 23 | for key, (_, text) in go_res.items(): | ||
| 24 | if text.strip() == '': | ||
| 25 | remove_key_set.add(key) | ||
| 26 | print(text) | ||
| 27 | |||
| 28 | if len(remove_key_set) > 0: | ||
| 29 | for del_key in remove_key_set: | ||
| 30 | del go_res[del_key] | ||
| 31 | |||
| 32 | go_res_list = sorted(list(go_res.values()), key=lambda x: (x[0][1], x[0][0]), reverse=False) | ||
| 33 | |||
| 34 | with open(go_res_json_path, 'w') as fp: | ||
| 35 | json.dump(go_res_list, fp) | ||
| 36 | print('Rerewirte {0}'.format(go_res_json_path)) | ||
| 37 | |||
| 38 | seq_sum += len(go_res_list) | ||
| 39 | file_count += 1 | ||
| 40 | if max_seq_count is None or len(go_res_list) > max_seq_count: | ||
| 41 | max_seq_count = len(go_res_list) | ||
| 42 | max_seq_file_name = go_res_json_path | ||
| 43 | |||
| 44 | seq_lens_mean = seq_sum // file_count | ||
| 45 | return max_seq_count, seq_lens_mean, max_seq_file_name | ||
| 46 | |||
| 47 | def text_statistics(go_res_dir): | ||
| 48 | """ | ||
| 49 | Args: | ||
| 50 | go_res_dir: str 通用OCR的JSON文件夹 | ||
| 51 | Returns: list 出现次数最多的文本及其次数 | ||
| 52 | """ | ||
| 53 | json_count = 0 | ||
| 54 | text_dict = {} | ||
| 55 | go_res_json_paths = get_file_paths(go_res_dir, ['.json', ]) | ||
| 56 | for go_res_json_path in go_res_json_paths: | ||
| 57 | print('Info: start {0}'.format(go_res_json_path)) | ||
| 58 | json_count += 1 | ||
| 59 | go_res = load_json(go_res_json_path) | ||
| 60 | for _, text in go_res.values(): | ||
| 61 | if text in text_dict: | ||
| 62 | text_dict[text] += 1 | ||
| 63 | else: | ||
| 64 | text_dict[text] = 1 | ||
| 65 | top_text_list = [] | ||
| 66 | # 按照次数排序 | ||
| 67 | for text, count in sorted(text_dict.items(), key=lambda x: x[1], reverse=True): | ||
| 68 | if text == '': | ||
| 69 | continue | ||
| 70 | # 丢弃:次数少于总数的2/3 | ||
| 71 | if count <= json_count // 3: | ||
| 72 | break | ||
| 73 | top_text_list.append((text, count)) | ||
| 74 | return top_text_list | ||
| 75 | |||
| 76 | def build_anno_file(dataset_dir, anno_file_path): | ||
| 77 | img_list = os.listdir(dataset_dir) | ||
| 78 | random.shuffle(img_list) | ||
| 79 | df = pd.DataFrame(columns=['name']) | ||
| 80 | df['name'] = img_list | ||
| 81 | df.to_csv(anno_file_path) | ||
| 82 | |||
| 83 | def build_dataset(img_dir, go_res_dir, label_dir, top_text_list, skip_list, save_dir): | ||
| 84 | """ | ||
| 85 | Args: | ||
| 86 | img_dir: str 图片目录 | ||
| 87 | go_res_dir: str 通用OCR的JSON保存目录 | ||
| 88 | label_dir: str 标注的JSON保存目录 | ||
| 89 | top_text_list: list 出现次数最多的文本及其次数 | ||
| 90 | skip_list: list 跳过的图片列表 | ||
| 91 | save_dir: str 数据集保存目录 | ||
| 92 | """ | ||
| 93 | if os.path.exists(save_dir): | ||
| 94 | return | ||
| 95 | else: | ||
| 96 | os.makedirs(save_dir, exist_ok=True) | ||
| 97 | |||
| 98 | # 开票日期 发票代码 机打号码 车辆类型 电话 | ||
| 99 | # 发动机号码 车架号 帐号 开户银行 小写 | ||
| 100 | group_cn_list = ['开票日期', '发票代码', '机打号码', '车辆类型', '电话', '发动机号码', '车架号', '帐号', '开户银行', '小写'] | ||
| 101 | test_group_id = [1, 2, 5, 9, 20, 15, 16, 22, 24, 28] | ||
| 102 | |||
| 103 | for img_name in sorted(os.listdir(img_dir)): | ||
| 104 | if img_name in skip_list: | ||
| 105 | print('Info: skip {0}'.format(img_name)) | ||
| 106 | continue | ||
| 107 | |||
| 108 | print('Info: start {0}'.format(img_name)) | ||
| 109 | image_path = os.path.join(img_dir, img_name) | ||
| 110 | img = cv2.imread(image_path) | ||
| 111 | h, w, _ = img.shape | ||
| 112 | base_image_name, _ = os.path.splitext(img_name) | ||
| 113 | go_res_json_path = os.path.join(go_res_dir, '{0}.json'.format(base_image_name)) | ||
| 114 | go_res_list = load_json(go_res_json_path) | ||
| 115 | |||
| 116 | valid_lens = len(go_res_list) | ||
| 117 | |||
| 118 | top_text_idx_set = set() | ||
| 119 | for top_text, _ in top_text_list: | ||
| 120 | for go_idx, (_, text) in enumerate(go_res_list): | ||
| 121 | if text == top_text: | ||
| 122 | top_text_idx_set.add(go_idx) | ||
| 123 | break | ||
| 124 | |||
| 125 | label_json_path = os.path.join(label_dir, '{0}.json'.format(base_image_name)) | ||
| 126 | label_res = load_json(label_json_path) | ||
| 127 | |||
| 128 | group_list = [] | ||
| 129 | for group_id in test_group_id: | ||
| 130 | for item in label_res.get("shapes", []): | ||
| 131 | if item.get("group_id") == group_id: | ||
| 132 | x_list = [] | ||
| 133 | y_list = [] | ||
| 134 | for point in item['points']: | ||
| 135 | x_list.append(point[0]) | ||
| 136 | y_list.append(point[1]) | ||
| 137 | group_list.append([min(x_list) + (max(x_list) - min(x_list))/2, min(y_list) + (max(y_list) - min(y_list))/2]) | ||
| 138 | break | ||
| 139 | else: | ||
| 140 | group_list.append(None) | ||
| 141 | |||
| 142 | go_center_list = [] | ||
| 143 | for (x0, y0, x1, y1, x2, y2, x3, y3), _ in go_res_list: | ||
| 144 | xmin = min(x0, x1, x2, x3) | ||
| 145 | ymin = min(y0, y1, y2, y3) | ||
| 146 | xmax = max(x0, x1, x2, x3) | ||
| 147 | ymax = max(y0, y1, y2, y3) | ||
| 148 | xcenter = xmin + (xmax - xmin)/2 | ||
| 149 | ycenter = ymin + (ymax - ymin)/2 | ||
| 150 | go_center_list.append((xcenter, ycenter)) | ||
| 151 | |||
| 152 | label_idx_dict = dict() | ||
| 153 | for label_idx, label_center_list in enumerate(group_list): | ||
| 154 | if isinstance(label_center_list, list): | ||
| 155 | min_go_key = None | ||
| 156 | min_length = None | ||
| 157 | for go_idx, (go_x_center, go_y_center) in enumerate(go_center_list): | ||
| 158 | if go_idx in top_text_idx_set or go_idx in label_idx_dict: | ||
| 159 | continue | ||
| 160 | length = abs(go_x_center-label_center_list[0])+abs(go_y_center-label_center_list[1]) | ||
| 161 | if min_go_key is None or length < min_length: | ||
| 162 | min_go_key = go_idx | ||
| 163 | min_length = length | ||
| 164 | if min_go_key is not None: | ||
| 165 | label_idx_dict[min_go_key] = label_idx | ||
| 166 | |||
| 167 | X = list() | ||
| 168 | y_true = list() | ||
| 169 | for i in range(200): | ||
| 170 | if i >= valid_lens: | ||
| 171 | X.append([0., 0., 0., 0., 0., 0., 0., 0., 0.]) | ||
| 172 | y_true.append([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) | ||
| 173 | elif i in top_text_idx_set: | ||
| 174 | (x0, y0, x1, y1, x2, y2, x3, y3), _ = go_res_list[i] | ||
| 175 | X.append([1., x0/w, y0/h, x1/w, y1/h, x2/w, y2/h, x3/w, y3/h]) | ||
| 176 | y_true.append([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) | ||
| 177 | elif i in label_idx_dict: | ||
| 178 | (x0, y0, x1, y1, x2, y2, x3, y3), _ = go_res_list[i] | ||
| 179 | X.append([0., x0/w, y0/h, x1/w, y1/h, x2/w, y2/h, x3/w, y3/h]) | ||
| 180 | base_label_list = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0] | ||
| 181 | base_label_list[label_idx_dict[i]] = 1 | ||
| 182 | y_true.append(base_label_list) | ||
| 183 | else: | ||
| 184 | (x0, y0, x1, y1, x2, y2, x3, y3), _ = go_res_list[i] | ||
| 185 | X.append([0., x0/w, y0/h, x1/w, y1/h, x2/w, y2/h, x3/w, y3/h]) | ||
| 186 | y_true.append([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) | ||
| 187 | |||
| 188 | all_data = [X, y_true, valid_lens] | ||
| 189 | |||
| 190 | with open(os.path.join(save_dir, '{0}.json'.format(uuid.uuid3(uuid.NAMESPACE_DNS, img_name))), 'w') as fp: | ||
| 191 | json.dump(all_data, fp) | ||
| 192 | |||
| 193 | # print('top text find:') | ||
| 194 | # for i in top_text_idx_set: | ||
| 195 | # _, text = go_res_list[i] | ||
| 196 | # print(text) | ||
| 197 | |||
| 198 | # print('-------------') | ||
| 199 | # print('label value find:') | ||
| 200 | # for k, v in label_idx_dict.items(): | ||
| 201 | # _, text = go_res_list[k] | ||
| 202 | # print('{0}: {1}'.format(group_cn_list[v], text)) | ||
| 203 | |||
| 204 | # break | ||
| 205 | |||
| 206 | |||
| 207 | if __name__ == '__main__': | ||
| 208 | base_dir = '/Users/zhouweiqi/Downloads/gcfp/data' | ||
| 209 | go_dir = os.path.join(base_dir, 'go_res') | ||
| 210 | dataset_save_dir = os.path.join(base_dir, 'dataset2') | ||
| 211 | label_dir = os.path.join(base_dir, 'labeled') | ||
| 212 | |||
| 213 | train_go_path = os.path.join(go_dir, 'train') | ||
| 214 | train_image_path = os.path.join(label_dir, 'train', 'image') | ||
| 215 | train_label_path = os.path.join(label_dir, 'train', 'label') | ||
| 216 | train_dataset_dir = os.path.join(dataset_save_dir, 'train') | ||
| 217 | train_anno_file_path = os.path.join(dataset_save_dir, 'train.csv') | ||
| 218 | |||
| 219 | valid_go_path = os.path.join(go_dir, 'valid') | ||
| 220 | valid_image_path = os.path.join(label_dir, 'valid', 'image') | ||
| 221 | valid_label_path = os.path.join(label_dir, 'valid', 'label') | ||
| 222 | valid_dataset_dir = os.path.join(dataset_save_dir, 'valid') | ||
| 223 | valid_anno_file_path = os.path.join(dataset_save_dir, 'valid.csv') | ||
| 224 | |||
| 225 | # max_seq_lens, seq_lens_mean, max_seq_file_name = clean_go_res(go_dir) | ||
| 226 | # print(max_seq_lens) # 152 | ||
| 227 | # print(max_seq_file_name) # CH-B101805176_page_2_img_0.json | ||
| 228 | # print(seq_lens_mean) # 92 | ||
| 229 | |||
| 230 | # top_text_list = text_statistics(go_dir) | ||
| 231 | # for t in top_text_list: | ||
| 232 | # print(t) | ||
| 233 | |||
| 234 | filter_from_top_text_list = [ | ||
| 235 | ('机器编号', 496), | ||
| 236 | ('购买方名称', 496), | ||
| 237 | ('合格证号', 495), | ||
| 238 | ('进口证明书号', 495), | ||
| 239 | ('机打代码', 494), | ||
| 240 | ('车辆类型', 492), | ||
| 241 | ('完税凭证号码', 492), | ||
| 242 | ('机打号码', 491), | ||
| 243 | ('发动机号码', 491), | ||
| 244 | ('主管税务', 491), | ||
| 245 | ('价税合计', 489), | ||
| 246 | ('机关及代码', 489), | ||
| 247 | ('销货单位名称', 486), | ||
| 248 | ('厂牌型号', 485), | ||
| 249 | ('产地', 485), | ||
| 250 | ('商检单号', 483), | ||
| 251 | ('电话', 476), | ||
| 252 | ('开户银行', 472), | ||
| 253 | ('车辆识别代号/车架号码', 463), | ||
| 254 | ('身份证号码', 454), | ||
| 255 | ('吨位', 452), | ||
| 256 | ('备注:一车一票', 439), | ||
| 257 | ('地', 432), | ||
| 258 | ('账号', 431), | ||
| 259 | ('统一社会信用代码/', 424), | ||
| 260 | ('限乘人数', 404), | ||
| 261 | ('税额', 465), | ||
| 262 | ('址', 392) | ||
| 263 | ] | ||
| 264 | |||
| 265 | skip_list_train = [ | ||
| 266 | 'CH-B101910792-page-12.jpg', | ||
| 267 | 'CH-B101655312-page-13.jpg', | ||
| 268 | 'CH-B102278656.jpg', | ||
| 269 | 'CH-B101846620_page_1_img_0.jpg', | ||
| 270 | 'CH-B103062528-0.jpg', | ||
| 271 | 'CH-B102613120-3.jpg', | ||
| 272 | 'CH-B102997980-3.jpg', | ||
| 273 | 'CH-B102680060-3.jpg', | ||
| 274 | # 'CH-B102995500-2.jpg', # 没value | ||
| 275 | ] | ||
| 276 | |||
| 277 | skip_list_valid = [ | ||
| 278 | 'CH-B102897920-2.jpg', | ||
| 279 | 'CH-B102551284-0.jpg', | ||
| 280 | 'CH-B102879376-2.jpg', | ||
| 281 | 'CH-B101509488-page-16.jpg', | ||
| 282 | 'CH-B102708352-2.jpg', | ||
| 283 | ] | ||
| 284 | |||
| 285 | build_dataset(train_image_path, train_go_path, train_label_path, filter_from_top_text_list, skip_list_train, train_dataset_dir) | ||
| 286 | build_anno_file(train_dataset_dir, train_anno_file_path) | ||
| 287 | |||
| 288 | build_dataset(valid_image_path, valid_go_path, valid_label_path, filter_from_top_text_list, skip_list_valid, valid_dataset_dir) | ||
| 289 | build_anno_file(valid_dataset_dir, valid_anno_file_path) | ||
| 290 | |||
| 291 |
| ... | @@ -2,6 +2,7 @@ import copy | ... | @@ -2,6 +2,7 @@ import copy |
| 2 | import torch | 2 | import torch |
| 3 | import inspect | 3 | import inspect |
| 4 | from utils.registery import LOSS_REGISTRY | 4 | from utils.registery import LOSS_REGISTRY |
| 5 | from utils import sequence_mask | ||
| 5 | from torchvision.ops import sigmoid_focal_loss | 6 | from torchvision.ops import sigmoid_focal_loss |
| 6 | 7 | ||
| 7 | class SigmoidFocalLoss(torch.nn.modules.loss._WeightedLoss): | 8 | class SigmoidFocalLoss(torch.nn.modules.loss._WeightedLoss): |
| ... | @@ -21,9 +22,31 @@ class SigmoidFocalLoss(torch.nn.modules.loss._WeightedLoss): | ... | @@ -21,9 +22,31 @@ class SigmoidFocalLoss(torch.nn.modules.loss._WeightedLoss): |
| 21 | def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: | 22 | def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: |
| 22 | return sigmoid_focal_loss(inputs, targets, self.alpha, self.gamma, self.reduction) | 23 | return sigmoid_focal_loss(inputs, targets, self.alpha, self.gamma, self.reduction) |
| 23 | 24 | ||
| 25 | class MaskedSigmoidFocalLoss(torch.nn.modules.loss._WeightedLoss): | ||
| 26 | |||
| 27 | def __init__(self, | ||
| 28 | weight= None, | ||
| 29 | size_average=None, | ||
| 30 | reduce=None, | ||
| 31 | reduction: str = 'mean', | ||
| 32 | alpha: float = 0.25, | ||
| 33 | gamma: float = 2): | ||
| 34 | super().__init__(weight, size_average, reduce, reduction) | ||
| 35 | self.alpha = alpha | ||
| 36 | self.gamma = gamma | ||
| 37 | self.reduction = reduction | ||
| 38 | |||
| 39 | def forward(self, inputs: torch.Tensor, targets: torch.Tensor, valid_lens) -> torch.Tensor: | ||
| 40 | weights = torch.ones_like(targets) | ||
| 41 | weights = sequence_mask(weights, valid_lens) | ||
| 42 | unweighted_loss = sigmoid_focal_loss(inputs, targets, self.alpha, self.gamma, reduction='none') | ||
| 43 | weighted_loss = (unweighted_loss * weights).mean(dim=-1) | ||
| 44 | return weighted_loss | ||
| 45 | |||
| 24 | 46 | ||
| 25 | def register_sigmoid_focal_loss(): | 47 | def register_sigmoid_focal_loss(): |
| 26 | LOSS_REGISTRY.register()(SigmoidFocalLoss) | 48 | LOSS_REGISTRY.register()(SigmoidFocalLoss) |
| 49 | LOSS_REGISTRY.register()(MaskedSigmoidFocalLoss) | ||
| 27 | 50 | ||
| 28 | 51 | ||
| 29 | def register_torch_loss(): | 52 | def register_torch_loss(): | ... | ... |
| ... | @@ -3,6 +3,7 @@ from utils import MODEL_REGISTRY | ... | @@ -3,6 +3,7 @@ from utils import MODEL_REGISTRY |
| 3 | 3 | ||
| 4 | from .mlp import MLPModel | 4 | from .mlp import MLPModel |
| 5 | from .vit import VisionTransformer | 5 | from .vit import VisionTransformer |
| 6 | from .seq_labeling import SLTransformer | ||
| 6 | 7 | ||
| 7 | 8 | ||
| 8 | def build_model(cfg): | 9 | def build_model(cfg): | ... | ... |
model/seq_labeling.py
0 → 100644
| 1 | import math | ||
| 2 | from functools import partial | ||
| 3 | from collections import OrderedDict | ||
| 4 | |||
| 5 | import torch | ||
| 6 | import torch.nn as nn | ||
| 7 | from utils.registery import MODEL_REGISTRY | ||
| 8 | from utils import sequence_mask | ||
| 9 | |||
| 10 | |||
| 11 | def masked_softmax(X, valid_lens): | ||
| 12 | """Perform softmax operation by masking elements on the last axis. | ||
| 13 | Defined in :numref:`sec_attention-scoring-functions`""" | ||
| 14 | # `X`: 3D tensor, `valid_lens`: 1D or 2D tensor | ||
| 15 | if valid_lens is None: | ||
| 16 | return nn.functional.softmax(X, dim=-1) | ||
| 17 | else: | ||
| 18 | # [batch_size, num_heads, seq_len, seq_len] | ||
| 19 | shape = X.shape | ||
| 20 | if valid_lens.dim() == 1: | ||
| 21 | valid_lens = torch.repeat_interleave(valid_lens, shape[2]) | ||
| 22 | else: | ||
| 23 | valid_lens = valid_lens.reshape(-1) | ||
| 24 | # On the last axis, replace masked elements with a very large negative | ||
| 25 | # value, whose exponentiation outputs 0 | ||
| 26 | X = sequence_mask(X.reshape(-1, shape[-1]), valid_lens, value=-1e6) | ||
| 27 | return nn.functional.softmax(X.reshape(shape), dim=-1) | ||
| 28 | |||
| 29 | |||
| 30 | class PositionalEncoding(nn.Module): | ||
| 31 | """Positional encoding. | ||
| 32 | Defined in :numref:`sec_self-attention-and-positional-encoding`""" | ||
| 33 | def __init__(self, embed_dim, drop_ratio, max_len=1000): | ||
| 34 | super(PositionalEncoding, self).__init__() | ||
| 35 | self.dropout = nn.Dropout(drop_ratio) | ||
| 36 | # Create a long enough `P` | ||
| 37 | self.P = torch.zeros((1, max_len, embed_dim)) | ||
| 38 | X = torch.arange(max_len, dtype=torch.float32).reshape( | ||
| 39 | -1, 1) / torch.pow(10000, torch.arange( | ||
| 40 | 0, embed_dim, 2, dtype=torch.float32) / embed_dim) | ||
| 41 | self.P[:, :, 0::2] = torch.sin(X) | ||
| 42 | self.P[:, :, 1::2] = torch.cos(X) | ||
| 43 | |||
| 44 | def forward(self, X): | ||
| 45 | X = X + self.P[:, :X.shape[1], :].to(X.device) | ||
| 46 | return self.dropout(X) | ||
| 47 | |||
| 48 | |||
| 49 | def _init_vit_weights(m): | ||
| 50 | """ | ||
| 51 | ViT weight initialization | ||
| 52 | :param m: module | ||
| 53 | """ | ||
| 54 | if isinstance(m, nn.Linear): | ||
| 55 | nn.init.trunc_normal_(m.weight, std=.01) | ||
| 56 | if m.bias is not None: | ||
| 57 | nn.init.zeros_(m.bias) | ||
| 58 | elif isinstance(m, nn.Conv2d): | ||
| 59 | nn.init.kaiming_normal_(m.weight, mode="fan_out") | ||
| 60 | if m.bias is not None: | ||
| 61 | nn.init.zeros_(m.bias) | ||
| 62 | elif isinstance(m, nn.LayerNorm): | ||
| 63 | nn.init.zeros_(m.bias) | ||
| 64 | nn.init.ones_(m.weight) | ||
| 65 | |||
| 66 | |||
| 67 | def drop_path(x, drop_prob: float = 0., training: bool = False): | ||
| 68 | """ | ||
| 69 | Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). | ||
| 70 | This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, | ||
| 71 | the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... | ||
| 72 | See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for | ||
| 73 | changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use | ||
| 74 | 'survival rate' as the argument. | ||
| 75 | """ | ||
| 76 | if drop_prob == 0. or not training: | ||
| 77 | return x | ||
| 78 | keep_prob = 1 - drop_prob | ||
| 79 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets | ||
| 80 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) | ||
| 81 | random_tensor.floor_() # binarize | ||
| 82 | output = x.div(keep_prob) * random_tensor | ||
| 83 | return output | ||
| 84 | |||
| 85 | |||
| 86 | class DropPath(nn.Module): | ||
| 87 | """ | ||
| 88 | Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). | ||
| 89 | """ | ||
| 90 | def __init__(self, drop_prob=None): | ||
| 91 | super(DropPath, self).__init__() | ||
| 92 | self.drop_prob = drop_prob | ||
| 93 | |||
| 94 | def forward(self, x): | ||
| 95 | return drop_path(x, self.drop_prob, self.training) | ||
| 96 | |||
| 97 | |||
| 98 | class Attention(nn.Module): | ||
| 99 | def __init__(self, | ||
| 100 | dim, # 输入token的dim | ||
| 101 | num_heads=8, | ||
| 102 | qkv_bias=False, | ||
| 103 | qk_scale=None, | ||
| 104 | attn_drop_ratio=0., | ||
| 105 | proj_drop_ratio=0.): | ||
| 106 | super(Attention, self).__init__() | ||
| 107 | self.num_heads = num_heads | ||
| 108 | head_dim = dim // num_heads | ||
| 109 | self.scale = qk_scale or head_dim ** -0.5 | ||
| 110 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) | ||
| 111 | self.attn_drop = nn.Dropout(attn_drop_ratio) | ||
| 112 | self.proj = nn.Linear(dim, dim) | ||
| 113 | self.proj_drop = nn.Dropout(proj_drop_ratio) | ||
| 114 | |||
| 115 | def forward(self, x, valid_lens): | ||
| 116 | # [batch_size, seq_len, total_embed_dim] | ||
| 117 | B, N, C = x.shape | ||
| 118 | |||
| 119 | # qkv(): -> [batch_size, seq_len, 3 * total_embed_dim] | ||
| 120 | # reshape: -> [batch_size, seq_len, 3, num_heads, embed_dim_per_head] | ||
| 121 | # permute: -> [3, batch_size, num_heads, seq_len, embed_dim_per_head] | ||
| 122 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) | ||
| 123 | # [batch_size, num_heads, seq_len, embed_dim_per_head] | ||
| 124 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) | ||
| 125 | |||
| 126 | # transpose: -> [batch_size, num_heads, embed_dim_per_head, seq_len] | ||
| 127 | # @: multiply -> [batch_size, num_heads, seq_len, seq_len] | ||
| 128 | attn = (q @ k.transpose(-2, -1)) * self.scale | ||
| 129 | # attn = attn.softmax(dim=-1) | ||
| 130 | attn = masked_softmax(attn, valid_lens) | ||
| 131 | attn = self.attn_drop(attn) | ||
| 132 | |||
| 133 | # @: multiply -> [batch_size, num_heads, seq_len, embed_dim_per_head] | ||
| 134 | # transpose: -> [batch_size, seq_len, num_heads, embed_dim_per_head] | ||
| 135 | # reshape: -> [batch_size, seq_len, total_embed_dim] | ||
| 136 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) | ||
| 137 | x = self.proj(x) | ||
| 138 | x = self.proj_drop(x) | ||
| 139 | return x | ||
| 140 | |||
| 141 | |||
| 142 | class Mlp(nn.Module): | ||
| 143 | """ | ||
| 144 | MLP as used in Vision Transformer, MLP-Mixer and related networks | ||
| 145 | """ | ||
| 146 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): | ||
| 147 | super().__init__() | ||
| 148 | out_features = out_features or in_features | ||
| 149 | hidden_features = hidden_features or in_features | ||
| 150 | self.fc1 = nn.Linear(in_features, hidden_features) | ||
| 151 | self.act = act_layer() | ||
| 152 | self.fc2 = nn.Linear(hidden_features, out_features) | ||
| 153 | self.drop = nn.Dropout(drop) | ||
| 154 | |||
| 155 | def forward(self, x): | ||
| 156 | x = self.fc1(x) | ||
| 157 | x = self.act(x) | ||
| 158 | x = self.drop(x) | ||
| 159 | x = self.fc2(x) | ||
| 160 | x = self.drop(x) | ||
| 161 | return x | ||
| 162 | |||
| 163 | |||
| 164 | class Block(nn.Module): | ||
| 165 | def __init__(self, | ||
| 166 | dim, | ||
| 167 | num_heads, | ||
| 168 | mlp_ratio=4., | ||
| 169 | qkv_bias=False, | ||
| 170 | qk_scale=None, | ||
| 171 | drop_ratio=0., | ||
| 172 | attn_drop_ratio=0., | ||
| 173 | drop_path_ratio=0., | ||
| 174 | act_layer=nn.GELU, | ||
| 175 | norm_layer=nn.LayerNorm): | ||
| 176 | super(Block, self).__init__() | ||
| 177 | self.norm1 = norm_layer(dim) | ||
| 178 | self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, | ||
| 179 | attn_drop_ratio=attn_drop_ratio, proj_drop_ratio=drop_ratio) | ||
| 180 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here | ||
| 181 | self.drop_path = DropPath(drop_path_ratio) if drop_path_ratio > 0. else nn.Identity() | ||
| 182 | self.norm2 = norm_layer(dim) | ||
| 183 | mlp_hidden_dim = int(dim * mlp_ratio) | ||
| 184 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop_ratio) | ||
| 185 | |||
| 186 | def forward(self, x, valid_lens): | ||
| 187 | # [batch_size, seq_len, total_embed_dim] | ||
| 188 | x = x + self.drop_path(self.attn(self.norm1(x), valid_lens)) | ||
| 189 | # [batch_size, seq_len, total_embed_dim] | ||
| 190 | x = x + self.drop_path(self.mlp(self.norm2(x))) | ||
| 191 | return x | ||
| 192 | |||
| 193 | |||
| 194 | @MODEL_REGISTRY.register() | ||
| 195 | class SLTransformer(nn.Module): | ||
| 196 | def __init__(self, | ||
| 197 | seq_lens=200, | ||
| 198 | num_classes=1000, | ||
| 199 | embed_dim=768, | ||
| 200 | depth=12, | ||
| 201 | num_heads=12, | ||
| 202 | mlp_ratio=4.0, | ||
| 203 | qkv_bias=True, | ||
| 204 | qk_scale=None, | ||
| 205 | drop_ratio=0., | ||
| 206 | attn_drop_ratio=0., | ||
| 207 | drop_path_ratio=0., | ||
| 208 | norm_layer=None, | ||
| 209 | act_layer=None, | ||
| 210 | ): | ||
| 211 | """ | ||
| 212 | Args: | ||
| 213 | num_classes (int): number of classes for classification head | ||
| 214 | embed_dim (int): embedding dimension | ||
| 215 | depth (int): depth of transformer | ||
| 216 | num_heads (int): number of attention heads | ||
| 217 | mlp_ratio (int): ratio of mlp hidden dim to embedding dim | ||
| 218 | qkv_bias (bool): enable bias for qkv if True | ||
| 219 | qk_scale (float): override default qk scale of head_dim ** -0.5 if set | ||
| 220 | drop_ratio (float): dropout rate | ||
| 221 | attn_drop_ratio (float): attention dropout rate | ||
| 222 | drop_path_ratio (float): stochastic depth rate | ||
| 223 | norm_layer: (nn.Module): normalization layer | ||
| 224 | """ | ||
| 225 | super(SLTransformer, self).__init__() | ||
| 226 | self.num_classes = num_classes | ||
| 227 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models | ||
| 228 | norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) | ||
| 229 | act_layer = act_layer or nn.GELU | ||
| 230 | |||
| 231 | # self.pos_embed = PositionalEncoding(self.embed_dim, drop_ratio, max_len=seq_lens) | ||
| 232 | self.pos_embed = nn.Parameter(torch.zeros(1, seq_lens, embed_dim)) | ||
| 233 | self.pos_drop = nn.Dropout(p=drop_ratio) | ||
| 234 | |||
| 235 | dpr = [x.item() for x in torch.linspace(0, drop_path_ratio, depth)] # stochastic depth decay rule | ||
| 236 | self.blocks = nn.Sequential(*[ | ||
| 237 | Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, | ||
| 238 | drop_ratio=drop_ratio, attn_drop_ratio=attn_drop_ratio, drop_path_ratio=dpr[i], | ||
| 239 | norm_layer=norm_layer, act_layer=act_layer) | ||
| 240 | for i in range(depth) | ||
| 241 | ]) | ||
| 242 | self.norm = norm_layer(embed_dim) | ||
| 243 | |||
| 244 | # Classifier head(s) | ||
| 245 | self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() | ||
| 246 | |||
| 247 | # Weight init | ||
| 248 | nn.init.trunc_normal_(self.pos_embed, std=0.02) | ||
| 249 | self.apply(_init_vit_weights) | ||
| 250 | |||
| 251 | def forward(self, x, valid_lens): | ||
| 252 | # x: [B, seq_len, embed_dim] | ||
| 253 | # valid_lens: [B, ] | ||
| 254 | |||
| 255 | # TODO sin/cos位置编码? | ||
| 256 | # 因为位置编码值在-1和1之间, | ||
| 257 | # 因此嵌入值乘以嵌入维度的平方根进行缩放, | ||
| 258 | # 然后再与位置编码相加。 | ||
| 259 | # x = self.pos_embed(x * math.sqrt(self.embed_dim)) | ||
| 260 | |||
| 261 | # 参数的位置编码 | ||
| 262 | x = self.pos_drop(x + self.pos_embed) | ||
| 263 | |||
| 264 | # [batch_size, seq_len, total_embed_dim] | ||
| 265 | for block in self.blocks: | ||
| 266 | x = block(x, valid_lens) | ||
| 267 | # x = self.blocks(x, valid_lens) | ||
| 268 | x = self.norm(x) | ||
| 269 | |||
| 270 | # [batch_size, seq_len, num_classes] | ||
| 271 | x = self.head(x) | ||
| 272 | return x | ||
| ... | \ No newline at end of file | ... | \ No newline at end of file |
| ... | @@ -3,6 +3,7 @@ import copy | ... | @@ -3,6 +3,7 @@ import copy |
| 3 | from utils.registery import SOLVER_REGISTRY | 3 | from utils.registery import SOLVER_REGISTRY |
| 4 | from .mlp_solver import MLPSolver | 4 | from .mlp_solver import MLPSolver |
| 5 | from .vit_solver import VITSolver | 5 | from .vit_solver import VITSolver |
| 6 | from .sl_solver import SLSolver | ||
| 6 | 7 | ||
| 7 | 8 | ||
| 8 | def build_solver(cfg): | 9 | def build_solver(cfg): | ... | ... |
solver/sl_solver.py
0 → 100644
| 1 | import copy | ||
| 2 | import os | ||
| 3 | |||
| 4 | import torch | ||
| 5 | |||
| 6 | from data import build_dataloader | ||
| 7 | from loss import build_loss | ||
| 8 | from model import build_model | ||
| 9 | from optimizer import build_lr_scheduler, build_optimizer | ||
| 10 | from utils import SOLVER_REGISTRY, get_logger_and_log_dir | ||
| 11 | from utils import sequence_mask | ||
| 12 | from sklearn.metrics import confusion_matrix, accuracy_score, classification_report | ||
| 13 | |||
| 14 | |||
| 15 | @SOLVER_REGISTRY.register() | ||
| 16 | class SLSolver(object): | ||
| 17 | |||
| 18 | def __init__(self, cfg): | ||
| 19 | self.device = "cuda" if torch.cuda.is_available() else "cpu" | ||
| 20 | |||
| 21 | self.cfg = copy.deepcopy(cfg) | ||
| 22 | |||
| 23 | self.train_loader, self.val_loader = build_dataloader(cfg) | ||
| 24 | self.train_loader_size, self.val_loader_size = len(self.train_loader), len(self.val_loader) | ||
| 25 | self.train_dataset_size, self.val_dataset_size = len(self.train_loader.dataset), len(self.val_loader.dataset) | ||
| 26 | |||
| 27 | # BatchNorm ? | ||
| 28 | self.model = build_model(cfg).to(self.device) | ||
| 29 | |||
| 30 | self.loss_fn = build_loss(cfg) | ||
| 31 | |||
| 32 | self.optimizer = build_optimizer(cfg)(self.model.parameters(), **cfg['solver']['optimizer']['args']) | ||
| 33 | |||
| 34 | self.hyper_params = cfg['solver']['args'] | ||
| 35 | self.base_on = self.hyper_params['base_on'] | ||
| 36 | self.model_path = self.hyper_params['model_path'] | ||
| 37 | try: | ||
| 38 | self.epoch = self.hyper_params['epoch'] | ||
| 39 | except Exception: | ||
| 40 | raise 'should contain epoch in {solver.args}' | ||
| 41 | |||
| 42 | self.logger, self.log_dir = get_logger_and_log_dir(**cfg['solver']['logger']) | ||
| 43 | |||
| 44 | def accuracy(self, y_pred, y_true, valid_lens, thresholds=0.5): | ||
| 45 | # [batch_size, seq_len, num_classes] | ||
| 46 | y_pred_sigmoid = torch.nn.Sigmoid()(y_pred) | ||
| 47 | # [batch_size, seq_len] | ||
| 48 | y_pred_idx = torch.argmax(y_pred_sigmoid, dim=-1) + 1 | ||
| 49 | # [batch_size, seq_len] | ||
| 50 | y_pred_is_other = (torch.amax(y_pred_sigmoid, dim=-1) > thresholds).int() | ||
| 51 | y_pred_rebuild = torch.multiply(y_pred_idx, y_pred_is_other) | ||
| 52 | |||
| 53 | y_true_idx = torch.argmax(y_true, dim=-1) + 1 | ||
| 54 | y_true_is_other = torch.sum(y_true, dim=-1).int() | ||
| 55 | y_true_rebuild = torch.multiply(y_true_idx, y_true_is_other) | ||
| 56 | |||
| 57 | masked_y_true_rebuild = sequence_mask(y_true_rebuild, valid_lens, value=-1) | ||
| 58 | |||
| 59 | return torch.sum((y_pred_rebuild == masked_y_true_rebuild).int()).item() | ||
| 60 | |||
| 61 | def train_loop(self): | ||
| 62 | self.model.train() | ||
| 63 | |||
| 64 | seq_lens_sum = torch.zeros(1).to(self.device) | ||
| 65 | train_loss = torch.zeros(1).to(self.device) | ||
| 66 | correct = torch.zeros(1).to(self.device) | ||
| 67 | for batch, (X, y, valid_lens) in enumerate(self.train_loader): | ||
| 68 | X, y = X.to(self.device), y.to(self.device) | ||
| 69 | |||
| 70 | pred = self.model(X, valid_lens) | ||
| 71 | # [batch_size, seq_len, num_classes] | ||
| 72 | |||
| 73 | loss = self.loss_fn(pred, y, valid_lens) | ||
| 74 | train_loss += loss.sum() | ||
| 75 | |||
| 76 | if batch % 100 == 0: | ||
| 77 | loss_value, current = loss.sum().item(), batch | ||
| 78 | self.logger.info(f'train iteration: {current}/{self.train_loader_size}, train loss: {loss_value :.4f}') | ||
| 79 | |||
| 80 | self.optimizer.zero_grad() | ||
| 81 | loss.sum().backward() | ||
| 82 | self.optimizer.step() | ||
| 83 | |||
| 84 | seq_lens_sum += valid_lens.sum() | ||
| 85 | correct += self.accuracy(pred, y, valid_lens) | ||
| 86 | |||
| 87 | # correct /= self.train_dataset_size | ||
| 88 | correct /= seq_lens_sum | ||
| 89 | train_loss /= self.train_loader_size | ||
| 90 | self.logger.info(f'train accuracy: {correct.item() :.4f}, train mean loss: {train_loss.item() :.4f}') | ||
| 91 | |||
| 92 | @torch.no_grad() | ||
| 93 | def val_loop(self, t): | ||
| 94 | self.model.eval() | ||
| 95 | |||
| 96 | seq_lens_sum = torch.zeros(1).to(self.device) | ||
| 97 | val_loss = torch.zeros(1).to(self.device) | ||
| 98 | correct = torch.zeros(1).to(self.device) | ||
| 99 | for X, y, valid_lens in self.val_loader: | ||
| 100 | X, y = X.to(self.device), y.to(self.device) | ||
| 101 | |||
| 102 | # pred = torch.nn.Sigmoid()(self.model(X)) | ||
| 103 | pred = self.model(X, valid_lens) | ||
| 104 | # [batch_size, seq_len, num_classes] | ||
| 105 | |||
| 106 | loss = self.loss_fn(pred, y, valid_lens) | ||
| 107 | val_loss += loss.sum() | ||
| 108 | |||
| 109 | seq_lens_sum += valid_lens.sum() | ||
| 110 | correct += self.accuracy(pred, y, valid_lens) | ||
| 111 | |||
| 112 | # correct /= self.val_dataset_size | ||
| 113 | correct /= seq_lens_sum | ||
| 114 | val_loss /= self.val_loader_size | ||
| 115 | |||
| 116 | self.logger.info(f"val accuracy: {correct.item() :.4f}, val mean loss: {val_loss.item() :.4f}") | ||
| 117 | |||
| 118 | def save_checkpoint(self, epoch_id): | ||
| 119 | self.model.eval() | ||
| 120 | torch.save(self.model.state_dict(), os.path.join(self.log_dir, f'ckpt_epoch_{epoch_id}.pt')) | ||
| 121 | |||
| 122 | def run(self): | ||
| 123 | if isinstance(self.base_on, str) and os.path.exists(self.base_on): | ||
| 124 | self.model.load_state_dict(torch.load(self.base_on)) | ||
| 125 | self.logger.info(f'==> Load Model from {self.base_on}') | ||
| 126 | |||
| 127 | self.logger.info('==> Start Training') | ||
| 128 | print(self.model) | ||
| 129 | |||
| 130 | lr_scheduler = build_lr_scheduler(self.cfg)(self.optimizer, **self.cfg['solver']['lr_scheduler']['args']) | ||
| 131 | |||
| 132 | for t in range(self.epoch): | ||
| 133 | self.logger.info(f'==> epoch {t + 1}') | ||
| 134 | |||
| 135 | self.train_loop() | ||
| 136 | self.val_loop(t + 1) | ||
| 137 | self.save_checkpoint(t + 1) | ||
| 138 | |||
| 139 | lr_scheduler.step() | ||
| 140 | |||
| 141 | self.logger.info('==> End Training') | ||
| 142 | |||
| 143 | # def run(self): | ||
| 144 | # from torch.nn import functional | ||
| 145 | |||
| 146 | # y = functional.one_hot(torch.randint(0, 10, (8, 100)), 10) | ||
| 147 | # valid_lens = torch.randint(50, 100, (8, )) | ||
| 148 | # print(valid_lens) | ||
| 149 | |||
| 150 | # pred = functional.one_hot(torch.randint(0, 10, (8, 100)), 10) | ||
| 151 | |||
| 152 | # print(self.accuracy(pred, y, valid_lens)) | ||
| 153 | |||
| 154 | def evaluate(self): | ||
| 155 | if isinstance(self.model_path, str) and os.path.exists(self.model_path): | ||
| 156 | self.model.load_state_dict(torch.load(self.model_path)) | ||
| 157 | self.logger.info(f'==> Load Model from {self.model_path}') | ||
| 158 | else: | ||
| 159 | return | ||
| 160 | |||
| 161 | self.model.eval() | ||
| 162 | |||
| 163 | label_true_list = [] | ||
| 164 | label_pred_list = [] | ||
| 165 | for X, y in self.val_loader: | ||
| 166 | X, y_true = X.to(self.device), y.to(self.device) | ||
| 167 | |||
| 168 | # pred = torch.nn.Sigmoid()(self.model(X)) | ||
| 169 | pred = self.model(X) | ||
| 170 | |||
| 171 | y_pred = torch.nn.Sigmoid()(pred) | ||
| 172 | |||
| 173 | y_pred_idx = torch.argmax(y_pred, dim=1) + 1 | ||
| 174 | y_pred_is_other = (torch.amax(y_pred, dim=1) > 0.5).int() | ||
| 175 | y_pred_rebuild = torch.multiply(y_pred_idx, y_pred_is_other) | ||
| 176 | |||
| 177 | y_true_idx = torch.argmax(y_true, dim=1) + 1 | ||
| 178 | y_true_is_other = torch.sum(y_true, dim=1) | ||
| 179 | y_true_rebuild = torch.multiply(y_true_idx, y_true_is_other) | ||
| 180 | |||
| 181 | label_true_list.extend(y_true_rebuild.cpu().numpy().tolist()) | ||
| 182 | label_pred_list.extend(y_pred_rebuild.cpu().numpy().tolist()) | ||
| 183 | |||
| 184 | |||
| 185 | acc = accuracy_score(label_true_list, label_pred_list) | ||
| 186 | cm = confusion_matrix(label_true_list, label_pred_list) | ||
| 187 | report = classification_report(label_true_list, label_pred_list) | ||
| 188 | print(acc) | ||
| 189 | print(cm) | ||
| 190 | print(report) |
| 1 | import torch | ||
| 1 | from .registery import * | 2 | from .registery import * |
| 2 | from .logger import get_logger_and_log_dir | 3 | from .logger import get_logger_and_log_dir |
| 3 | 4 | ||
| 4 | __all__ = [ | 5 | __all__ = [ |
| 5 | 'Registry', | 6 | 'Registry', |
| 6 | 'get_logger_and_log_dir', | 7 | 'get_logger_and_log_dir', |
| 8 | 'sequence_mask', | ||
| 7 | ] | 9 | ] |
| 8 | 10 | ||
| 11 | def sequence_mask(X, valid_len, value=0): | ||
| 12 | """Mask irrelevant entries in sequences. | ||
| 13 | Defined in :numref:`sec_seq2seq_decoder`""" | ||
| 14 | maxlen = X.size(1) | ||
| 15 | mask = torch.arange((maxlen), dtype=torch.float32, device=X.device)[None, :] < valid_len[:, None] | ||
| 16 | X[~mask] = value | ||
| 17 | return X | ... | ... |
-
Please register or sign in to post a comment