bcb17d0f by 周伟奇

first commit

0 parents
1 .DS_Store
2 logs/
1 ## Intro
2
3 测试坐标分类的信息结构化方案
4
5 ## Useage
6
7 ```
8 pip install -r requirements.txt
9
10 python3 main.py --config=path/to/config.yaml
11 ```
12
1 seed: 3407
2
3 dataset:
4 name: 'CoordinatesData'
5 args:
6 data_root: '/Users/zhouweiqi/Downloads/gcfp/data/dataset'
7 train_anno_file: '/Users/zhouweiqi/Downloads/gcfp/data/dataset/train.csv'
8 val_anno_file: '/Users/zhouweiqi/Downloads/gcfp/data/dataset/valid.csv'
9
10 dataloader:
11 batch_size: 32
12 num_workers: 4
13 pin_memory: true
14 shuffle: true
15
16 model:
17 name: 'MLPModel'
18 args:
19 activation: 'relu'
20
21 solver:
22 name: 'MLPSolver'
23 args:
24 epoch: 100
25
26 optimizer:
27 name: 'Adam'
28 args:
29 lr: !!float 1e-4
30 weight_decay: !!float 5e-5
31
32 lr_scheduler:
33 name: 'StepLR'
34 args:
35 step_size: 15
36 gamma: 0.1
37
38 loss:
39 name: 'SigmoidFocalLoss'
40 # name: 'CrossEntropyLoss'
41 args:
42 reduction: "mean"
43
44 logger:
45 log_root: '/Users/zhouweiqi/Downloads/test/logs'
46 suffix: 'mlp'
...\ No newline at end of file ...\ No newline at end of file
1 import os
2 import json
3 import torch
4 from torch.utils.data import DataLoader, Dataset
5 import pandas as pd
6 from utils.registery import DATASET_REGISTRY
7
8
9 @DATASET_REGISTRY.register()
10 class CoordinatesData(Dataset):
11
12 def __init__(self,
13 data_root: str = '/Users/zhouweiqi/Downloads/gcfp/data/dataset',
14 anno_file: str = 'train.csv',
15 phase: str = 'train'):
16 self.data_root = data_root
17 self.df = pd.read_csv(anno_file)
18 self.phase = phase
19
20
21 def __len__(self):
22 return len(self.df)
23
24 def __getitem__(self, idx):
25 series = self.df.iloc[idx]
26 name = series['name']
27
28 with open(os.path.join(self.data_root, self.phase, name), 'r') as fp:
29 input_coordinates_list, label_list = json.load(fp)
30
31 input_coordinates = torch.tensor(input_coordinates_list)
32 label = torch.tensor(label_list).float()
33
34 return input_coordinates, label
...\ No newline at end of file ...\ No newline at end of file
1 from .builder import build_dataloader
2
3 __all__ = ['build_dataloader']
1 import copy
2 from torch.utils.data import DataLoader
3 from utils.registery import DATASET_REGISTRY
4
5 from .CoordinatesData import CoordinatesData
6
7
8 def build_dataset(cfg):
9
10 dataset_cfg = copy.deepcopy(cfg)
11 try:
12 dataset_cfg = dataset_cfg['dataset']
13 except Exception:
14 raise 'should contain {dataset}!'
15
16 train_cfg = copy.deepcopy(dataset_cfg)
17 val_cfg = copy.deepcopy(dataset_cfg)
18 train_cfg['args']['anno_file'] = train_cfg['args'].pop('train_anno_file')
19 train_cfg['args'].pop('val_anno_file', None)
20 train_cfg['args']['phase'] = 'train'
21 val_cfg['args']['anno_file'] = val_cfg['args'].pop('val_anno_file')
22 val_cfg['args'].pop('train_anno_file', None)
23 val_cfg['args']['phase'] = 'valid'
24
25 train_data = DATASET_REGISTRY.get(cfg['dataset']['name'])(**train_cfg['args'])
26 val_data = DATASET_REGISTRY.get(cfg['dataset']['name'])(**val_cfg['args'])
27
28 return train_data, val_data
29
30
31
32 def build_dataloader(cfg):
33
34 dataloader_cfg = copy.deepcopy(cfg)
35 try:
36 dataloader_cfg = cfg['dataloader']
37 except Exception:
38 raise 'should contain {dataloader}!'
39
40 train_ds, val_ds = build_dataset(cfg)
41
42 train_loader = DataLoader(train_ds,
43 **dataloader_cfg)
44
45 val_loader = DataLoader(val_ds,
46 **dataloader_cfg)
47
48 return train_loader, val_loader
49
1 import os
2 import cv2
3 import uuid
4 import json
5 import random
6 import copy
7
8 import pandas as pd
9 from tools import get_file_paths, load_json
10
11
12 def text_statistics(go_res_dir):
13 """
14 Args:
15 go_res_dir: str 通用OCR的JSON文件夹
16 Returns: list 出现次数最多的文本及其次数
17 """
18 json_count = 0
19 text_dict = {}
20 go_res_json_paths = get_file_paths(go_res_dir, ['.json', ])
21 for go_res_json_path in go_res_json_paths:
22 print('Info: start {0}'.format(go_res_json_path))
23 json_count += 1
24 go_res = load_json(go_res_json_path)
25 for _, text in go_res.values():
26 if text in text_dict:
27 text_dict[text] += 1
28 else:
29 text_dict[text] = 1
30 top_text_list = []
31 # 按照次数排序
32 for text, count in sorted(text_dict.items(), key=lambda x: x[1], reverse=True):
33 if text == '':
34 continue
35 # 丢弃:次数少于总数的2/3
36 if count <= json_count // 3:
37 break
38 top_text_list.append((text, count))
39 return top_text_list
40
41 def build_anno_file(dataset_dir, anno_file_path):
42 img_list = os.listdir(dataset_dir)
43 random.shuffle(img_list)
44 df = pd.DataFrame(columns=['name'])
45 df['name'] = img_list
46 df.to_csv(anno_file_path)
47
48 def build_dataset(img_dir, go_res_dir, label_dir, top_text_list, skip_list, save_dir):
49 """
50 Args:
51 img_dir: str 图片目录
52 go_res_dir: str 通用OCR的JSON保存目录
53 label_dir: str 标注的JSON保存目录
54 top_text_list: list 出现次数最多的文本及其次数
55 skip_list: list 跳过的图片列表
56 save_dir: str 数据集保存目录
57 """
58 # if os.path.exists(save_dir):
59 # return
60 # else:
61 # os.makedirs(save_dir, exist_ok=True)
62
63 count = 0
64 un_count = 0
65 top_text_count = len(top_text_list)
66 for img_name in sorted(os.listdir(img_dir)):
67 if img_name in skip_list:
68 print('Info: skip {0}'.format(img_name))
69 continue
70
71 print('Info: start {0}'.format(img_name))
72 image_path = os.path.join(img_dir, img_name)
73 img = cv2.imread(image_path)
74 h, w, _ = img.shape
75 base_image_name, _ = os.path.splitext(img_name)
76 go_res_json_path = os.path.join(go_res_dir, '{0}.json'.format(base_image_name))
77 go_res = load_json(go_res_json_path)
78
79 input_key_list = []
80 not_found_count = 0
81 go_key_set = set()
82 for top_text, _ in top_text_list:
83 for go_key, ((x0, y0, x1, y1, x2, y2, x3, y3), text) in go_res.items():
84 if text == top_text:
85 input_key_list.append([x0/w, y0/h, x1/w, y1/h, x2/w, y2/h, x3/w, y3/h])
86 go_key_set.add(go_key)
87 break
88 else:
89 not_found_count += 1
90 input_key_list.append([0, 0, 0, 0, 0, 0, 0, 0])
91 if not_found_count >= top_text_count // 3:
92 print('Info: skip {0} : {1}/{2}'.format(img_name, not_found_count, top_text_count))
93 continue
94
95 label_json_path = os.path.join(label_dir, '{0}.json'.format(base_image_name))
96 label_res = load_json(label_json_path)
97
98 # 开票日期 发票代码 机打号码 车辆类型 电话
99 test_group_id = [1, 2, 5, 9, 20]
100 group_list = []
101 for group_id in test_group_id:
102 for item in label_res.get("shapes", []):
103 if item.get("group_id") == group_id:
104 x_list = []
105 y_list = []
106 for point in item['points']:
107 x_list.append(point[0])
108 y_list.append(point[1])
109 group_list.append([min(x_list) + (max(x_list) - min(x_list))/2, min(y_list) + (max(y_list) - min(y_list))/2])
110 break
111 else:
112 group_list.append(None)
113
114 go_center_list = []
115 for go_key, ((x0, y0, x1, y1, x2, y2, x3, y3), _) in go_res.items():
116 if go_key in go_key_set:
117 continue
118 xmin = min(x0, x1, x2, x3)
119 ymin = min(y0, y1, y2, y3)
120 xmax = max(x0, x1, x2, x3)
121 ymax = max(y0, y1, y2, y3)
122 xcenter = xmin + (xmax - xmin)/2
123 ycenter = ymin + (ymax - ymin)/2
124 go_center_list.append([xcenter, ycenter, go_key])
125
126 group_go_key_list = []
127 for label_center_list in group_list:
128 if isinstance(label_center_list, list):
129 min_go_key = None
130 min_length = None
131 for go_x_center, go_y_center, go_key in go_center_list:
132 if go_key in go_key_set:
133 continue
134 length = abs(go_x_center-label_center_list[0])+abs(go_y_center-label_center_list[1])
135 if min_go_key is None or length < min_length:
136 min_go_key = go_key
137 min_length = length
138 if min_go_key is not None:
139 go_key_set.add(min_go_key)
140 group_go_key_list.append(min_go_key)
141 else:
142 group_go_key_list.append(None)
143 else:
144 group_go_key_list.append(None)
145
146 src_label_list = [0 for _ in test_group_id]
147 for idx, find_go_key in enumerate(group_go_key_list):
148 if find_go_key is None:
149 continue
150 (x0, y0, x1, y1, x2, y2, x3, y3), _ = go_res[find_go_key]
151 input_list = copy.deepcopy(input_key_list)
152 input_list.append([x0/w, y0/h, x1/w, y1/h, x2/w, y2/h, x3/w, y3/h])
153
154 input_label = copy.deepcopy(src_label_list)
155 input_label[idx] = 1
156 # with open(os.path.join(save_dir, '{0}.json'.format(uuid.uuid3(uuid.NAMESPACE_DNS, '{0}-{1}'.format(img_name, find_go_key)))), 'w') as fp:
157 # json.dump([input_list, input_label], fp)
158 count += 1
159
160 for go_key, ((x0, y0, x1, y1, x2, y2, x3, y3), _) in go_res.items():
161 if go_key in go_key_set:
162 continue
163 input_list = copy.deepcopy(input_key_list)
164 input_list.append([x0/w, y0/h, x1/w, y1/h, x2/w, y2/h, x3/w, y3/h])
165 # with open(os.path.join(save_dir, '{0}.json'.format(uuid.uuid3(uuid.NAMESPACE_DNS, '{0}-{1}'.format(img_name, go_key)))), 'w') as fp:
166 # json.dump([input_list, src_label_list], fp)
167 un_count += 1
168
169 # break
170 print(count)
171 print(un_count)
172
173
174 if __name__ == '__main__':
175 base_dir = '/Users/zhouweiqi/Downloads/gcfp/data'
176 go_dir = os.path.join(base_dir, 'go_res')
177 dataset_save_dir = os.path.join(base_dir, 'dataset')
178 label_dir = os.path.join(base_dir, 'labeled')
179
180 train_go_path = os.path.join(go_dir, 'train')
181 train_image_path = os.path.join(label_dir, 'train', 'image')
182 train_label_path = os.path.join(label_dir, 'train', 'label')
183 train_dataset_dir = os.path.join(dataset_save_dir, 'train')
184 train_anno_file_path = os.path.join(dataset_save_dir, 'train.csv')
185
186 valid_go_path = os.path.join(go_dir, 'valid')
187 valid_image_path = os.path.join(label_dir, 'valid', 'image')
188 valid_label_path = os.path.join(label_dir, 'valid', 'label')
189 valid_dataset_dir = os.path.join(dataset_save_dir, 'valid')
190 valid_anno_file_path = os.path.join(dataset_save_dir, 'valid.csv')
191
192 # top_text_list = text_statistics(go_dir)
193 # for t in top_text_list:
194 # print(t)
195
196 filter_from_top_text_list = [
197 ('机器编号', 496),
198 ('购买方名称', 496),
199 ('合格证号', 495),
200 ('进口证明书号', 495),
201 ('机打代码', 494),
202 ('车辆类型', 492),
203 ('完税凭证号码', 492),
204 ('机打号码', 491),
205 ('发动机号码', 491),
206 ('主管税务', 491),
207 ('价税合计', 489),
208 ('机关及代码', 489),
209 ('销货单位名称', 486),
210 ('厂牌型号', 485),
211 ('产地', 485),
212 ('商检单号', 483),
213 ('电话', 476),
214 ('开户银行', 472),
215 ('车辆识别代号/车架号码', 463),
216 ('身份证号码', 454),
217 ('吨位', 452),
218 ('备注:一车一票', 439),
219 ('地', 432),
220 ('账号', 431),
221 ('统一社会信用代码/', 424),
222 ('限乘人数', 404),
223 ('税额', 465),
224 ('址', 392)
225 ]
226
227 skip_list_train = [
228 'CH-B101910792-page-12.jpg',
229 'CH-B101655312-page-13.jpg',
230 'CH-B102278656.jpg',
231 'CH-B101846620_page_1_img_0.jpg',
232 'CH-B103062528-0.jpg',
233 'CH-B102613120-3.jpg',
234 'CH-B102997980-3.jpg',
235 'CH-B102680060-3.jpg',
236 # 'CH-B102995500-2.jpg', # 没value
237 ]
238
239 skip_list_valid = [
240 'CH-B102897920-2.jpg',
241 'CH-B102551284-0.jpg',
242 'CH-B102879376-2.jpg',
243 'CH-B101509488-page-16.jpg',
244 'CH-B102708352-2.jpg',
245 ]
246
247 # build_dataset(train_image_path, train_go_path, train_label_path, filter_from_top_text_list, skip_list_train, train_dataset_dir)
248
249 build_dataset(valid_image_path, valid_go_path, valid_label_path, filter_from_top_text_list, skip_list_valid, valid_dataset_dir)
250
251 # build_anno_file(train_dataset_dir, train_anno_file_path)
252 # build_anno_file(valid_dataset_dir, valid_anno_file_path)
253
254
1 import json
2 import os
3
4
5 def get_exclude_paths(input_path, exclude_list=[]):
6 """
7 Args:
8 input_path: str 目标目录
9 exclude_list: list 排除文件或目录的相对位置
10 Returns: set 排除文件或目录的绝对路径集合
11 """
12 exclude_paths_set = set()
13 if os.path.isdir(input_path):
14 for path in exclude_list:
15 abs_path = path if os.path.isabs(path) else os.path.join(input_path, path)
16 if not os.path.exists(abs_path):
17 print('Warning: exclude path not exists: {0}'.format(abs_path))
18 continue
19 exclude_paths_set.add(abs_path)
20 return exclude_paths_set
21
22 def get_file_paths(input_path, suffix_list, exclude_list=[]):
23 """
24 Args:
25 input_path: str 目标目录
26 suffix_list: list 搜索的文件的后缀列表
27 exclude_list: list 排除文件或目录的相对位置
28 Returns: list 搜索到的相关文件绝对路径列表
29 """
30 exclude_paths_set = get_exclude_paths(input_path, exclude_list)
31 for parent, _, filenames in os.walk(input_path):
32 if parent in exclude_paths_set:
33 print('Info: exclude path: {0}'.format(parent))
34 continue
35 for filename in filenames:
36 for suffix in suffix_list:
37 if filename.endswith(suffix):
38 file_path = os.path.join(parent, filename)
39 break
40 else:
41 continue
42 if file_path in exclude_paths_set:
43 print('Info: exclude path: {0}'.format(file_path))
44 continue
45 yield file_path
46
47 def load_json(json_path):
48 """
49 Args:
50 json_path: str JSON文件路径
51 Returns: obj JSON对象
52 """
53 with open(json_path, 'r') as fp:
54 output = json.load(fp)
55 return output
...\ No newline at end of file ...\ No newline at end of file
1 from .builder import build_loss
2
3 __all__ = ['build_loss']
1 import copy
2 import torch
3 import inspect
4 from utils.registery import LOSS_REGISTRY
5 from torchvision.ops import sigmoid_focal_loss
6
7 class SigmoidFocalLoss(torch.nn.modules.loss._WeightedLoss):
8
9 def __init__(self,
10 weight= None,
11 size_average=None,
12 reduce=None,
13 reduction: str = 'mean',
14 alpha: float = 0.25,
15 gamma: float = 2):
16 super().__init__(weight, size_average, reduce, reduction)
17 self.alpha = alpha
18 self.gamma = gamma
19 self.reduction = reduction
20
21 def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
22 return sigmoid_focal_loss(inputs, targets, self.alpha, self.gamma, self.reduction)
23
24
25 def register_sigmoid_focal_loss():
26 LOSS_REGISTRY.register()(SigmoidFocalLoss)
27
28
29 def register_torch_loss():
30 for module_name in dir(torch.nn):
31 if module_name.startswith('__') or 'Loss' not in module_name:
32 continue
33 _loss = getattr(torch.nn, module_name)
34 if inspect.isclass(_loss) and issubclass(_loss, torch.nn.Module):
35 LOSS_REGISTRY.register()(_loss)
36
37 def build_loss(cfg):
38 register_sigmoid_focal_loss()
39 register_torch_loss()
40 loss_cfg = copy.deepcopy(cfg)
41 try:
42 loss_cfg = cfg['solver']['loss']
43 except Exception:
44 raise 'should contain {solver.loss}!'
45
46 # return sigmoid_focal_loss
47 return LOSS_REGISTRY.get(loss_cfg['name'])(**loss_cfg['args'])
...\ No newline at end of file ...\ No newline at end of file
1 import argparse
2 import torch
3 import yaml
4 from solver.builder import build_solver
5
6
7 def main():
8 parser = argparse.ArgumentParser()
9 parser.add_argument('--config', default='./config/mlp.yaml', type=str, help='config file')
10 args = parser.parse_args()
11
12 cfg = yaml.load(open(args.config, 'r').read(), Loader=yaml.FullLoader)
13 # print(cfg)
14 # print(torch.cuda.is_available())
15
16 solver = build_solver(cfg)
17 solver.run()
18
19
20 if __name__ == '__main__':
21 main()
1 from .builder import build_model
2
3 __all__ = ['build_model']
1 import copy
2 from utils import MODEL_REGISTRY
3
4 from .mlp import MLPModel
5
6
7 def build_model(cfg):
8 model_cfg = copy.deepcopy(cfg)
9 try:
10 model_cfg = model_cfg['model']
11 except Exception:
12 raise 'should contain {model}'
13
14 model = MODEL_REGISTRY.get(model_cfg['name'])(**model_cfg['args'])
15
16 return model
17
1 from abc import ABCMeta
2
3 import torch.nn as nn
4 from utils.registery import MODEL_REGISTRY
5
6
7 @MODEL_REGISTRY.register()
8 class MLPModel(nn.Module):
9
10 def __init__(self, activation):
11 super().__init__()
12 self.activation_fn = activation
13 self.flatten = nn.Flatten()
14 self.linear_relu_stack = nn.Sequential(
15 nn.Linear(29*8, 512),
16 nn.ReLU(),
17 nn.Linear(512, 256),
18 nn.ReLU(),
19 nn.Linear(256, 5),
20 nn.Sigmoid(),
21 )
22 self._initialize_weights()
23
24 def _initialize_weights(self):
25 for m in self.modules():
26 if isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
27 nn.init.xavier_uniform_(m.weight)
28 if m.bias is not None:
29 nn.init.constant_(m.bias, 0)
30 elif isinstance(m, (nn.LayerNorm, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
31 nn.init.constant_(m.weight, 1)
32 nn.init.constant_(m.bias, 0)
33 elif isinstance(m, nn.Linear):
34 nn.init.xavier_uniform_(m.weight)
35 if m.bias is not None:
36 nn.init.constant_(m.bias, 0)
37
38 def forward(self, x):
39 x = self.flatten(x)
40 logits = self.linear_relu_stack(x)
41 return logits
42
43
1 from .builder import build_optimizer, build_lr_scheduler
2
3 __all__ = ['build_optimizer', 'build_lr_scheduler']
1 import torch
2 import inspect
3 from utils.registery import OPTIMIZER_REGISTRY, LR_SCHEDULER_REGISTRY
4 import copy
5
6 def register_torch_optimizers():
7 """
8 Register all optimizers implemented by torch
9 """
10 for module_name in dir(torch.optim):
11 if module_name.startswith('__'):
12 continue
13 _optim = getattr(torch.optim, module_name)
14 if inspect.isclass(_optim) and issubclass(_optim, torch.optim.Optimizer):
15 OPTIMIZER_REGISTRY.register()(_optim)
16
17 def build_optimizer(cfg):
18 register_torch_optimizers()
19 optimizer_cfg = copy.deepcopy(cfg)
20 try:
21 optimizer_cfg = cfg['solver']['optimizer']
22 except Exception:
23 raise 'should contain {solver.optimizer}!'
24
25 return OPTIMIZER_REGISTRY.get(optimizer_cfg['name'])
26
27 def register_torch_lr_scheduler():
28 """
29 Register all lr_schedulers implemented by torch
30 """
31 for module_name in dir(torch.optim.lr_scheduler):
32 if module_name.startswith('__'):
33 continue
34
35 _scheduler = getattr(torch.optim.lr_scheduler, module_name)
36 if inspect.isclass(_scheduler) and issubclass(_scheduler, torch.optim.lr_scheduler._LRScheduler):
37 LR_SCHEDULER_REGISTRY.register()(_scheduler)
38
39 def build_lr_scheduler(cfg):
40 register_torch_lr_scheduler()
41 scheduler_cfg = copy.deepcopy(cfg)
42 try:
43 scheduler_cfg = cfg['solver']['lr_scheduler']
44 except Exception:
45 raise 'should contain {solver.lr_scheduler}!'
46 return LR_SCHEDULER_REGISTRY.get(scheduler_cfg['name'])
...\ No newline at end of file ...\ No newline at end of file
1 torch==1.13.0
2 torchvision==0.14.0
3 PyYaml==6.0
4 loguru==0.6.0
5 pandas==1.5.2
6 opencv-python==4.6.0.66
...\ No newline at end of file ...\ No newline at end of file
1 from .builder import build_solver
2
3 __all__ = ['build_solver']
1 import copy
2
3 from utils.registery import SOLVER_REGISTRY
4 from .mlp_solver import MLPSolver
5
6
7 def build_solver(cfg):
8 cfg = copy.deepcopy(cfg)
9
10 try:
11 solver_cfg = cfg['solver']
12 except Exception:
13 raise 'should contain {solver}!'
14
15 return SOLVER_REGISTRY.get(solver_cfg['name'])(cfg)
1 import os
2 import copy
3 import torch
4
5 from model import build_model
6 from data import build_dataloader
7 from optimizer import build_optimizer, build_lr_scheduler
8 from loss import build_loss
9 from utils import SOLVER_REGISTRY, get_logger_and_log_dir
10
11
12 @SOLVER_REGISTRY.register()
13 class MLPSolver(object):
14
15 def __init__(self, cfg):
16 self.cfg = copy.deepcopy(cfg)
17
18 self.train_loader, self.val_loader = build_dataloader(cfg)
19 self.train_loader_size, self.val_loader_size = len(self.train_loader), len(self.val_loader)
20 self.train_dataset_size, self.val_dataset_size = len(self.train_loader.dataset), len(self.val_loader.dataset)
21
22 # BatchNorm ?
23 self.model = build_model(cfg)
24
25 self.loss_fn = build_loss(cfg)
26
27 self.optimizer = build_optimizer(cfg)(self.model.parameters(), **cfg['solver']['optimizer']['args'])
28
29 self.hyper_params = cfg['solver']['args']
30 try:
31 self.epoch = self.hyper_params['epoch']
32 except Exception:
33 raise 'should contain epoch in {solver.args}'
34
35 self.logger, self.log_dir = get_logger_and_log_dir(**cfg['solver']['logger'])
36
37 @staticmethod
38 def evaluate(y_pred, y_true, thresholds=0.5):
39 y_pred_idx = torch.argmax(y_pred, dim=1) + 1
40 y_pred_is_other = (torch.amax(y_pred, dim=1) > 0.5).int()
41 y_pred_rebuild = torch.multiply(y_pred_idx, y_pred_is_other)
42
43 y_true_idx = torch.argmax(y_true, dim=1) + 1
44 y_true_is_other = torch.sum(y_true, dim=1)
45 y_true_rebuild = torch.multiply(y_true_idx, y_true_is_other)
46
47 return torch.sum((y_pred_rebuild == y_true_rebuild).int()).item()
48
49 def train_loop(self):
50 self.model.train()
51
52 train_loss = 0
53 for batch, (X, y) in enumerate(self.train_loader):
54 pred = self.model(X)
55
56 # loss = self.loss_fn(pred, y, reduction="mean")
57 loss = self.loss_fn(pred, y)
58 train_loss += loss.item()
59
60 if batch % 100 == 0:
61 loss_value, current = loss.item(), batch
62 self.logger.info(f'train iteration: {current}/{self.train_loader_size}, train loss: {loss_value :.4f}')
63
64 self.optimizer.zero_grad()
65 loss.backward()
66 self.optimizer.step()
67
68 train_loss /= self.train_loader_size
69 self.logger.info(f'train mean loss: {train_loss :.4f}')
70
71 @torch.no_grad()
72 def val_loop(self, t):
73 self.model.eval()
74
75 val_loss, correct = 0, 0
76 for X, y in self.val_loader:
77 pred = self.model(X)
78
79 correct += self.evaluate(pred, y)
80
81 loss = self.loss_fn(pred, y)
82 val_loss += loss.item()
83
84 correct /= self.val_dataset_size
85 val_loss /= self.val_loader_size
86
87 self.logger.info(f"val accuracy: {correct :.4f}, val loss: {val_loss :.4f}")
88
89 def save_checkpoint(self, epoch_id):
90 self.model.eval()
91 torch.save(self.model.state_dict(), os.path.join(self.log_dir, f'ckpt_epoch_{epoch_id}.pt'))
92
93 def run(self):
94 self.logger.info('==> Start Training')
95 print(self.model)
96
97 # lr_scheduler = build_lr_scheduler(self.cfg)(self.optimizer, **self.cfg['solver']['lr_scheduler']['args'])
98
99 for t in range(self.epoch):
100 self.logger.info(f'==> epoch {t + 1}')
101
102 self.train_loop()
103 self.val_loop(t + 1)
104 self.save_checkpoint(t + 1)
105
106 # lr_scheduler.step()
107
108 self.logger.info('==> End Training')
109
110 # for X, y in self.train_loader:
111 # print(X.size())
112 # print(y.size())
113
114 # pred = self.model(X)
115 # print(pred)
116 # print(y)
117
118 # loss = self.loss_fn(pred, y, reduction="mean")
119 # print(loss)
120
121 # break
122
123 # y_true = [
124 # [0, 1, 0],
125 # [0, 1, 0],
126 # [0, 0, 1],
127 # [0, 0, 0],
128 # ]
129 # y_pred = [
130 # [0.1, 0.8, 0.9],
131 # [0.2, 0.8, 0.1],
132 # [0.2, 0.1, 0.85],
133 # [0.2, 0.6, 0.1],
134 # ]
135 # acc_num = self.evaluate(torch.tensor(y_pred), torch.tensor(y_true))
1 from .registery import *
2 from .logger import get_logger_and_log_dir
3
4 __all__ = [
5 'Registry',
6 'get_logger_and_log_dir',
7 ]
8
1 import loguru
2 import os
3 import datetime
4
5 def get_logger_and_log_dir(log_root, suffix):
6 """
7 get logger and log path
8
9 Args:
10 log_root (str): root path of log
11 suffix (str): log save name
12
13 Returns:
14 logger (loguru.logger): logger object
15 log_path (str): current root log path (with suffix)
16 """
17 crt_date = datetime.date.today().strftime('%Y-%m-%d')
18 log_dir = os.path.join(log_root, crt_date, suffix)
19 if not os.path.exists(log_dir):
20 os.makedirs(log_dir)
21
22 logger_path = os.path.join(log_dir, 'logfile.log')
23 fmt = '{time:YYYY-MM-DD at HH:mm:ss} | {message}'
24 logger = loguru.logger
25 logger.add(logger_path, format=fmt)
26
27 return logger, log_dir
...\ No newline at end of file ...\ No newline at end of file
1 class Registry():
2 """
3 The registry that provides name -> object mapping, to support third-party
4 users' custom modules.
5
6 """
7
8 def __init__(self, name):
9 """
10 Args:
11 name (str): the name of this registry
12 """
13 self._name = name
14 self._obj_map = {}
15
16 def _do_register(self, name, obj, suffix=None):
17 if isinstance(suffix, str):
18 name = name + '_' + suffix
19
20 assert (name not in self._obj_map), (f"An object named '{name}' was already registered "
21 f"in '{self._name}' registry!")
22 self._obj_map[name] = obj
23
24 def register(self, obj=None, suffix=None):
25 """
26 Register the given object under the the name `obj.__name__`.
27 Can be used as either a decorator or not.
28 See docstring of this class for usage.
29 """
30 if obj is None:
31 # used as a decorator
32 def deco(func_or_class):
33 name = func_or_class.__name__
34 self._do_register(name, func_or_class, suffix)
35 return func_or_class
36
37 return deco
38
39 # used as a function call
40 name = obj.__name__
41 self._do_register(name, obj, suffix)
42
43 def get(self, name, suffix='soulwalker'):
44 ret = self._obj_map.get(name)
45 if ret is None:
46 ret = self._obj_map.get(name + '_' + suffix)
47 print(f'Name {name} is not found, use name: {name}_{suffix}!')
48 if ret is None:
49 raise KeyError(f"No object named '{name}' found in '{self._name}' registry!")
50 return ret
51
52 def __contains__(self, name):
53 return name in self._obj_map
54
55 def __iter__(self):
56 return iter(self._obj_map.items())
57
58 def keys(self):
59 return self._obj_map.keys()
60
61
62 DATASET_REGISTRY = Registry('dataset')
63 MODEL_REGISTRY = Registry('model')
64 LOSS_REGISTRY = Registry('loss')
65 METRIC_REGISTRY = Registry('metric')
66 OPTIMIZER_REGISTRY = Registry('optimizer')
67 SOLVER_REGISTRY = Registry('solver')
68 LR_SCHEDULER_REGISTRY = Registry('lr_scheduler')
69 COLLATE_FN_REGISTRY = Registry('collate_fn')
Styling with Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!