40ca6fe1 by 周伟奇

add Seq Labeling solver

1 parent b3694ec8
seed: 3407
dataset:
name: 'SLData'
args:
data_root: '/Users/zhouweiqi/Downloads/gcfp/data/dataset2'
train_anno_file: '/Users/zhouweiqi/Downloads/gcfp/data/dataset2/train.csv'
val_anno_file: '/Users/zhouweiqi/Downloads/gcfp/data/dataset2/valid.csv'
dataloader:
batch_size: 8
num_workers: 4
pin_memory: true
shuffle: true
model:
name: 'SLTransformer'
args:
seq_lens: 200
num_classes: 10
embed_dim: 9
depth: 6
num_heads: 1
mlp_ratio: 4.0
qkv_bias: true
qk_scale: null
drop_ratio: 0.
attn_drop_ratio: 0.
drop_path_ratio: 0.
norm_layer: null
act_layer: null
solver:
name: 'SLSolver'
args:
epoch: 100
base_on: null
model_path: null
optimizer:
name: 'Adam'
args:
lr: !!float 1e-3
# weight_decay: !!float 5e-5
lr_scheduler:
name: 'CosineLR'
args:
epochs: 100
lrf: 0.1
loss:
name: 'MaskedSigmoidFocalLoss'
# name: 'SigmoidFocalLoss'
# name: 'CrossEntropyLoss'
args:
reduction: "mean"
alpha: 0.95
logger:
log_root: '/Users/zhouweiqi/Downloads/test/logs'
suffix: 'sl-6-1'
\ No newline at end of file
......@@ -60,6 +60,7 @@ solver:
# name: 'CrossEntropyLoss'
args:
reduction: "mean"
alpha: 0.95
logger:
log_root: '/Users/zhouweiqi/Downloads/test/logs'
......
import os
import json
import torch
from torch.utils.data import DataLoader, Dataset
import pandas as pd
from utils.registery import DATASET_REGISTRY
@DATASET_REGISTRY.register()
class SLData(Dataset):
def __init__(self,
data_root: str = '/Users/zhouweiqi/Downloads/gcfp/data/dataset',
anno_file: str = 'train.csv',
phase: str = 'train'):
self.data_root = data_root
self.df = pd.read_csv(anno_file)
self.phase = phase
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
series = self.df.iloc[idx]
name = series['name']
with open(os.path.join(self.data_root, self.phase, name), 'r') as fp:
input_list, label_list, valid_lens = json.load(fp)
input_tensor = torch.tensor(input_list)
label_tensor = torch.tensor(label_list).float()
return input_tensor, label_tensor, valid_lens
\ No newline at end of file
......@@ -3,6 +3,7 @@ from torch.utils.data import DataLoader
from utils.registery import DATASET_REGISTRY
from .CoordinatesData import CoordinatesData
from .SLData import SLData
def build_dataset(cfg):
......
......@@ -94,6 +94,7 @@ def build_dataset(img_dir, go_res_dir, label_dir, top_text_list, skip_list, save
label_res = load_json(label_json_path)
# 开票日期 发票代码 机打号码 车辆类型 电话
# 发动机号码 车架号 帐号 开户银行 小写
test_group_id = [1, 2, 5, 9, 20]
group_list = []
for group_id in test_group_id:
......
......@@ -2,6 +2,7 @@ import copy
import torch
import inspect
from utils.registery import LOSS_REGISTRY
from utils import sequence_mask
from torchvision.ops import sigmoid_focal_loss
class SigmoidFocalLoss(torch.nn.modules.loss._WeightedLoss):
......@@ -21,9 +22,31 @@ class SigmoidFocalLoss(torch.nn.modules.loss._WeightedLoss):
def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
return sigmoid_focal_loss(inputs, targets, self.alpha, self.gamma, self.reduction)
class MaskedSigmoidFocalLoss(torch.nn.modules.loss._WeightedLoss):
def __init__(self,
weight= None,
size_average=None,
reduce=None,
reduction: str = 'mean',
alpha: float = 0.25,
gamma: float = 2):
super().__init__(weight, size_average, reduce, reduction)
self.alpha = alpha
self.gamma = gamma
self.reduction = reduction
def forward(self, inputs: torch.Tensor, targets: torch.Tensor, valid_lens) -> torch.Tensor:
weights = torch.ones_like(targets)
weights = sequence_mask(weights, valid_lens)
unweighted_loss = sigmoid_focal_loss(inputs, targets, self.alpha, self.gamma, reduction='none')
weighted_loss = (unweighted_loss * weights).mean(dim=-1)
return weighted_loss
def register_sigmoid_focal_loss():
LOSS_REGISTRY.register()(SigmoidFocalLoss)
LOSS_REGISTRY.register()(MaskedSigmoidFocalLoss)
def register_torch_loss():
......
......@@ -3,6 +3,7 @@ from utils import MODEL_REGISTRY
from .mlp import MLPModel
from .vit import VisionTransformer
from .seq_labeling import SLTransformer
def build_model(cfg):
......
......@@ -3,6 +3,7 @@ import copy
from utils.registery import SOLVER_REGISTRY
from .mlp_solver import MLPSolver
from .vit_solver import VITSolver
from .sl_solver import SLSolver
def build_solver(cfg):
......
import copy
import os
import torch
from data import build_dataloader
from loss import build_loss
from model import build_model
from optimizer import build_lr_scheduler, build_optimizer
from utils import SOLVER_REGISTRY, get_logger_and_log_dir
from utils import sequence_mask
from sklearn.metrics import confusion_matrix, accuracy_score, classification_report
@SOLVER_REGISTRY.register()
class SLSolver(object):
def __init__(self, cfg):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.cfg = copy.deepcopy(cfg)
self.train_loader, self.val_loader = build_dataloader(cfg)
self.train_loader_size, self.val_loader_size = len(self.train_loader), len(self.val_loader)
self.train_dataset_size, self.val_dataset_size = len(self.train_loader.dataset), len(self.val_loader.dataset)
# BatchNorm ?
self.model = build_model(cfg).to(self.device)
self.loss_fn = build_loss(cfg)
self.optimizer = build_optimizer(cfg)(self.model.parameters(), **cfg['solver']['optimizer']['args'])
self.hyper_params = cfg['solver']['args']
self.base_on = self.hyper_params['base_on']
self.model_path = self.hyper_params['model_path']
try:
self.epoch = self.hyper_params['epoch']
except Exception:
raise 'should contain epoch in {solver.args}'
self.logger, self.log_dir = get_logger_and_log_dir(**cfg['solver']['logger'])
def accuracy(self, y_pred, y_true, valid_lens, thresholds=0.5):
# [batch_size, seq_len, num_classes]
y_pred_sigmoid = torch.nn.Sigmoid()(y_pred)
# [batch_size, seq_len]
y_pred_idx = torch.argmax(y_pred_sigmoid, dim=-1) + 1
# [batch_size, seq_len]
y_pred_is_other = (torch.amax(y_pred_sigmoid, dim=-1) > thresholds).int()
y_pred_rebuild = torch.multiply(y_pred_idx, y_pred_is_other)
y_true_idx = torch.argmax(y_true, dim=-1) + 1
y_true_is_other = torch.sum(y_true, dim=-1).int()
y_true_rebuild = torch.multiply(y_true_idx, y_true_is_other)
masked_y_true_rebuild = sequence_mask(y_true_rebuild, valid_lens, value=-1)
return torch.sum((y_pred_rebuild == masked_y_true_rebuild).int()).item()
def train_loop(self):
self.model.train()
seq_lens_sum = torch.zeros(1).to(self.device)
train_loss = torch.zeros(1).to(self.device)
correct = torch.zeros(1).to(self.device)
for batch, (X, y, valid_lens) in enumerate(self.train_loader):
X, y = X.to(self.device), y.to(self.device)
pred = self.model(X, valid_lens)
# [batch_size, seq_len, num_classes]
loss = self.loss_fn(pred, y, valid_lens)
train_loss += loss.sum()
if batch % 100 == 0:
loss_value, current = loss.sum().item(), batch
self.logger.info(f'train iteration: {current}/{self.train_loader_size}, train loss: {loss_value :.4f}')
self.optimizer.zero_grad()
loss.sum().backward()
self.optimizer.step()
seq_lens_sum += valid_lens.sum()
correct += self.accuracy(pred, y, valid_lens)
# correct /= self.train_dataset_size
correct /= seq_lens_sum
train_loss /= self.train_loader_size
self.logger.info(f'train accuracy: {correct.item() :.4f}, train mean loss: {train_loss.item() :.4f}')
@torch.no_grad()
def val_loop(self, t):
self.model.eval()
seq_lens_sum = torch.zeros(1).to(self.device)
val_loss = torch.zeros(1).to(self.device)
correct = torch.zeros(1).to(self.device)
for X, y, valid_lens in self.val_loader:
X, y = X.to(self.device), y.to(self.device)
# pred = torch.nn.Sigmoid()(self.model(X))
pred = self.model(X, valid_lens)
# [batch_size, seq_len, num_classes]
loss = self.loss_fn(pred, y, valid_lens)
val_loss += loss.sum()
seq_lens_sum += valid_lens.sum()
correct += self.accuracy(pred, y, valid_lens)
# correct /= self.val_dataset_size
correct /= seq_lens_sum
val_loss /= self.val_loader_size
self.logger.info(f"val accuracy: {correct.item() :.4f}, val mean loss: {val_loss.item() :.4f}")
def save_checkpoint(self, epoch_id):
self.model.eval()
torch.save(self.model.state_dict(), os.path.join(self.log_dir, f'ckpt_epoch_{epoch_id}.pt'))
def run(self):
if isinstance(self.base_on, str) and os.path.exists(self.base_on):
self.model.load_state_dict(torch.load(self.base_on))
self.logger.info(f'==> Load Model from {self.base_on}')
self.logger.info('==> Start Training')
print(self.model)
lr_scheduler = build_lr_scheduler(self.cfg)(self.optimizer, **self.cfg['solver']['lr_scheduler']['args'])
for t in range(self.epoch):
self.logger.info(f'==> epoch {t + 1}')
self.train_loop()
self.val_loop(t + 1)
self.save_checkpoint(t + 1)
lr_scheduler.step()
self.logger.info('==> End Training')
# def run(self):
# from torch.nn import functional
# y = functional.one_hot(torch.randint(0, 10, (8, 100)), 10)
# valid_lens = torch.randint(50, 100, (8, ))
# print(valid_lens)
# pred = functional.one_hot(torch.randint(0, 10, (8, 100)), 10)
# print(self.accuracy(pred, y, valid_lens))
def evaluate(self):
if isinstance(self.model_path, str) and os.path.exists(self.model_path):
self.model.load_state_dict(torch.load(self.model_path))
self.logger.info(f'==> Load Model from {self.model_path}')
else:
return
self.model.eval()
label_true_list = []
label_pred_list = []
for X, y in self.val_loader:
X, y_true = X.to(self.device), y.to(self.device)
# pred = torch.nn.Sigmoid()(self.model(X))
pred = self.model(X)
y_pred = torch.nn.Sigmoid()(pred)
y_pred_idx = torch.argmax(y_pred, dim=1) + 1
y_pred_is_other = (torch.amax(y_pred, dim=1) > 0.5).int()
y_pred_rebuild = torch.multiply(y_pred_idx, y_pred_is_other)
y_true_idx = torch.argmax(y_true, dim=1) + 1
y_true_is_other = torch.sum(y_true, dim=1)
y_true_rebuild = torch.multiply(y_true_idx, y_true_is_other)
label_true_list.extend(y_true_rebuild.cpu().numpy().tolist())
label_pred_list.extend(y_pred_rebuild.cpu().numpy().tolist())
acc = accuracy_score(label_true_list, label_pred_list)
cm = confusion_matrix(label_true_list, label_pred_list)
report = classification_report(label_true_list, label_pred_list)
print(acc)
print(cm)
print(report)
import torch
from .registery import *
from .logger import get_logger_and_log_dir
__all__ = [
'Registry',
'get_logger_and_log_dir',
'sequence_mask',
]
def sequence_mask(X, valid_len, value=0):
"""Mask irrelevant entries in sequences.
Defined in :numref:`sec_seq2seq_decoder`"""
maxlen = X.size(1)
mask = torch.arange((maxlen), dtype=torch.float32, device=X.device)[None, :] < valid_len[:, None]
X[~mask] = value
return X
......
Styling with Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!