add Seq Labeling solver
Showing
12 changed files
with
323 additions
and
0 deletions
config/sl.yaml
0 → 100644
1 | seed: 3407 | ||
2 | |||
3 | dataset: | ||
4 | name: 'SLData' | ||
5 | args: | ||
6 | data_root: '/Users/zhouweiqi/Downloads/gcfp/data/dataset2' | ||
7 | train_anno_file: '/Users/zhouweiqi/Downloads/gcfp/data/dataset2/train.csv' | ||
8 | val_anno_file: '/Users/zhouweiqi/Downloads/gcfp/data/dataset2/valid.csv' | ||
9 | |||
10 | dataloader: | ||
11 | batch_size: 8 | ||
12 | num_workers: 4 | ||
13 | pin_memory: true | ||
14 | shuffle: true | ||
15 | |||
16 | model: | ||
17 | name: 'SLTransformer' | ||
18 | args: | ||
19 | seq_lens: 200 | ||
20 | num_classes: 10 | ||
21 | embed_dim: 9 | ||
22 | depth: 6 | ||
23 | num_heads: 1 | ||
24 | mlp_ratio: 4.0 | ||
25 | qkv_bias: true | ||
26 | qk_scale: null | ||
27 | drop_ratio: 0. | ||
28 | attn_drop_ratio: 0. | ||
29 | drop_path_ratio: 0. | ||
30 | norm_layer: null | ||
31 | act_layer: null | ||
32 | |||
33 | solver: | ||
34 | name: 'SLSolver' | ||
35 | args: | ||
36 | epoch: 100 | ||
37 | base_on: null | ||
38 | model_path: null | ||
39 | |||
40 | optimizer: | ||
41 | name: 'Adam' | ||
42 | args: | ||
43 | lr: !!float 1e-3 | ||
44 | # weight_decay: !!float 5e-5 | ||
45 | |||
46 | lr_scheduler: | ||
47 | name: 'CosineLR' | ||
48 | args: | ||
49 | epochs: 100 | ||
50 | lrf: 0.1 | ||
51 | |||
52 | loss: | ||
53 | name: 'MaskedSigmoidFocalLoss' | ||
54 | # name: 'SigmoidFocalLoss' | ||
55 | # name: 'CrossEntropyLoss' | ||
56 | args: | ||
57 | reduction: "mean" | ||
58 | alpha: 0.95 | ||
59 | |||
60 | logger: | ||
61 | log_root: '/Users/zhouweiqi/Downloads/test/logs' | ||
62 | suffix: 'sl-6-1' | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
... | @@ -60,6 +60,7 @@ solver: | ... | @@ -60,6 +60,7 @@ solver: |
60 | # name: 'CrossEntropyLoss' | 60 | # name: 'CrossEntropyLoss' |
61 | args: | 61 | args: |
62 | reduction: "mean" | 62 | reduction: "mean" |
63 | alpha: 0.95 | ||
63 | 64 | ||
64 | logger: | 65 | logger: |
65 | log_root: '/Users/zhouweiqi/Downloads/test/logs' | 66 | log_root: '/Users/zhouweiqi/Downloads/test/logs' | ... | ... |
data/SLData.py
0 → 100644
1 | import os | ||
2 | import json | ||
3 | import torch | ||
4 | from torch.utils.data import DataLoader, Dataset | ||
5 | import pandas as pd | ||
6 | from utils.registery import DATASET_REGISTRY | ||
7 | |||
8 | |||
9 | @DATASET_REGISTRY.register() | ||
10 | class SLData(Dataset): | ||
11 | |||
12 | def __init__(self, | ||
13 | data_root: str = '/Users/zhouweiqi/Downloads/gcfp/data/dataset', | ||
14 | anno_file: str = 'train.csv', | ||
15 | phase: str = 'train'): | ||
16 | self.data_root = data_root | ||
17 | self.df = pd.read_csv(anno_file) | ||
18 | self.phase = phase | ||
19 | |||
20 | |||
21 | def __len__(self): | ||
22 | return len(self.df) | ||
23 | |||
24 | def __getitem__(self, idx): | ||
25 | series = self.df.iloc[idx] | ||
26 | name = series['name'] | ||
27 | |||
28 | with open(os.path.join(self.data_root, self.phase, name), 'r') as fp: | ||
29 | input_list, label_list, valid_lens = json.load(fp) | ||
30 | |||
31 | input_tensor = torch.tensor(input_list) | ||
32 | label_tensor = torch.tensor(label_list).float() | ||
33 | |||
34 | return input_tensor, label_tensor, valid_lens | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
... | @@ -3,6 +3,7 @@ from torch.utils.data import DataLoader | ... | @@ -3,6 +3,7 @@ from torch.utils.data import DataLoader |
3 | from utils.registery import DATASET_REGISTRY | 3 | from utils.registery import DATASET_REGISTRY |
4 | 4 | ||
5 | from .CoordinatesData import CoordinatesData | 5 | from .CoordinatesData import CoordinatesData |
6 | from .SLData import SLData | ||
6 | 7 | ||
7 | 8 | ||
8 | def build_dataset(cfg): | 9 | def build_dataset(cfg): | ... | ... |
... | @@ -94,6 +94,7 @@ def build_dataset(img_dir, go_res_dir, label_dir, top_text_list, skip_list, save | ... | @@ -94,6 +94,7 @@ def build_dataset(img_dir, go_res_dir, label_dir, top_text_list, skip_list, save |
94 | label_res = load_json(label_json_path) | 94 | label_res = load_json(label_json_path) |
95 | 95 | ||
96 | # 开票日期 发票代码 机打号码 车辆类型 电话 | 96 | # 开票日期 发票代码 机打号码 车辆类型 电话 |
97 | # 发动机号码 车架号 帐号 开户银行 小写 | ||
97 | test_group_id = [1, 2, 5, 9, 20] | 98 | test_group_id = [1, 2, 5, 9, 20] |
98 | group_list = [] | 99 | group_list = [] |
99 | for group_id in test_group_id: | 100 | for group_id in test_group_id: | ... | ... |
data/create_dataset2.py
0 → 100644
This diff is collapsed.
Click to expand it.
... | @@ -2,6 +2,7 @@ import copy | ... | @@ -2,6 +2,7 @@ import copy |
2 | import torch | 2 | import torch |
3 | import inspect | 3 | import inspect |
4 | from utils.registery import LOSS_REGISTRY | 4 | from utils.registery import LOSS_REGISTRY |
5 | from utils import sequence_mask | ||
5 | from torchvision.ops import sigmoid_focal_loss | 6 | from torchvision.ops import sigmoid_focal_loss |
6 | 7 | ||
7 | class SigmoidFocalLoss(torch.nn.modules.loss._WeightedLoss): | 8 | class SigmoidFocalLoss(torch.nn.modules.loss._WeightedLoss): |
... | @@ -21,9 +22,31 @@ class SigmoidFocalLoss(torch.nn.modules.loss._WeightedLoss): | ... | @@ -21,9 +22,31 @@ class SigmoidFocalLoss(torch.nn.modules.loss._WeightedLoss): |
21 | def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: | 22 | def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: |
22 | return sigmoid_focal_loss(inputs, targets, self.alpha, self.gamma, self.reduction) | 23 | return sigmoid_focal_loss(inputs, targets, self.alpha, self.gamma, self.reduction) |
23 | 24 | ||
25 | class MaskedSigmoidFocalLoss(torch.nn.modules.loss._WeightedLoss): | ||
26 | |||
27 | def __init__(self, | ||
28 | weight= None, | ||
29 | size_average=None, | ||
30 | reduce=None, | ||
31 | reduction: str = 'mean', | ||
32 | alpha: float = 0.25, | ||
33 | gamma: float = 2): | ||
34 | super().__init__(weight, size_average, reduce, reduction) | ||
35 | self.alpha = alpha | ||
36 | self.gamma = gamma | ||
37 | self.reduction = reduction | ||
38 | |||
39 | def forward(self, inputs: torch.Tensor, targets: torch.Tensor, valid_lens) -> torch.Tensor: | ||
40 | weights = torch.ones_like(targets) | ||
41 | weights = sequence_mask(weights, valid_lens) | ||
42 | unweighted_loss = sigmoid_focal_loss(inputs, targets, self.alpha, self.gamma, reduction='none') | ||
43 | weighted_loss = (unweighted_loss * weights).mean(dim=-1) | ||
44 | return weighted_loss | ||
45 | |||
24 | 46 | ||
25 | def register_sigmoid_focal_loss(): | 47 | def register_sigmoid_focal_loss(): |
26 | LOSS_REGISTRY.register()(SigmoidFocalLoss) | 48 | LOSS_REGISTRY.register()(SigmoidFocalLoss) |
49 | LOSS_REGISTRY.register()(MaskedSigmoidFocalLoss) | ||
27 | 50 | ||
28 | 51 | ||
29 | def register_torch_loss(): | 52 | def register_torch_loss(): | ... | ... |
... | @@ -3,6 +3,7 @@ from utils import MODEL_REGISTRY | ... | @@ -3,6 +3,7 @@ from utils import MODEL_REGISTRY |
3 | 3 | ||
4 | from .mlp import MLPModel | 4 | from .mlp import MLPModel |
5 | from .vit import VisionTransformer | 5 | from .vit import VisionTransformer |
6 | from .seq_labeling import SLTransformer | ||
6 | 7 | ||
7 | 8 | ||
8 | def build_model(cfg): | 9 | def build_model(cfg): | ... | ... |
model/seq_labeling.py
0 → 100644
This diff is collapsed.
Click to expand it.
... | @@ -3,6 +3,7 @@ import copy | ... | @@ -3,6 +3,7 @@ import copy |
3 | from utils.registery import SOLVER_REGISTRY | 3 | from utils.registery import SOLVER_REGISTRY |
4 | from .mlp_solver import MLPSolver | 4 | from .mlp_solver import MLPSolver |
5 | from .vit_solver import VITSolver | 5 | from .vit_solver import VITSolver |
6 | from .sl_solver import SLSolver | ||
6 | 7 | ||
7 | 8 | ||
8 | def build_solver(cfg): | 9 | def build_solver(cfg): | ... | ... |
solver/sl_solver.py
0 → 100644
1 | import copy | ||
2 | import os | ||
3 | |||
4 | import torch | ||
5 | |||
6 | from data import build_dataloader | ||
7 | from loss import build_loss | ||
8 | from model import build_model | ||
9 | from optimizer import build_lr_scheduler, build_optimizer | ||
10 | from utils import SOLVER_REGISTRY, get_logger_and_log_dir | ||
11 | from utils import sequence_mask | ||
12 | from sklearn.metrics import confusion_matrix, accuracy_score, classification_report | ||
13 | |||
14 | |||
15 | @SOLVER_REGISTRY.register() | ||
16 | class SLSolver(object): | ||
17 | |||
18 | def __init__(self, cfg): | ||
19 | self.device = "cuda" if torch.cuda.is_available() else "cpu" | ||
20 | |||
21 | self.cfg = copy.deepcopy(cfg) | ||
22 | |||
23 | self.train_loader, self.val_loader = build_dataloader(cfg) | ||
24 | self.train_loader_size, self.val_loader_size = len(self.train_loader), len(self.val_loader) | ||
25 | self.train_dataset_size, self.val_dataset_size = len(self.train_loader.dataset), len(self.val_loader.dataset) | ||
26 | |||
27 | # BatchNorm ? | ||
28 | self.model = build_model(cfg).to(self.device) | ||
29 | |||
30 | self.loss_fn = build_loss(cfg) | ||
31 | |||
32 | self.optimizer = build_optimizer(cfg)(self.model.parameters(), **cfg['solver']['optimizer']['args']) | ||
33 | |||
34 | self.hyper_params = cfg['solver']['args'] | ||
35 | self.base_on = self.hyper_params['base_on'] | ||
36 | self.model_path = self.hyper_params['model_path'] | ||
37 | try: | ||
38 | self.epoch = self.hyper_params['epoch'] | ||
39 | except Exception: | ||
40 | raise 'should contain epoch in {solver.args}' | ||
41 | |||
42 | self.logger, self.log_dir = get_logger_and_log_dir(**cfg['solver']['logger']) | ||
43 | |||
44 | def accuracy(self, y_pred, y_true, valid_lens, thresholds=0.5): | ||
45 | # [batch_size, seq_len, num_classes] | ||
46 | y_pred_sigmoid = torch.nn.Sigmoid()(y_pred) | ||
47 | # [batch_size, seq_len] | ||
48 | y_pred_idx = torch.argmax(y_pred_sigmoid, dim=-1) + 1 | ||
49 | # [batch_size, seq_len] | ||
50 | y_pred_is_other = (torch.amax(y_pred_sigmoid, dim=-1) > thresholds).int() | ||
51 | y_pred_rebuild = torch.multiply(y_pred_idx, y_pred_is_other) | ||
52 | |||
53 | y_true_idx = torch.argmax(y_true, dim=-1) + 1 | ||
54 | y_true_is_other = torch.sum(y_true, dim=-1).int() | ||
55 | y_true_rebuild = torch.multiply(y_true_idx, y_true_is_other) | ||
56 | |||
57 | masked_y_true_rebuild = sequence_mask(y_true_rebuild, valid_lens, value=-1) | ||
58 | |||
59 | return torch.sum((y_pred_rebuild == masked_y_true_rebuild).int()).item() | ||
60 | |||
61 | def train_loop(self): | ||
62 | self.model.train() | ||
63 | |||
64 | seq_lens_sum = torch.zeros(1).to(self.device) | ||
65 | train_loss = torch.zeros(1).to(self.device) | ||
66 | correct = torch.zeros(1).to(self.device) | ||
67 | for batch, (X, y, valid_lens) in enumerate(self.train_loader): | ||
68 | X, y = X.to(self.device), y.to(self.device) | ||
69 | |||
70 | pred = self.model(X, valid_lens) | ||
71 | # [batch_size, seq_len, num_classes] | ||
72 | |||
73 | loss = self.loss_fn(pred, y, valid_lens) | ||
74 | train_loss += loss.sum() | ||
75 | |||
76 | if batch % 100 == 0: | ||
77 | loss_value, current = loss.sum().item(), batch | ||
78 | self.logger.info(f'train iteration: {current}/{self.train_loader_size}, train loss: {loss_value :.4f}') | ||
79 | |||
80 | self.optimizer.zero_grad() | ||
81 | loss.sum().backward() | ||
82 | self.optimizer.step() | ||
83 | |||
84 | seq_lens_sum += valid_lens.sum() | ||
85 | correct += self.accuracy(pred, y, valid_lens) | ||
86 | |||
87 | # correct /= self.train_dataset_size | ||
88 | correct /= seq_lens_sum | ||
89 | train_loss /= self.train_loader_size | ||
90 | self.logger.info(f'train accuracy: {correct.item() :.4f}, train mean loss: {train_loss.item() :.4f}') | ||
91 | |||
92 | @torch.no_grad() | ||
93 | def val_loop(self, t): | ||
94 | self.model.eval() | ||
95 | |||
96 | seq_lens_sum = torch.zeros(1).to(self.device) | ||
97 | val_loss = torch.zeros(1).to(self.device) | ||
98 | correct = torch.zeros(1).to(self.device) | ||
99 | for X, y, valid_lens in self.val_loader: | ||
100 | X, y = X.to(self.device), y.to(self.device) | ||
101 | |||
102 | # pred = torch.nn.Sigmoid()(self.model(X)) | ||
103 | pred = self.model(X, valid_lens) | ||
104 | # [batch_size, seq_len, num_classes] | ||
105 | |||
106 | loss = self.loss_fn(pred, y, valid_lens) | ||
107 | val_loss += loss.sum() | ||
108 | |||
109 | seq_lens_sum += valid_lens.sum() | ||
110 | correct += self.accuracy(pred, y, valid_lens) | ||
111 | |||
112 | # correct /= self.val_dataset_size | ||
113 | correct /= seq_lens_sum | ||
114 | val_loss /= self.val_loader_size | ||
115 | |||
116 | self.logger.info(f"val accuracy: {correct.item() :.4f}, val mean loss: {val_loss.item() :.4f}") | ||
117 | |||
118 | def save_checkpoint(self, epoch_id): | ||
119 | self.model.eval() | ||
120 | torch.save(self.model.state_dict(), os.path.join(self.log_dir, f'ckpt_epoch_{epoch_id}.pt')) | ||
121 | |||
122 | def run(self): | ||
123 | if isinstance(self.base_on, str) and os.path.exists(self.base_on): | ||
124 | self.model.load_state_dict(torch.load(self.base_on)) | ||
125 | self.logger.info(f'==> Load Model from {self.base_on}') | ||
126 | |||
127 | self.logger.info('==> Start Training') | ||
128 | print(self.model) | ||
129 | |||
130 | lr_scheduler = build_lr_scheduler(self.cfg)(self.optimizer, **self.cfg['solver']['lr_scheduler']['args']) | ||
131 | |||
132 | for t in range(self.epoch): | ||
133 | self.logger.info(f'==> epoch {t + 1}') | ||
134 | |||
135 | self.train_loop() | ||
136 | self.val_loop(t + 1) | ||
137 | self.save_checkpoint(t + 1) | ||
138 | |||
139 | lr_scheduler.step() | ||
140 | |||
141 | self.logger.info('==> End Training') | ||
142 | |||
143 | # def run(self): | ||
144 | # from torch.nn import functional | ||
145 | |||
146 | # y = functional.one_hot(torch.randint(0, 10, (8, 100)), 10) | ||
147 | # valid_lens = torch.randint(50, 100, (8, )) | ||
148 | # print(valid_lens) | ||
149 | |||
150 | # pred = functional.one_hot(torch.randint(0, 10, (8, 100)), 10) | ||
151 | |||
152 | # print(self.accuracy(pred, y, valid_lens)) | ||
153 | |||
154 | def evaluate(self): | ||
155 | if isinstance(self.model_path, str) and os.path.exists(self.model_path): | ||
156 | self.model.load_state_dict(torch.load(self.model_path)) | ||
157 | self.logger.info(f'==> Load Model from {self.model_path}') | ||
158 | else: | ||
159 | return | ||
160 | |||
161 | self.model.eval() | ||
162 | |||
163 | label_true_list = [] | ||
164 | label_pred_list = [] | ||
165 | for X, y in self.val_loader: | ||
166 | X, y_true = X.to(self.device), y.to(self.device) | ||
167 | |||
168 | # pred = torch.nn.Sigmoid()(self.model(X)) | ||
169 | pred = self.model(X) | ||
170 | |||
171 | y_pred = torch.nn.Sigmoid()(pred) | ||
172 | |||
173 | y_pred_idx = torch.argmax(y_pred, dim=1) + 1 | ||
174 | y_pred_is_other = (torch.amax(y_pred, dim=1) > 0.5).int() | ||
175 | y_pred_rebuild = torch.multiply(y_pred_idx, y_pred_is_other) | ||
176 | |||
177 | y_true_idx = torch.argmax(y_true, dim=1) + 1 | ||
178 | y_true_is_other = torch.sum(y_true, dim=1) | ||
179 | y_true_rebuild = torch.multiply(y_true_idx, y_true_is_other) | ||
180 | |||
181 | label_true_list.extend(y_true_rebuild.cpu().numpy().tolist()) | ||
182 | label_pred_list.extend(y_pred_rebuild.cpu().numpy().tolist()) | ||
183 | |||
184 | |||
185 | acc = accuracy_score(label_true_list, label_pred_list) | ||
186 | cm = confusion_matrix(label_true_list, label_pred_list) | ||
187 | report = classification_report(label_true_list, label_pred_list) | ||
188 | print(acc) | ||
189 | print(cm) | ||
190 | print(report) |
1 | import torch | ||
1 | from .registery import * | 2 | from .registery import * |
2 | from .logger import get_logger_and_log_dir | 3 | from .logger import get_logger_and_log_dir |
3 | 4 | ||
4 | __all__ = [ | 5 | __all__ = [ |
5 | 'Registry', | 6 | 'Registry', |
6 | 'get_logger_and_log_dir', | 7 | 'get_logger_and_log_dir', |
8 | 'sequence_mask', | ||
7 | ] | 9 | ] |
8 | 10 | ||
11 | def sequence_mask(X, valid_len, value=0): | ||
12 | """Mask irrelevant entries in sequences. | ||
13 | Defined in :numref:`sec_seq2seq_decoder`""" | ||
14 | maxlen = X.size(1) | ||
15 | mask = torch.arange((maxlen), dtype=torch.float32, device=X.device)[None, :] < valid_len[:, None] | ||
16 | X[~mask] = value | ||
17 | return X | ... | ... |
-
Please register or sign in to post a comment