add Seq Labeling solver
Showing
12 changed files
with
324 additions
and
1 deletions
config/sl.yaml
0 → 100644
| 1 | seed: 3407 | ||
| 2 | |||
| 3 | dataset: | ||
| 4 | name: 'SLData' | ||
| 5 | args: | ||
| 6 | data_root: '/Users/zhouweiqi/Downloads/gcfp/data/dataset2' | ||
| 7 | train_anno_file: '/Users/zhouweiqi/Downloads/gcfp/data/dataset2/train.csv' | ||
| 8 | val_anno_file: '/Users/zhouweiqi/Downloads/gcfp/data/dataset2/valid.csv' | ||
| 9 | |||
| 10 | dataloader: | ||
| 11 | batch_size: 8 | ||
| 12 | num_workers: 4 | ||
| 13 | pin_memory: true | ||
| 14 | shuffle: true | ||
| 15 | |||
| 16 | model: | ||
| 17 | name: 'SLTransformer' | ||
| 18 | args: | ||
| 19 | seq_lens: 200 | ||
| 20 | num_classes: 10 | ||
| 21 | embed_dim: 9 | ||
| 22 | depth: 6 | ||
| 23 | num_heads: 1 | ||
| 24 | mlp_ratio: 4.0 | ||
| 25 | qkv_bias: true | ||
| 26 | qk_scale: null | ||
| 27 | drop_ratio: 0. | ||
| 28 | attn_drop_ratio: 0. | ||
| 29 | drop_path_ratio: 0. | ||
| 30 | norm_layer: null | ||
| 31 | act_layer: null | ||
| 32 | |||
| 33 | solver: | ||
| 34 | name: 'SLSolver' | ||
| 35 | args: | ||
| 36 | epoch: 100 | ||
| 37 | base_on: null | ||
| 38 | model_path: null | ||
| 39 | |||
| 40 | optimizer: | ||
| 41 | name: 'Adam' | ||
| 42 | args: | ||
| 43 | lr: !!float 1e-3 | ||
| 44 | # weight_decay: !!float 5e-5 | ||
| 45 | |||
| 46 | lr_scheduler: | ||
| 47 | name: 'CosineLR' | ||
| 48 | args: | ||
| 49 | epochs: 100 | ||
| 50 | lrf: 0.1 | ||
| 51 | |||
| 52 | loss: | ||
| 53 | name: 'MaskedSigmoidFocalLoss' | ||
| 54 | # name: 'SigmoidFocalLoss' | ||
| 55 | # name: 'CrossEntropyLoss' | ||
| 56 | args: | ||
| 57 | reduction: "mean" | ||
| 58 | alpha: 0.95 | ||
| 59 | |||
| 60 | logger: | ||
| 61 | log_root: '/Users/zhouweiqi/Downloads/test/logs' | ||
| 62 | suffix: 'sl-6-1' | ||
| ... | \ No newline at end of file | ... | \ No newline at end of file |
| ... | @@ -60,6 +60,7 @@ solver: | ... | @@ -60,6 +60,7 @@ solver: |
| 60 | # name: 'CrossEntropyLoss' | 60 | # name: 'CrossEntropyLoss' |
| 61 | args: | 61 | args: |
| 62 | reduction: "mean" | 62 | reduction: "mean" |
| 63 | alpha: 0.95 | ||
| 63 | 64 | ||
| 64 | logger: | 65 | logger: |
| 65 | log_root: '/Users/zhouweiqi/Downloads/test/logs' | 66 | log_root: '/Users/zhouweiqi/Downloads/test/logs' | ... | ... |
data/SLData.py
0 → 100644
| 1 | import os | ||
| 2 | import json | ||
| 3 | import torch | ||
| 4 | from torch.utils.data import DataLoader, Dataset | ||
| 5 | import pandas as pd | ||
| 6 | from utils.registery import DATASET_REGISTRY | ||
| 7 | |||
| 8 | |||
| 9 | @DATASET_REGISTRY.register() | ||
| 10 | class SLData(Dataset): | ||
| 11 | |||
| 12 | def __init__(self, | ||
| 13 | data_root: str = '/Users/zhouweiqi/Downloads/gcfp/data/dataset', | ||
| 14 | anno_file: str = 'train.csv', | ||
| 15 | phase: str = 'train'): | ||
| 16 | self.data_root = data_root | ||
| 17 | self.df = pd.read_csv(anno_file) | ||
| 18 | self.phase = phase | ||
| 19 | |||
| 20 | |||
| 21 | def __len__(self): | ||
| 22 | return len(self.df) | ||
| 23 | |||
| 24 | def __getitem__(self, idx): | ||
| 25 | series = self.df.iloc[idx] | ||
| 26 | name = series['name'] | ||
| 27 | |||
| 28 | with open(os.path.join(self.data_root, self.phase, name), 'r') as fp: | ||
| 29 | input_list, label_list, valid_lens = json.load(fp) | ||
| 30 | |||
| 31 | input_tensor = torch.tensor(input_list) | ||
| 32 | label_tensor = torch.tensor(label_list).float() | ||
| 33 | |||
| 34 | return input_tensor, label_tensor, valid_lens | ||
| ... | \ No newline at end of file | ... | \ No newline at end of file |
| ... | @@ -3,6 +3,7 @@ from torch.utils.data import DataLoader | ... | @@ -3,6 +3,7 @@ from torch.utils.data import DataLoader |
| 3 | from utils.registery import DATASET_REGISTRY | 3 | from utils.registery import DATASET_REGISTRY |
| 4 | 4 | ||
| 5 | from .CoordinatesData import CoordinatesData | 5 | from .CoordinatesData import CoordinatesData |
| 6 | from .SLData import SLData | ||
| 6 | 7 | ||
| 7 | 8 | ||
| 8 | def build_dataset(cfg): | 9 | def build_dataset(cfg): | ... | ... |
| ... | @@ -93,7 +93,8 @@ def build_dataset(img_dir, go_res_dir, label_dir, top_text_list, skip_list, save | ... | @@ -93,7 +93,8 @@ def build_dataset(img_dir, go_res_dir, label_dir, top_text_list, skip_list, save |
| 93 | label_json_path = os.path.join(label_dir, '{0}.json'.format(base_image_name)) | 93 | label_json_path = os.path.join(label_dir, '{0}.json'.format(base_image_name)) |
| 94 | label_res = load_json(label_json_path) | 94 | label_res = load_json(label_json_path) |
| 95 | 95 | ||
| 96 | # 开票日期 发票代码 机打号码 车辆类型 电话 | 96 | # 开票日期 发票代码 机打号码 车辆类型 电话 |
| 97 | # 发动机号码 车架号 帐号 开户银行 小写 | ||
| 97 | test_group_id = [1, 2, 5, 9, 20] | 98 | test_group_id = [1, 2, 5, 9, 20] |
| 98 | group_list = [] | 99 | group_list = [] |
| 99 | for group_id in test_group_id: | 100 | for group_id in test_group_id: | ... | ... |
data/create_dataset2.py
0 → 100644
This diff is collapsed.
Click to expand it.
| ... | @@ -2,6 +2,7 @@ import copy | ... | @@ -2,6 +2,7 @@ import copy |
| 2 | import torch | 2 | import torch |
| 3 | import inspect | 3 | import inspect |
| 4 | from utils.registery import LOSS_REGISTRY | 4 | from utils.registery import LOSS_REGISTRY |
| 5 | from utils import sequence_mask | ||
| 5 | from torchvision.ops import sigmoid_focal_loss | 6 | from torchvision.ops import sigmoid_focal_loss |
| 6 | 7 | ||
| 7 | class SigmoidFocalLoss(torch.nn.modules.loss._WeightedLoss): | 8 | class SigmoidFocalLoss(torch.nn.modules.loss._WeightedLoss): |
| ... | @@ -21,9 +22,31 @@ class SigmoidFocalLoss(torch.nn.modules.loss._WeightedLoss): | ... | @@ -21,9 +22,31 @@ class SigmoidFocalLoss(torch.nn.modules.loss._WeightedLoss): |
| 21 | def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: | 22 | def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: |
| 22 | return sigmoid_focal_loss(inputs, targets, self.alpha, self.gamma, self.reduction) | 23 | return sigmoid_focal_loss(inputs, targets, self.alpha, self.gamma, self.reduction) |
| 23 | 24 | ||
| 25 | class MaskedSigmoidFocalLoss(torch.nn.modules.loss._WeightedLoss): | ||
| 26 | |||
| 27 | def __init__(self, | ||
| 28 | weight= None, | ||
| 29 | size_average=None, | ||
| 30 | reduce=None, | ||
| 31 | reduction: str = 'mean', | ||
| 32 | alpha: float = 0.25, | ||
| 33 | gamma: float = 2): | ||
| 34 | super().__init__(weight, size_average, reduce, reduction) | ||
| 35 | self.alpha = alpha | ||
| 36 | self.gamma = gamma | ||
| 37 | self.reduction = reduction | ||
| 38 | |||
| 39 | def forward(self, inputs: torch.Tensor, targets: torch.Tensor, valid_lens) -> torch.Tensor: | ||
| 40 | weights = torch.ones_like(targets) | ||
| 41 | weights = sequence_mask(weights, valid_lens) | ||
| 42 | unweighted_loss = sigmoid_focal_loss(inputs, targets, self.alpha, self.gamma, reduction='none') | ||
| 43 | weighted_loss = (unweighted_loss * weights).mean(dim=-1) | ||
| 44 | return weighted_loss | ||
| 45 | |||
| 24 | 46 | ||
| 25 | def register_sigmoid_focal_loss(): | 47 | def register_sigmoid_focal_loss(): |
| 26 | LOSS_REGISTRY.register()(SigmoidFocalLoss) | 48 | LOSS_REGISTRY.register()(SigmoidFocalLoss) |
| 49 | LOSS_REGISTRY.register()(MaskedSigmoidFocalLoss) | ||
| 27 | 50 | ||
| 28 | 51 | ||
| 29 | def register_torch_loss(): | 52 | def register_torch_loss(): | ... | ... |
| ... | @@ -3,6 +3,7 @@ from utils import MODEL_REGISTRY | ... | @@ -3,6 +3,7 @@ from utils import MODEL_REGISTRY |
| 3 | 3 | ||
| 4 | from .mlp import MLPModel | 4 | from .mlp import MLPModel |
| 5 | from .vit import VisionTransformer | 5 | from .vit import VisionTransformer |
| 6 | from .seq_labeling import SLTransformer | ||
| 6 | 7 | ||
| 7 | 8 | ||
| 8 | def build_model(cfg): | 9 | def build_model(cfg): | ... | ... |
model/seq_labeling.py
0 → 100644
This diff is collapsed.
Click to expand it.
| ... | @@ -3,6 +3,7 @@ import copy | ... | @@ -3,6 +3,7 @@ import copy |
| 3 | from utils.registery import SOLVER_REGISTRY | 3 | from utils.registery import SOLVER_REGISTRY |
| 4 | from .mlp_solver import MLPSolver | 4 | from .mlp_solver import MLPSolver |
| 5 | from .vit_solver import VITSolver | 5 | from .vit_solver import VITSolver |
| 6 | from .sl_solver import SLSolver | ||
| 6 | 7 | ||
| 7 | 8 | ||
| 8 | def build_solver(cfg): | 9 | def build_solver(cfg): | ... | ... |
solver/sl_solver.py
0 → 100644
| 1 | import copy | ||
| 2 | import os | ||
| 3 | |||
| 4 | import torch | ||
| 5 | |||
| 6 | from data import build_dataloader | ||
| 7 | from loss import build_loss | ||
| 8 | from model import build_model | ||
| 9 | from optimizer import build_lr_scheduler, build_optimizer | ||
| 10 | from utils import SOLVER_REGISTRY, get_logger_and_log_dir | ||
| 11 | from utils import sequence_mask | ||
| 12 | from sklearn.metrics import confusion_matrix, accuracy_score, classification_report | ||
| 13 | |||
| 14 | |||
| 15 | @SOLVER_REGISTRY.register() | ||
| 16 | class SLSolver(object): | ||
| 17 | |||
| 18 | def __init__(self, cfg): | ||
| 19 | self.device = "cuda" if torch.cuda.is_available() else "cpu" | ||
| 20 | |||
| 21 | self.cfg = copy.deepcopy(cfg) | ||
| 22 | |||
| 23 | self.train_loader, self.val_loader = build_dataloader(cfg) | ||
| 24 | self.train_loader_size, self.val_loader_size = len(self.train_loader), len(self.val_loader) | ||
| 25 | self.train_dataset_size, self.val_dataset_size = len(self.train_loader.dataset), len(self.val_loader.dataset) | ||
| 26 | |||
| 27 | # BatchNorm ? | ||
| 28 | self.model = build_model(cfg).to(self.device) | ||
| 29 | |||
| 30 | self.loss_fn = build_loss(cfg) | ||
| 31 | |||
| 32 | self.optimizer = build_optimizer(cfg)(self.model.parameters(), **cfg['solver']['optimizer']['args']) | ||
| 33 | |||
| 34 | self.hyper_params = cfg['solver']['args'] | ||
| 35 | self.base_on = self.hyper_params['base_on'] | ||
| 36 | self.model_path = self.hyper_params['model_path'] | ||
| 37 | try: | ||
| 38 | self.epoch = self.hyper_params['epoch'] | ||
| 39 | except Exception: | ||
| 40 | raise 'should contain epoch in {solver.args}' | ||
| 41 | |||
| 42 | self.logger, self.log_dir = get_logger_and_log_dir(**cfg['solver']['logger']) | ||
| 43 | |||
| 44 | def accuracy(self, y_pred, y_true, valid_lens, thresholds=0.5): | ||
| 45 | # [batch_size, seq_len, num_classes] | ||
| 46 | y_pred_sigmoid = torch.nn.Sigmoid()(y_pred) | ||
| 47 | # [batch_size, seq_len] | ||
| 48 | y_pred_idx = torch.argmax(y_pred_sigmoid, dim=-1) + 1 | ||
| 49 | # [batch_size, seq_len] | ||
| 50 | y_pred_is_other = (torch.amax(y_pred_sigmoid, dim=-1) > thresholds).int() | ||
| 51 | y_pred_rebuild = torch.multiply(y_pred_idx, y_pred_is_other) | ||
| 52 | |||
| 53 | y_true_idx = torch.argmax(y_true, dim=-1) + 1 | ||
| 54 | y_true_is_other = torch.sum(y_true, dim=-1).int() | ||
| 55 | y_true_rebuild = torch.multiply(y_true_idx, y_true_is_other) | ||
| 56 | |||
| 57 | masked_y_true_rebuild = sequence_mask(y_true_rebuild, valid_lens, value=-1) | ||
| 58 | |||
| 59 | return torch.sum((y_pred_rebuild == masked_y_true_rebuild).int()).item() | ||
| 60 | |||
| 61 | def train_loop(self): | ||
| 62 | self.model.train() | ||
| 63 | |||
| 64 | seq_lens_sum = torch.zeros(1).to(self.device) | ||
| 65 | train_loss = torch.zeros(1).to(self.device) | ||
| 66 | correct = torch.zeros(1).to(self.device) | ||
| 67 | for batch, (X, y, valid_lens) in enumerate(self.train_loader): | ||
| 68 | X, y = X.to(self.device), y.to(self.device) | ||
| 69 | |||
| 70 | pred = self.model(X, valid_lens) | ||
| 71 | # [batch_size, seq_len, num_classes] | ||
| 72 | |||
| 73 | loss = self.loss_fn(pred, y, valid_lens) | ||
| 74 | train_loss += loss.sum() | ||
| 75 | |||
| 76 | if batch % 100 == 0: | ||
| 77 | loss_value, current = loss.sum().item(), batch | ||
| 78 | self.logger.info(f'train iteration: {current}/{self.train_loader_size}, train loss: {loss_value :.4f}') | ||
| 79 | |||
| 80 | self.optimizer.zero_grad() | ||
| 81 | loss.sum().backward() | ||
| 82 | self.optimizer.step() | ||
| 83 | |||
| 84 | seq_lens_sum += valid_lens.sum() | ||
| 85 | correct += self.accuracy(pred, y, valid_lens) | ||
| 86 | |||
| 87 | # correct /= self.train_dataset_size | ||
| 88 | correct /= seq_lens_sum | ||
| 89 | train_loss /= self.train_loader_size | ||
| 90 | self.logger.info(f'train accuracy: {correct.item() :.4f}, train mean loss: {train_loss.item() :.4f}') | ||
| 91 | |||
| 92 | @torch.no_grad() | ||
| 93 | def val_loop(self, t): | ||
| 94 | self.model.eval() | ||
| 95 | |||
| 96 | seq_lens_sum = torch.zeros(1).to(self.device) | ||
| 97 | val_loss = torch.zeros(1).to(self.device) | ||
| 98 | correct = torch.zeros(1).to(self.device) | ||
| 99 | for X, y, valid_lens in self.val_loader: | ||
| 100 | X, y = X.to(self.device), y.to(self.device) | ||
| 101 | |||
| 102 | # pred = torch.nn.Sigmoid()(self.model(X)) | ||
| 103 | pred = self.model(X, valid_lens) | ||
| 104 | # [batch_size, seq_len, num_classes] | ||
| 105 | |||
| 106 | loss = self.loss_fn(pred, y, valid_lens) | ||
| 107 | val_loss += loss.sum() | ||
| 108 | |||
| 109 | seq_lens_sum += valid_lens.sum() | ||
| 110 | correct += self.accuracy(pred, y, valid_lens) | ||
| 111 | |||
| 112 | # correct /= self.val_dataset_size | ||
| 113 | correct /= seq_lens_sum | ||
| 114 | val_loss /= self.val_loader_size | ||
| 115 | |||
| 116 | self.logger.info(f"val accuracy: {correct.item() :.4f}, val mean loss: {val_loss.item() :.4f}") | ||
| 117 | |||
| 118 | def save_checkpoint(self, epoch_id): | ||
| 119 | self.model.eval() | ||
| 120 | torch.save(self.model.state_dict(), os.path.join(self.log_dir, f'ckpt_epoch_{epoch_id}.pt')) | ||
| 121 | |||
| 122 | def run(self): | ||
| 123 | if isinstance(self.base_on, str) and os.path.exists(self.base_on): | ||
| 124 | self.model.load_state_dict(torch.load(self.base_on)) | ||
| 125 | self.logger.info(f'==> Load Model from {self.base_on}') | ||
| 126 | |||
| 127 | self.logger.info('==> Start Training') | ||
| 128 | print(self.model) | ||
| 129 | |||
| 130 | lr_scheduler = build_lr_scheduler(self.cfg)(self.optimizer, **self.cfg['solver']['lr_scheduler']['args']) | ||
| 131 | |||
| 132 | for t in range(self.epoch): | ||
| 133 | self.logger.info(f'==> epoch {t + 1}') | ||
| 134 | |||
| 135 | self.train_loop() | ||
| 136 | self.val_loop(t + 1) | ||
| 137 | self.save_checkpoint(t + 1) | ||
| 138 | |||
| 139 | lr_scheduler.step() | ||
| 140 | |||
| 141 | self.logger.info('==> End Training') | ||
| 142 | |||
| 143 | # def run(self): | ||
| 144 | # from torch.nn import functional | ||
| 145 | |||
| 146 | # y = functional.one_hot(torch.randint(0, 10, (8, 100)), 10) | ||
| 147 | # valid_lens = torch.randint(50, 100, (8, )) | ||
| 148 | # print(valid_lens) | ||
| 149 | |||
| 150 | # pred = functional.one_hot(torch.randint(0, 10, (8, 100)), 10) | ||
| 151 | |||
| 152 | # print(self.accuracy(pred, y, valid_lens)) | ||
| 153 | |||
| 154 | def evaluate(self): | ||
| 155 | if isinstance(self.model_path, str) and os.path.exists(self.model_path): | ||
| 156 | self.model.load_state_dict(torch.load(self.model_path)) | ||
| 157 | self.logger.info(f'==> Load Model from {self.model_path}') | ||
| 158 | else: | ||
| 159 | return | ||
| 160 | |||
| 161 | self.model.eval() | ||
| 162 | |||
| 163 | label_true_list = [] | ||
| 164 | label_pred_list = [] | ||
| 165 | for X, y in self.val_loader: | ||
| 166 | X, y_true = X.to(self.device), y.to(self.device) | ||
| 167 | |||
| 168 | # pred = torch.nn.Sigmoid()(self.model(X)) | ||
| 169 | pred = self.model(X) | ||
| 170 | |||
| 171 | y_pred = torch.nn.Sigmoid()(pred) | ||
| 172 | |||
| 173 | y_pred_idx = torch.argmax(y_pred, dim=1) + 1 | ||
| 174 | y_pred_is_other = (torch.amax(y_pred, dim=1) > 0.5).int() | ||
| 175 | y_pred_rebuild = torch.multiply(y_pred_idx, y_pred_is_other) | ||
| 176 | |||
| 177 | y_true_idx = torch.argmax(y_true, dim=1) + 1 | ||
| 178 | y_true_is_other = torch.sum(y_true, dim=1) | ||
| 179 | y_true_rebuild = torch.multiply(y_true_idx, y_true_is_other) | ||
| 180 | |||
| 181 | label_true_list.extend(y_true_rebuild.cpu().numpy().tolist()) | ||
| 182 | label_pred_list.extend(y_pred_rebuild.cpu().numpy().tolist()) | ||
| 183 | |||
| 184 | |||
| 185 | acc = accuracy_score(label_true_list, label_pred_list) | ||
| 186 | cm = confusion_matrix(label_true_list, label_pred_list) | ||
| 187 | report = classification_report(label_true_list, label_pred_list) | ||
| 188 | print(acc) | ||
| 189 | print(cm) | ||
| 190 | print(report) |
| 1 | import torch | ||
| 1 | from .registery import * | 2 | from .registery import * |
| 2 | from .logger import get_logger_and_log_dir | 3 | from .logger import get_logger_and_log_dir |
| 3 | 4 | ||
| 4 | __all__ = [ | 5 | __all__ = [ |
| 5 | 'Registry', | 6 | 'Registry', |
| 6 | 'get_logger_and_log_dir', | 7 | 'get_logger_and_log_dir', |
| 8 | 'sequence_mask', | ||
| 7 | ] | 9 | ] |
| 8 | 10 | ||
| 11 | def sequence_mask(X, valid_len, value=0): | ||
| 12 | """Mask irrelevant entries in sequences. | ||
| 13 | Defined in :numref:`sec_seq2seq_decoder`""" | ||
| 14 | maxlen = X.size(1) | ||
| 15 | mask = torch.arange((maxlen), dtype=torch.float32, device=X.device)[None, :] < valid_len[:, None] | ||
| 16 | X[~mask] = value | ||
| 17 | return X | ... | ... |
-
Please register or sign in to post a comment