40ca6fe1 by 周伟奇

add Seq Labeling solver

1 parent b3694ec8
1 seed: 3407
2
3 dataset:
4 name: 'SLData'
5 args:
6 data_root: '/Users/zhouweiqi/Downloads/gcfp/data/dataset2'
7 train_anno_file: '/Users/zhouweiqi/Downloads/gcfp/data/dataset2/train.csv'
8 val_anno_file: '/Users/zhouweiqi/Downloads/gcfp/data/dataset2/valid.csv'
9
10 dataloader:
11 batch_size: 8
12 num_workers: 4
13 pin_memory: true
14 shuffle: true
15
16 model:
17 name: 'SLTransformer'
18 args:
19 seq_lens: 200
20 num_classes: 10
21 embed_dim: 9
22 depth: 6
23 num_heads: 1
24 mlp_ratio: 4.0
25 qkv_bias: true
26 qk_scale: null
27 drop_ratio: 0.
28 attn_drop_ratio: 0.
29 drop_path_ratio: 0.
30 norm_layer: null
31 act_layer: null
32
33 solver:
34 name: 'SLSolver'
35 args:
36 epoch: 100
37 base_on: null
38 model_path: null
39
40 optimizer:
41 name: 'Adam'
42 args:
43 lr: !!float 1e-3
44 # weight_decay: !!float 5e-5
45
46 lr_scheduler:
47 name: 'CosineLR'
48 args:
49 epochs: 100
50 lrf: 0.1
51
52 loss:
53 name: 'MaskedSigmoidFocalLoss'
54 # name: 'SigmoidFocalLoss'
55 # name: 'CrossEntropyLoss'
56 args:
57 reduction: "mean"
58 alpha: 0.95
59
60 logger:
61 log_root: '/Users/zhouweiqi/Downloads/test/logs'
62 suffix: 'sl-6-1'
...\ No newline at end of file ...\ No newline at end of file
...@@ -60,6 +60,7 @@ solver: ...@@ -60,6 +60,7 @@ solver:
60 # name: 'CrossEntropyLoss' 60 # name: 'CrossEntropyLoss'
61 args: 61 args:
62 reduction: "mean" 62 reduction: "mean"
63 alpha: 0.95
63 64
64 logger: 65 logger:
65 log_root: '/Users/zhouweiqi/Downloads/test/logs' 66 log_root: '/Users/zhouweiqi/Downloads/test/logs'
......
1 import os
2 import json
3 import torch
4 from torch.utils.data import DataLoader, Dataset
5 import pandas as pd
6 from utils.registery import DATASET_REGISTRY
7
8
9 @DATASET_REGISTRY.register()
10 class SLData(Dataset):
11
12 def __init__(self,
13 data_root: str = '/Users/zhouweiqi/Downloads/gcfp/data/dataset',
14 anno_file: str = 'train.csv',
15 phase: str = 'train'):
16 self.data_root = data_root
17 self.df = pd.read_csv(anno_file)
18 self.phase = phase
19
20
21 def __len__(self):
22 return len(self.df)
23
24 def __getitem__(self, idx):
25 series = self.df.iloc[idx]
26 name = series['name']
27
28 with open(os.path.join(self.data_root, self.phase, name), 'r') as fp:
29 input_list, label_list, valid_lens = json.load(fp)
30
31 input_tensor = torch.tensor(input_list)
32 label_tensor = torch.tensor(label_list).float()
33
34 return input_tensor, label_tensor, valid_lens
...\ No newline at end of file ...\ No newline at end of file
...@@ -3,6 +3,7 @@ from torch.utils.data import DataLoader ...@@ -3,6 +3,7 @@ from torch.utils.data import DataLoader
3 from utils.registery import DATASET_REGISTRY 3 from utils.registery import DATASET_REGISTRY
4 4
5 from .CoordinatesData import CoordinatesData 5 from .CoordinatesData import CoordinatesData
6 from .SLData import SLData
6 7
7 8
8 def build_dataset(cfg): 9 def build_dataset(cfg):
......
...@@ -94,6 +94,7 @@ def build_dataset(img_dir, go_res_dir, label_dir, top_text_list, skip_list, save ...@@ -94,6 +94,7 @@ def build_dataset(img_dir, go_res_dir, label_dir, top_text_list, skip_list, save
94 label_res = load_json(label_json_path) 94 label_res = load_json(label_json_path)
95 95
96 # 开票日期 发票代码 机打号码 车辆类型 电话 96 # 开票日期 发票代码 机打号码 车辆类型 电话
97 # 发动机号码 车架号 帐号 开户银行 小写
97 test_group_id = [1, 2, 5, 9, 20] 98 test_group_id = [1, 2, 5, 9, 20]
98 group_list = [] 99 group_list = []
99 for group_id in test_group_id: 100 for group_id in test_group_id:
......
...@@ -2,6 +2,7 @@ import copy ...@@ -2,6 +2,7 @@ import copy
2 import torch 2 import torch
3 import inspect 3 import inspect
4 from utils.registery import LOSS_REGISTRY 4 from utils.registery import LOSS_REGISTRY
5 from utils import sequence_mask
5 from torchvision.ops import sigmoid_focal_loss 6 from torchvision.ops import sigmoid_focal_loss
6 7
7 class SigmoidFocalLoss(torch.nn.modules.loss._WeightedLoss): 8 class SigmoidFocalLoss(torch.nn.modules.loss._WeightedLoss):
...@@ -21,9 +22,31 @@ class SigmoidFocalLoss(torch.nn.modules.loss._WeightedLoss): ...@@ -21,9 +22,31 @@ class SigmoidFocalLoss(torch.nn.modules.loss._WeightedLoss):
21 def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: 22 def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
22 return sigmoid_focal_loss(inputs, targets, self.alpha, self.gamma, self.reduction) 23 return sigmoid_focal_loss(inputs, targets, self.alpha, self.gamma, self.reduction)
23 24
25 class MaskedSigmoidFocalLoss(torch.nn.modules.loss._WeightedLoss):
26
27 def __init__(self,
28 weight= None,
29 size_average=None,
30 reduce=None,
31 reduction: str = 'mean',
32 alpha: float = 0.25,
33 gamma: float = 2):
34 super().__init__(weight, size_average, reduce, reduction)
35 self.alpha = alpha
36 self.gamma = gamma
37 self.reduction = reduction
38
39 def forward(self, inputs: torch.Tensor, targets: torch.Tensor, valid_lens) -> torch.Tensor:
40 weights = torch.ones_like(targets)
41 weights = sequence_mask(weights, valid_lens)
42 unweighted_loss = sigmoid_focal_loss(inputs, targets, self.alpha, self.gamma, reduction='none')
43 weighted_loss = (unweighted_loss * weights).mean(dim=-1)
44 return weighted_loss
45
24 46
25 def register_sigmoid_focal_loss(): 47 def register_sigmoid_focal_loss():
26 LOSS_REGISTRY.register()(SigmoidFocalLoss) 48 LOSS_REGISTRY.register()(SigmoidFocalLoss)
49 LOSS_REGISTRY.register()(MaskedSigmoidFocalLoss)
27 50
28 51
29 def register_torch_loss(): 52 def register_torch_loss():
......
...@@ -3,6 +3,7 @@ from utils import MODEL_REGISTRY ...@@ -3,6 +3,7 @@ from utils import MODEL_REGISTRY
3 3
4 from .mlp import MLPModel 4 from .mlp import MLPModel
5 from .vit import VisionTransformer 5 from .vit import VisionTransformer
6 from .seq_labeling import SLTransformer
6 7
7 8
8 def build_model(cfg): 9 def build_model(cfg):
......
...@@ -3,6 +3,7 @@ import copy ...@@ -3,6 +3,7 @@ import copy
3 from utils.registery import SOLVER_REGISTRY 3 from utils.registery import SOLVER_REGISTRY
4 from .mlp_solver import MLPSolver 4 from .mlp_solver import MLPSolver
5 from .vit_solver import VITSolver 5 from .vit_solver import VITSolver
6 from .sl_solver import SLSolver
6 7
7 8
8 def build_solver(cfg): 9 def build_solver(cfg):
......
1 import copy
2 import os
3
4 import torch
5
6 from data import build_dataloader
7 from loss import build_loss
8 from model import build_model
9 from optimizer import build_lr_scheduler, build_optimizer
10 from utils import SOLVER_REGISTRY, get_logger_and_log_dir
11 from utils import sequence_mask
12 from sklearn.metrics import confusion_matrix, accuracy_score, classification_report
13
14
15 @SOLVER_REGISTRY.register()
16 class SLSolver(object):
17
18 def __init__(self, cfg):
19 self.device = "cuda" if torch.cuda.is_available() else "cpu"
20
21 self.cfg = copy.deepcopy(cfg)
22
23 self.train_loader, self.val_loader = build_dataloader(cfg)
24 self.train_loader_size, self.val_loader_size = len(self.train_loader), len(self.val_loader)
25 self.train_dataset_size, self.val_dataset_size = len(self.train_loader.dataset), len(self.val_loader.dataset)
26
27 # BatchNorm ?
28 self.model = build_model(cfg).to(self.device)
29
30 self.loss_fn = build_loss(cfg)
31
32 self.optimizer = build_optimizer(cfg)(self.model.parameters(), **cfg['solver']['optimizer']['args'])
33
34 self.hyper_params = cfg['solver']['args']
35 self.base_on = self.hyper_params['base_on']
36 self.model_path = self.hyper_params['model_path']
37 try:
38 self.epoch = self.hyper_params['epoch']
39 except Exception:
40 raise 'should contain epoch in {solver.args}'
41
42 self.logger, self.log_dir = get_logger_and_log_dir(**cfg['solver']['logger'])
43
44 def accuracy(self, y_pred, y_true, valid_lens, thresholds=0.5):
45 # [batch_size, seq_len, num_classes]
46 y_pred_sigmoid = torch.nn.Sigmoid()(y_pred)
47 # [batch_size, seq_len]
48 y_pred_idx = torch.argmax(y_pred_sigmoid, dim=-1) + 1
49 # [batch_size, seq_len]
50 y_pred_is_other = (torch.amax(y_pred_sigmoid, dim=-1) > thresholds).int()
51 y_pred_rebuild = torch.multiply(y_pred_idx, y_pred_is_other)
52
53 y_true_idx = torch.argmax(y_true, dim=-1) + 1
54 y_true_is_other = torch.sum(y_true, dim=-1).int()
55 y_true_rebuild = torch.multiply(y_true_idx, y_true_is_other)
56
57 masked_y_true_rebuild = sequence_mask(y_true_rebuild, valid_lens, value=-1)
58
59 return torch.sum((y_pred_rebuild == masked_y_true_rebuild).int()).item()
60
61 def train_loop(self):
62 self.model.train()
63
64 seq_lens_sum = torch.zeros(1).to(self.device)
65 train_loss = torch.zeros(1).to(self.device)
66 correct = torch.zeros(1).to(self.device)
67 for batch, (X, y, valid_lens) in enumerate(self.train_loader):
68 X, y = X.to(self.device), y.to(self.device)
69
70 pred = self.model(X, valid_lens)
71 # [batch_size, seq_len, num_classes]
72
73 loss = self.loss_fn(pred, y, valid_lens)
74 train_loss += loss.sum()
75
76 if batch % 100 == 0:
77 loss_value, current = loss.sum().item(), batch
78 self.logger.info(f'train iteration: {current}/{self.train_loader_size}, train loss: {loss_value :.4f}')
79
80 self.optimizer.zero_grad()
81 loss.sum().backward()
82 self.optimizer.step()
83
84 seq_lens_sum += valid_lens.sum()
85 correct += self.accuracy(pred, y, valid_lens)
86
87 # correct /= self.train_dataset_size
88 correct /= seq_lens_sum
89 train_loss /= self.train_loader_size
90 self.logger.info(f'train accuracy: {correct.item() :.4f}, train mean loss: {train_loss.item() :.4f}')
91
92 @torch.no_grad()
93 def val_loop(self, t):
94 self.model.eval()
95
96 seq_lens_sum = torch.zeros(1).to(self.device)
97 val_loss = torch.zeros(1).to(self.device)
98 correct = torch.zeros(1).to(self.device)
99 for X, y, valid_lens in self.val_loader:
100 X, y = X.to(self.device), y.to(self.device)
101
102 # pred = torch.nn.Sigmoid()(self.model(X))
103 pred = self.model(X, valid_lens)
104 # [batch_size, seq_len, num_classes]
105
106 loss = self.loss_fn(pred, y, valid_lens)
107 val_loss += loss.sum()
108
109 seq_lens_sum += valid_lens.sum()
110 correct += self.accuracy(pred, y, valid_lens)
111
112 # correct /= self.val_dataset_size
113 correct /= seq_lens_sum
114 val_loss /= self.val_loader_size
115
116 self.logger.info(f"val accuracy: {correct.item() :.4f}, val mean loss: {val_loss.item() :.4f}")
117
118 def save_checkpoint(self, epoch_id):
119 self.model.eval()
120 torch.save(self.model.state_dict(), os.path.join(self.log_dir, f'ckpt_epoch_{epoch_id}.pt'))
121
122 def run(self):
123 if isinstance(self.base_on, str) and os.path.exists(self.base_on):
124 self.model.load_state_dict(torch.load(self.base_on))
125 self.logger.info(f'==> Load Model from {self.base_on}')
126
127 self.logger.info('==> Start Training')
128 print(self.model)
129
130 lr_scheduler = build_lr_scheduler(self.cfg)(self.optimizer, **self.cfg['solver']['lr_scheduler']['args'])
131
132 for t in range(self.epoch):
133 self.logger.info(f'==> epoch {t + 1}')
134
135 self.train_loop()
136 self.val_loop(t + 1)
137 self.save_checkpoint(t + 1)
138
139 lr_scheduler.step()
140
141 self.logger.info('==> End Training')
142
143 # def run(self):
144 # from torch.nn import functional
145
146 # y = functional.one_hot(torch.randint(0, 10, (8, 100)), 10)
147 # valid_lens = torch.randint(50, 100, (8, ))
148 # print(valid_lens)
149
150 # pred = functional.one_hot(torch.randint(0, 10, (8, 100)), 10)
151
152 # print(self.accuracy(pred, y, valid_lens))
153
154 def evaluate(self):
155 if isinstance(self.model_path, str) and os.path.exists(self.model_path):
156 self.model.load_state_dict(torch.load(self.model_path))
157 self.logger.info(f'==> Load Model from {self.model_path}')
158 else:
159 return
160
161 self.model.eval()
162
163 label_true_list = []
164 label_pred_list = []
165 for X, y in self.val_loader:
166 X, y_true = X.to(self.device), y.to(self.device)
167
168 # pred = torch.nn.Sigmoid()(self.model(X))
169 pred = self.model(X)
170
171 y_pred = torch.nn.Sigmoid()(pred)
172
173 y_pred_idx = torch.argmax(y_pred, dim=1) + 1
174 y_pred_is_other = (torch.amax(y_pred, dim=1) > 0.5).int()
175 y_pred_rebuild = torch.multiply(y_pred_idx, y_pred_is_other)
176
177 y_true_idx = torch.argmax(y_true, dim=1) + 1
178 y_true_is_other = torch.sum(y_true, dim=1)
179 y_true_rebuild = torch.multiply(y_true_idx, y_true_is_other)
180
181 label_true_list.extend(y_true_rebuild.cpu().numpy().tolist())
182 label_pred_list.extend(y_pred_rebuild.cpu().numpy().tolist())
183
184
185 acc = accuracy_score(label_true_list, label_pred_list)
186 cm = confusion_matrix(label_true_list, label_pred_list)
187 report = classification_report(label_true_list, label_pred_list)
188 print(acc)
189 print(cm)
190 print(report)
1 import torch
1 from .registery import * 2 from .registery import *
2 from .logger import get_logger_and_log_dir 3 from .logger import get_logger_and_log_dir
3 4
4 __all__ = [ 5 __all__ = [
5 'Registry', 6 'Registry',
6 'get_logger_and_log_dir', 7 'get_logger_and_log_dir',
8 'sequence_mask',
7 ] 9 ]
8 10
11 def sequence_mask(X, valid_len, value=0):
12 """Mask irrelevant entries in sequences.
13 Defined in :numref:`sec_seq2seq_decoder`"""
14 maxlen = X.size(1)
15 mask = torch.arange((maxlen), dtype=torch.float32, device=X.device)[None, :] < valid_len[:, None]
16 X[~mask] = value
17 return X
......
Styling with Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!