40ca6fe1 by 周伟奇

add Seq Labeling solver

1 parent b3694ec8
seed: 3407
dataset:
name: 'SLData'
args:
data_root: '/Users/zhouweiqi/Downloads/gcfp/data/dataset2'
train_anno_file: '/Users/zhouweiqi/Downloads/gcfp/data/dataset2/train.csv'
val_anno_file: '/Users/zhouweiqi/Downloads/gcfp/data/dataset2/valid.csv'
dataloader:
batch_size: 8
num_workers: 4
pin_memory: true
shuffle: true
model:
name: 'SLTransformer'
args:
seq_lens: 200
num_classes: 10
embed_dim: 9
depth: 6
num_heads: 1
mlp_ratio: 4.0
qkv_bias: true
qk_scale: null
drop_ratio: 0.
attn_drop_ratio: 0.
drop_path_ratio: 0.
norm_layer: null
act_layer: null
solver:
name: 'SLSolver'
args:
epoch: 100
base_on: null
model_path: null
optimizer:
name: 'Adam'
args:
lr: !!float 1e-3
# weight_decay: !!float 5e-5
lr_scheduler:
name: 'CosineLR'
args:
epochs: 100
lrf: 0.1
loss:
name: 'MaskedSigmoidFocalLoss'
# name: 'SigmoidFocalLoss'
# name: 'CrossEntropyLoss'
args:
reduction: "mean"
alpha: 0.95
logger:
log_root: '/Users/zhouweiqi/Downloads/test/logs'
suffix: 'sl-6-1'
\ No newline at end of file
......@@ -60,6 +60,7 @@ solver:
# name: 'CrossEntropyLoss'
args:
reduction: "mean"
alpha: 0.95
logger:
log_root: '/Users/zhouweiqi/Downloads/test/logs'
......
import os
import json
import torch
from torch.utils.data import DataLoader, Dataset
import pandas as pd
from utils.registery import DATASET_REGISTRY
@DATASET_REGISTRY.register()
class SLData(Dataset):
def __init__(self,
data_root: str = '/Users/zhouweiqi/Downloads/gcfp/data/dataset',
anno_file: str = 'train.csv',
phase: str = 'train'):
self.data_root = data_root
self.df = pd.read_csv(anno_file)
self.phase = phase
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
series = self.df.iloc[idx]
name = series['name']
with open(os.path.join(self.data_root, self.phase, name), 'r') as fp:
input_list, label_list, valid_lens = json.load(fp)
input_tensor = torch.tensor(input_list)
label_tensor = torch.tensor(label_list).float()
return input_tensor, label_tensor, valid_lens
\ No newline at end of file
......@@ -3,6 +3,7 @@ from torch.utils.data import DataLoader
from utils.registery import DATASET_REGISTRY
from .CoordinatesData import CoordinatesData
from .SLData import SLData
def build_dataset(cfg):
......
......@@ -94,6 +94,7 @@ def build_dataset(img_dir, go_res_dir, label_dir, top_text_list, skip_list, save
label_res = load_json(label_json_path)
# 开票日期 发票代码 机打号码 车辆类型 电话
# 发动机号码 车架号 帐号 开户银行 小写
test_group_id = [1, 2, 5, 9, 20]
group_list = []
for group_id in test_group_id:
......
import copy
import json
import os
import random
import uuid
import cv2
import pandas as pd
from tools import get_file_paths, load_json
def clean_go_res(go_res_dir):
max_seq_count = None
seq_sum = 0
file_count = 0
go_res_json_paths = get_file_paths(go_res_dir, ['.json', ])
for go_res_json_path in go_res_json_paths:
print('Info: start {0}'.format(go_res_json_path))
remove_key_set = set()
go_res = load_json(go_res_json_path)
for key, (_, text) in go_res.items():
if text.strip() == '':
remove_key_set.add(key)
print(text)
if len(remove_key_set) > 0:
for del_key in remove_key_set:
del go_res[del_key]
go_res_list = sorted(list(go_res.values()), key=lambda x: (x[0][1], x[0][0]), reverse=False)
with open(go_res_json_path, 'w') as fp:
json.dump(go_res_list, fp)
print('Rerewirte {0}'.format(go_res_json_path))
seq_sum += len(go_res_list)
file_count += 1
if max_seq_count is None or len(go_res_list) > max_seq_count:
max_seq_count = len(go_res_list)
max_seq_file_name = go_res_json_path
seq_lens_mean = seq_sum // file_count
return max_seq_count, seq_lens_mean, max_seq_file_name
def text_statistics(go_res_dir):
"""
Args:
go_res_dir: str 通用OCR的JSON文件夹
Returns: list 出现次数最多的文本及其次数
"""
json_count = 0
text_dict = {}
go_res_json_paths = get_file_paths(go_res_dir, ['.json', ])
for go_res_json_path in go_res_json_paths:
print('Info: start {0}'.format(go_res_json_path))
json_count += 1
go_res = load_json(go_res_json_path)
for _, text in go_res.values():
if text in text_dict:
text_dict[text] += 1
else:
text_dict[text] = 1
top_text_list = []
# 按照次数排序
for text, count in sorted(text_dict.items(), key=lambda x: x[1], reverse=True):
if text == '':
continue
# 丢弃:次数少于总数的2/3
if count <= json_count // 3:
break
top_text_list.append((text, count))
return top_text_list
def build_anno_file(dataset_dir, anno_file_path):
img_list = os.listdir(dataset_dir)
random.shuffle(img_list)
df = pd.DataFrame(columns=['name'])
df['name'] = img_list
df.to_csv(anno_file_path)
def build_dataset(img_dir, go_res_dir, label_dir, top_text_list, skip_list, save_dir):
"""
Args:
img_dir: str 图片目录
go_res_dir: str 通用OCR的JSON保存目录
label_dir: str 标注的JSON保存目录
top_text_list: list 出现次数最多的文本及其次数
skip_list: list 跳过的图片列表
save_dir: str 数据集保存目录
"""
if os.path.exists(save_dir):
return
else:
os.makedirs(save_dir, exist_ok=True)
# 开票日期 发票代码 机打号码 车辆类型 电话
# 发动机号码 车架号 帐号 开户银行 小写
group_cn_list = ['开票日期', '发票代码', '机打号码', '车辆类型', '电话', '发动机号码', '车架号', '帐号', '开户银行', '小写']
test_group_id = [1, 2, 5, 9, 20, 15, 16, 22, 24, 28]
for img_name in sorted(os.listdir(img_dir)):
if img_name in skip_list:
print('Info: skip {0}'.format(img_name))
continue
print('Info: start {0}'.format(img_name))
image_path = os.path.join(img_dir, img_name)
img = cv2.imread(image_path)
h, w, _ = img.shape
base_image_name, _ = os.path.splitext(img_name)
go_res_json_path = os.path.join(go_res_dir, '{0}.json'.format(base_image_name))
go_res_list = load_json(go_res_json_path)
valid_lens = len(go_res_list)
top_text_idx_set = set()
for top_text, _ in top_text_list:
for go_idx, (_, text) in enumerate(go_res_list):
if text == top_text:
top_text_idx_set.add(go_idx)
break
label_json_path = os.path.join(label_dir, '{0}.json'.format(base_image_name))
label_res = load_json(label_json_path)
group_list = []
for group_id in test_group_id:
for item in label_res.get("shapes", []):
if item.get("group_id") == group_id:
x_list = []
y_list = []
for point in item['points']:
x_list.append(point[0])
y_list.append(point[1])
group_list.append([min(x_list) + (max(x_list) - min(x_list))/2, min(y_list) + (max(y_list) - min(y_list))/2])
break
else:
group_list.append(None)
go_center_list = []
for (x0, y0, x1, y1, x2, y2, x3, y3), _ in go_res_list:
xmin = min(x0, x1, x2, x3)
ymin = min(y0, y1, y2, y3)
xmax = max(x0, x1, x2, x3)
ymax = max(y0, y1, y2, y3)
xcenter = xmin + (xmax - xmin)/2
ycenter = ymin + (ymax - ymin)/2
go_center_list.append((xcenter, ycenter))
label_idx_dict = dict()
for label_idx, label_center_list in enumerate(group_list):
if isinstance(label_center_list, list):
min_go_key = None
min_length = None
for go_idx, (go_x_center, go_y_center) in enumerate(go_center_list):
if go_idx in top_text_idx_set or go_idx in label_idx_dict:
continue
length = abs(go_x_center-label_center_list[0])+abs(go_y_center-label_center_list[1])
if min_go_key is None or length < min_length:
min_go_key = go_idx
min_length = length
if min_go_key is not None:
label_idx_dict[min_go_key] = label_idx
X = list()
y_true = list()
for i in range(200):
if i >= valid_lens:
X.append([0., 0., 0., 0., 0., 0., 0., 0., 0.])
y_true.append([0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
elif i in top_text_idx_set:
(x0, y0, x1, y1, x2, y2, x3, y3), _ = go_res_list[i]
X.append([1., x0/w, y0/h, x1/w, y1/h, x2/w, y2/h, x3/w, y3/h])
y_true.append([0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
elif i in label_idx_dict:
(x0, y0, x1, y1, x2, y2, x3, y3), _ = go_res_list[i]
X.append([0., x0/w, y0/h, x1/w, y1/h, x2/w, y2/h, x3/w, y3/h])
base_label_list = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
base_label_list[label_idx_dict[i]] = 1
y_true.append(base_label_list)
else:
(x0, y0, x1, y1, x2, y2, x3, y3), _ = go_res_list[i]
X.append([0., x0/w, y0/h, x1/w, y1/h, x2/w, y2/h, x3/w, y3/h])
y_true.append([0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
all_data = [X, y_true, valid_lens]
with open(os.path.join(save_dir, '{0}.json'.format(uuid.uuid3(uuid.NAMESPACE_DNS, img_name))), 'w') as fp:
json.dump(all_data, fp)
# print('top text find:')
# for i in top_text_idx_set:
# _, text = go_res_list[i]
# print(text)
# print('-------------')
# print('label value find:')
# for k, v in label_idx_dict.items():
# _, text = go_res_list[k]
# print('{0}: {1}'.format(group_cn_list[v], text))
# break
if __name__ == '__main__':
base_dir = '/Users/zhouweiqi/Downloads/gcfp/data'
go_dir = os.path.join(base_dir, 'go_res')
dataset_save_dir = os.path.join(base_dir, 'dataset2')
label_dir = os.path.join(base_dir, 'labeled')
train_go_path = os.path.join(go_dir, 'train')
train_image_path = os.path.join(label_dir, 'train', 'image')
train_label_path = os.path.join(label_dir, 'train', 'label')
train_dataset_dir = os.path.join(dataset_save_dir, 'train')
train_anno_file_path = os.path.join(dataset_save_dir, 'train.csv')
valid_go_path = os.path.join(go_dir, 'valid')
valid_image_path = os.path.join(label_dir, 'valid', 'image')
valid_label_path = os.path.join(label_dir, 'valid', 'label')
valid_dataset_dir = os.path.join(dataset_save_dir, 'valid')
valid_anno_file_path = os.path.join(dataset_save_dir, 'valid.csv')
# max_seq_lens, seq_lens_mean, max_seq_file_name = clean_go_res(go_dir)
# print(max_seq_lens) # 152
# print(max_seq_file_name) # CH-B101805176_page_2_img_0.json
# print(seq_lens_mean) # 92
# top_text_list = text_statistics(go_dir)
# for t in top_text_list:
# print(t)
filter_from_top_text_list = [
('机器编号', 496),
('购买方名称', 496),
('合格证号', 495),
('进口证明书号', 495),
('机打代码', 494),
('车辆类型', 492),
('完税凭证号码', 492),
('机打号码', 491),
('发动机号码', 491),
('主管税务', 491),
('价税合计', 489),
('机关及代码', 489),
('销货单位名称', 486),
('厂牌型号', 485),
('产地', 485),
('商检单号', 483),
('电话', 476),
('开户银行', 472),
('车辆识别代号/车架号码', 463),
('身份证号码', 454),
('吨位', 452),
('备注:一车一票', 439),
('地', 432),
('账号', 431),
('统一社会信用代码/', 424),
('限乘人数', 404),
('税额', 465),
('址', 392)
]
skip_list_train = [
'CH-B101910792-page-12.jpg',
'CH-B101655312-page-13.jpg',
'CH-B102278656.jpg',
'CH-B101846620_page_1_img_0.jpg',
'CH-B103062528-0.jpg',
'CH-B102613120-3.jpg',
'CH-B102997980-3.jpg',
'CH-B102680060-3.jpg',
# 'CH-B102995500-2.jpg', # 没value
]
skip_list_valid = [
'CH-B102897920-2.jpg',
'CH-B102551284-0.jpg',
'CH-B102879376-2.jpg',
'CH-B101509488-page-16.jpg',
'CH-B102708352-2.jpg',
]
build_dataset(train_image_path, train_go_path, train_label_path, filter_from_top_text_list, skip_list_train, train_dataset_dir)
build_anno_file(train_dataset_dir, train_anno_file_path)
build_dataset(valid_image_path, valid_go_path, valid_label_path, filter_from_top_text_list, skip_list_valid, valid_dataset_dir)
build_anno_file(valid_dataset_dir, valid_anno_file_path)
......@@ -2,6 +2,7 @@ import copy
import torch
import inspect
from utils.registery import LOSS_REGISTRY
from utils import sequence_mask
from torchvision.ops import sigmoid_focal_loss
class SigmoidFocalLoss(torch.nn.modules.loss._WeightedLoss):
......@@ -21,9 +22,31 @@ class SigmoidFocalLoss(torch.nn.modules.loss._WeightedLoss):
def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
return sigmoid_focal_loss(inputs, targets, self.alpha, self.gamma, self.reduction)
class MaskedSigmoidFocalLoss(torch.nn.modules.loss._WeightedLoss):
def __init__(self,
weight= None,
size_average=None,
reduce=None,
reduction: str = 'mean',
alpha: float = 0.25,
gamma: float = 2):
super().__init__(weight, size_average, reduce, reduction)
self.alpha = alpha
self.gamma = gamma
self.reduction = reduction
def forward(self, inputs: torch.Tensor, targets: torch.Tensor, valid_lens) -> torch.Tensor:
weights = torch.ones_like(targets)
weights = sequence_mask(weights, valid_lens)
unweighted_loss = sigmoid_focal_loss(inputs, targets, self.alpha, self.gamma, reduction='none')
weighted_loss = (unweighted_loss * weights).mean(dim=-1)
return weighted_loss
def register_sigmoid_focal_loss():
LOSS_REGISTRY.register()(SigmoidFocalLoss)
LOSS_REGISTRY.register()(MaskedSigmoidFocalLoss)
def register_torch_loss():
......
......@@ -3,6 +3,7 @@ from utils import MODEL_REGISTRY
from .mlp import MLPModel
from .vit import VisionTransformer
from .seq_labeling import SLTransformer
def build_model(cfg):
......
import math
from functools import partial
from collections import OrderedDict
import torch
import torch.nn as nn
from utils.registery import MODEL_REGISTRY
from utils import sequence_mask
def masked_softmax(X, valid_lens):
"""Perform softmax operation by masking elements on the last axis.
Defined in :numref:`sec_attention-scoring-functions`"""
# `X`: 3D tensor, `valid_lens`: 1D or 2D tensor
if valid_lens is None:
return nn.functional.softmax(X, dim=-1)
else:
# [batch_size, num_heads, seq_len, seq_len]
shape = X.shape
if valid_lens.dim() == 1:
valid_lens = torch.repeat_interleave(valid_lens, shape[2])
else:
valid_lens = valid_lens.reshape(-1)
# On the last axis, replace masked elements with a very large negative
# value, whose exponentiation outputs 0
X = sequence_mask(X.reshape(-1, shape[-1]), valid_lens, value=-1e6)
return nn.functional.softmax(X.reshape(shape), dim=-1)
class PositionalEncoding(nn.Module):
"""Positional encoding.
Defined in :numref:`sec_self-attention-and-positional-encoding`"""
def __init__(self, embed_dim, drop_ratio, max_len=1000):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(drop_ratio)
# Create a long enough `P`
self.P = torch.zeros((1, max_len, embed_dim))
X = torch.arange(max_len, dtype=torch.float32).reshape(
-1, 1) / torch.pow(10000, torch.arange(
0, embed_dim, 2, dtype=torch.float32) / embed_dim)
self.P[:, :, 0::2] = torch.sin(X)
self.P[:, :, 1::2] = torch.cos(X)
def forward(self, X):
X = X + self.P[:, :X.shape[1], :].to(X.device)
return self.dropout(X)
def _init_vit_weights(m):
"""
ViT weight initialization
:param m: module
"""
if isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight, std=.01)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out")
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.LayerNorm):
nn.init.zeros_(m.bias)
nn.init.ones_(m.weight)
def drop_path(x, drop_prob: float = 0., training: bool = False):
"""
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
'survival rate' as the argument.
"""
if drop_prob == 0. or not training:
return x
keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
random_tensor.floor_() # binarize
output = x.div(keep_prob) * random_tensor
return output
class DropPath(nn.Module):
"""
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
class Attention(nn.Module):
def __init__(self,
dim, # 输入token的dim
num_heads=8,
qkv_bias=False,
qk_scale=None,
attn_drop_ratio=0.,
proj_drop_ratio=0.):
super(Attention, self).__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop_ratio)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop_ratio)
def forward(self, x, valid_lens):
# [batch_size, seq_len, total_embed_dim]
B, N, C = x.shape
# qkv(): -> [batch_size, seq_len, 3 * total_embed_dim]
# reshape: -> [batch_size, seq_len, 3, num_heads, embed_dim_per_head]
# permute: -> [3, batch_size, num_heads, seq_len, embed_dim_per_head]
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
# [batch_size, num_heads, seq_len, embed_dim_per_head]
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
# transpose: -> [batch_size, num_heads, embed_dim_per_head, seq_len]
# @: multiply -> [batch_size, num_heads, seq_len, seq_len]
attn = (q @ k.transpose(-2, -1)) * self.scale
# attn = attn.softmax(dim=-1)
attn = masked_softmax(attn, valid_lens)
attn = self.attn_drop(attn)
# @: multiply -> [batch_size, num_heads, seq_len, embed_dim_per_head]
# transpose: -> [batch_size, seq_len, num_heads, embed_dim_per_head]
# reshape: -> [batch_size, seq_len, total_embed_dim]
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class Mlp(nn.Module):
"""
MLP as used in Vision Transformer, MLP-Mixer and related networks
"""
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class Block(nn.Module):
def __init__(self,
dim,
num_heads,
mlp_ratio=4.,
qkv_bias=False,
qk_scale=None,
drop_ratio=0.,
attn_drop_ratio=0.,
drop_path_ratio=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm):
super(Block, self).__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
attn_drop_ratio=attn_drop_ratio, proj_drop_ratio=drop_ratio)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path_ratio) if drop_path_ratio > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop_ratio)
def forward(self, x, valid_lens):
# [batch_size, seq_len, total_embed_dim]
x = x + self.drop_path(self.attn(self.norm1(x), valid_lens))
# [batch_size, seq_len, total_embed_dim]
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
@MODEL_REGISTRY.register()
class SLTransformer(nn.Module):
def __init__(self,
seq_lens=200,
num_classes=1000,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4.0,
qkv_bias=True,
qk_scale=None,
drop_ratio=0.,
attn_drop_ratio=0.,
drop_path_ratio=0.,
norm_layer=None,
act_layer=None,
):
"""
Args:
num_classes (int): number of classes for classification head
embed_dim (int): embedding dimension
depth (int): depth of transformer
num_heads (int): number of attention heads
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
qkv_bias (bool): enable bias for qkv if True
qk_scale (float): override default qk scale of head_dim ** -0.5 if set
drop_ratio (float): dropout rate
attn_drop_ratio (float): attention dropout rate
drop_path_ratio (float): stochastic depth rate
norm_layer: (nn.Module): normalization layer
"""
super(SLTransformer, self).__init__()
self.num_classes = num_classes
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
act_layer = act_layer or nn.GELU
# self.pos_embed = PositionalEncoding(self.embed_dim, drop_ratio, max_len=seq_lens)
self.pos_embed = nn.Parameter(torch.zeros(1, seq_lens, embed_dim))
self.pos_drop = nn.Dropout(p=drop_ratio)
dpr = [x.item() for x in torch.linspace(0, drop_path_ratio, depth)] # stochastic depth decay rule
self.blocks = nn.Sequential(*[
Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
drop_ratio=drop_ratio, attn_drop_ratio=attn_drop_ratio, drop_path_ratio=dpr[i],
norm_layer=norm_layer, act_layer=act_layer)
for i in range(depth)
])
self.norm = norm_layer(embed_dim)
# Classifier head(s)
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
# Weight init
nn.init.trunc_normal_(self.pos_embed, std=0.02)
self.apply(_init_vit_weights)
def forward(self, x, valid_lens):
# x: [B, seq_len, embed_dim]
# valid_lens: [B, ]
# TODO sin/cos位置编码?
# 因为位置编码值在-1和1之间,
# 因此嵌入值乘以嵌入维度的平方根进行缩放,
# 然后再与位置编码相加。
# x = self.pos_embed(x * math.sqrt(self.embed_dim))
# 参数的位置编码
x = self.pos_drop(x + self.pos_embed)
# [batch_size, seq_len, total_embed_dim]
for block in self.blocks:
x = block(x, valid_lens)
# x = self.blocks(x, valid_lens)
x = self.norm(x)
# [batch_size, seq_len, num_classes]
x = self.head(x)
return x
\ No newline at end of file
......@@ -3,6 +3,7 @@ import copy
from utils.registery import SOLVER_REGISTRY
from .mlp_solver import MLPSolver
from .vit_solver import VITSolver
from .sl_solver import SLSolver
def build_solver(cfg):
......
import copy
import os
import torch
from data import build_dataloader
from loss import build_loss
from model import build_model
from optimizer import build_lr_scheduler, build_optimizer
from utils import SOLVER_REGISTRY, get_logger_and_log_dir
from utils import sequence_mask
from sklearn.metrics import confusion_matrix, accuracy_score, classification_report
@SOLVER_REGISTRY.register()
class SLSolver(object):
def __init__(self, cfg):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.cfg = copy.deepcopy(cfg)
self.train_loader, self.val_loader = build_dataloader(cfg)
self.train_loader_size, self.val_loader_size = len(self.train_loader), len(self.val_loader)
self.train_dataset_size, self.val_dataset_size = len(self.train_loader.dataset), len(self.val_loader.dataset)
# BatchNorm ?
self.model = build_model(cfg).to(self.device)
self.loss_fn = build_loss(cfg)
self.optimizer = build_optimizer(cfg)(self.model.parameters(), **cfg['solver']['optimizer']['args'])
self.hyper_params = cfg['solver']['args']
self.base_on = self.hyper_params['base_on']
self.model_path = self.hyper_params['model_path']
try:
self.epoch = self.hyper_params['epoch']
except Exception:
raise 'should contain epoch in {solver.args}'
self.logger, self.log_dir = get_logger_and_log_dir(**cfg['solver']['logger'])
def accuracy(self, y_pred, y_true, valid_lens, thresholds=0.5):
# [batch_size, seq_len, num_classes]
y_pred_sigmoid = torch.nn.Sigmoid()(y_pred)
# [batch_size, seq_len]
y_pred_idx = torch.argmax(y_pred_sigmoid, dim=-1) + 1
# [batch_size, seq_len]
y_pred_is_other = (torch.amax(y_pred_sigmoid, dim=-1) > thresholds).int()
y_pred_rebuild = torch.multiply(y_pred_idx, y_pred_is_other)
y_true_idx = torch.argmax(y_true, dim=-1) + 1
y_true_is_other = torch.sum(y_true, dim=-1).int()
y_true_rebuild = torch.multiply(y_true_idx, y_true_is_other)
masked_y_true_rebuild = sequence_mask(y_true_rebuild, valid_lens, value=-1)
return torch.sum((y_pred_rebuild == masked_y_true_rebuild).int()).item()
def train_loop(self):
self.model.train()
seq_lens_sum = torch.zeros(1).to(self.device)
train_loss = torch.zeros(1).to(self.device)
correct = torch.zeros(1).to(self.device)
for batch, (X, y, valid_lens) in enumerate(self.train_loader):
X, y = X.to(self.device), y.to(self.device)
pred = self.model(X, valid_lens)
# [batch_size, seq_len, num_classes]
loss = self.loss_fn(pred, y, valid_lens)
train_loss += loss.sum()
if batch % 100 == 0:
loss_value, current = loss.sum().item(), batch
self.logger.info(f'train iteration: {current}/{self.train_loader_size}, train loss: {loss_value :.4f}')
self.optimizer.zero_grad()
loss.sum().backward()
self.optimizer.step()
seq_lens_sum += valid_lens.sum()
correct += self.accuracy(pred, y, valid_lens)
# correct /= self.train_dataset_size
correct /= seq_lens_sum
train_loss /= self.train_loader_size
self.logger.info(f'train accuracy: {correct.item() :.4f}, train mean loss: {train_loss.item() :.4f}')
@torch.no_grad()
def val_loop(self, t):
self.model.eval()
seq_lens_sum = torch.zeros(1).to(self.device)
val_loss = torch.zeros(1).to(self.device)
correct = torch.zeros(1).to(self.device)
for X, y, valid_lens in self.val_loader:
X, y = X.to(self.device), y.to(self.device)
# pred = torch.nn.Sigmoid()(self.model(X))
pred = self.model(X, valid_lens)
# [batch_size, seq_len, num_classes]
loss = self.loss_fn(pred, y, valid_lens)
val_loss += loss.sum()
seq_lens_sum += valid_lens.sum()
correct += self.accuracy(pred, y, valid_lens)
# correct /= self.val_dataset_size
correct /= seq_lens_sum
val_loss /= self.val_loader_size
self.logger.info(f"val accuracy: {correct.item() :.4f}, val mean loss: {val_loss.item() :.4f}")
def save_checkpoint(self, epoch_id):
self.model.eval()
torch.save(self.model.state_dict(), os.path.join(self.log_dir, f'ckpt_epoch_{epoch_id}.pt'))
def run(self):
if isinstance(self.base_on, str) and os.path.exists(self.base_on):
self.model.load_state_dict(torch.load(self.base_on))
self.logger.info(f'==> Load Model from {self.base_on}')
self.logger.info('==> Start Training')
print(self.model)
lr_scheduler = build_lr_scheduler(self.cfg)(self.optimizer, **self.cfg['solver']['lr_scheduler']['args'])
for t in range(self.epoch):
self.logger.info(f'==> epoch {t + 1}')
self.train_loop()
self.val_loop(t + 1)
self.save_checkpoint(t + 1)
lr_scheduler.step()
self.logger.info('==> End Training')
# def run(self):
# from torch.nn import functional
# y = functional.one_hot(torch.randint(0, 10, (8, 100)), 10)
# valid_lens = torch.randint(50, 100, (8, ))
# print(valid_lens)
# pred = functional.one_hot(torch.randint(0, 10, (8, 100)), 10)
# print(self.accuracy(pred, y, valid_lens))
def evaluate(self):
if isinstance(self.model_path, str) and os.path.exists(self.model_path):
self.model.load_state_dict(torch.load(self.model_path))
self.logger.info(f'==> Load Model from {self.model_path}')
else:
return
self.model.eval()
label_true_list = []
label_pred_list = []
for X, y in self.val_loader:
X, y_true = X.to(self.device), y.to(self.device)
# pred = torch.nn.Sigmoid()(self.model(X))
pred = self.model(X)
y_pred = torch.nn.Sigmoid()(pred)
y_pred_idx = torch.argmax(y_pred, dim=1) + 1
y_pred_is_other = (torch.amax(y_pred, dim=1) > 0.5).int()
y_pred_rebuild = torch.multiply(y_pred_idx, y_pred_is_other)
y_true_idx = torch.argmax(y_true, dim=1) + 1
y_true_is_other = torch.sum(y_true, dim=1)
y_true_rebuild = torch.multiply(y_true_idx, y_true_is_other)
label_true_list.extend(y_true_rebuild.cpu().numpy().tolist())
label_pred_list.extend(y_pred_rebuild.cpu().numpy().tolist())
acc = accuracy_score(label_true_list, label_pred_list)
cm = confusion_matrix(label_true_list, label_pred_list)
report = classification_report(label_true_list, label_pred_list)
print(acc)
print(cm)
print(report)
import torch
from .registery import *
from .logger import get_logger_and_log_dir
__all__ = [
'Registry',
'get_logger_and_log_dir',
'sequence_mask',
]
def sequence_mask(X, valid_len, value=0):
"""Mask irrelevant entries in sequences.
Defined in :numref:`sec_seq2seq_decoder`"""
maxlen = X.size(1)
mask = torch.arange((maxlen), dtype=torch.float32, device=X.device)[None, :] < valid_len[:, None]
X[~mask] = value
return X
......
Styling with Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!