first commit
0 parents
Showing
23 changed files
with
901 additions
and
0 deletions
.gitignore
0 → 100644
README.md
0 → 100644
config/mlp.yaml
0 → 100644
1 | seed: 3407 | ||
2 | |||
3 | dataset: | ||
4 | name: 'CoordinatesData' | ||
5 | args: | ||
6 | data_root: '/Users/zhouweiqi/Downloads/gcfp/data/dataset' | ||
7 | train_anno_file: '/Users/zhouweiqi/Downloads/gcfp/data/dataset/train.csv' | ||
8 | val_anno_file: '/Users/zhouweiqi/Downloads/gcfp/data/dataset/valid.csv' | ||
9 | |||
10 | dataloader: | ||
11 | batch_size: 32 | ||
12 | num_workers: 4 | ||
13 | pin_memory: true | ||
14 | shuffle: true | ||
15 | |||
16 | model: | ||
17 | name: 'MLPModel' | ||
18 | args: | ||
19 | activation: 'relu' | ||
20 | |||
21 | solver: | ||
22 | name: 'MLPSolver' | ||
23 | args: | ||
24 | epoch: 100 | ||
25 | |||
26 | optimizer: | ||
27 | name: 'Adam' | ||
28 | args: | ||
29 | lr: !!float 1e-4 | ||
30 | weight_decay: !!float 5e-5 | ||
31 | |||
32 | lr_scheduler: | ||
33 | name: 'StepLR' | ||
34 | args: | ||
35 | step_size: 15 | ||
36 | gamma: 0.1 | ||
37 | |||
38 | loss: | ||
39 | name: 'SigmoidFocalLoss' | ||
40 | # name: 'CrossEntropyLoss' | ||
41 | args: | ||
42 | reduction: "mean" | ||
43 | |||
44 | logger: | ||
45 | log_root: '/Users/zhouweiqi/Downloads/test/logs' | ||
46 | suffix: 'mlp' | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
data/CoordinatesData.py
0 → 100644
1 | import os | ||
2 | import json | ||
3 | import torch | ||
4 | from torch.utils.data import DataLoader, Dataset | ||
5 | import pandas as pd | ||
6 | from utils.registery import DATASET_REGISTRY | ||
7 | |||
8 | |||
9 | @DATASET_REGISTRY.register() | ||
10 | class CoordinatesData(Dataset): | ||
11 | |||
12 | def __init__(self, | ||
13 | data_root: str = '/Users/zhouweiqi/Downloads/gcfp/data/dataset', | ||
14 | anno_file: str = 'train.csv', | ||
15 | phase: str = 'train'): | ||
16 | self.data_root = data_root | ||
17 | self.df = pd.read_csv(anno_file) | ||
18 | self.phase = phase | ||
19 | |||
20 | |||
21 | def __len__(self): | ||
22 | return len(self.df) | ||
23 | |||
24 | def __getitem__(self, idx): | ||
25 | series = self.df.iloc[idx] | ||
26 | name = series['name'] | ||
27 | |||
28 | with open(os.path.join(self.data_root, self.phase, name), 'r') as fp: | ||
29 | input_coordinates_list, label_list = json.load(fp) | ||
30 | |||
31 | input_coordinates = torch.tensor(input_coordinates_list) | ||
32 | label = torch.tensor(label_list).float() | ||
33 | |||
34 | return input_coordinates, label | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
data/__init__.py
0 → 100644
data/builder.py
0 → 100644
1 | import copy | ||
2 | from torch.utils.data import DataLoader | ||
3 | from utils.registery import DATASET_REGISTRY | ||
4 | |||
5 | from .CoordinatesData import CoordinatesData | ||
6 | |||
7 | |||
8 | def build_dataset(cfg): | ||
9 | |||
10 | dataset_cfg = copy.deepcopy(cfg) | ||
11 | try: | ||
12 | dataset_cfg = dataset_cfg['dataset'] | ||
13 | except Exception: | ||
14 | raise 'should contain {dataset}!' | ||
15 | |||
16 | train_cfg = copy.deepcopy(dataset_cfg) | ||
17 | val_cfg = copy.deepcopy(dataset_cfg) | ||
18 | train_cfg['args']['anno_file'] = train_cfg['args'].pop('train_anno_file') | ||
19 | train_cfg['args'].pop('val_anno_file', None) | ||
20 | train_cfg['args']['phase'] = 'train' | ||
21 | val_cfg['args']['anno_file'] = val_cfg['args'].pop('val_anno_file') | ||
22 | val_cfg['args'].pop('train_anno_file', None) | ||
23 | val_cfg['args']['phase'] = 'valid' | ||
24 | |||
25 | train_data = DATASET_REGISTRY.get(cfg['dataset']['name'])(**train_cfg['args']) | ||
26 | val_data = DATASET_REGISTRY.get(cfg['dataset']['name'])(**val_cfg['args']) | ||
27 | |||
28 | return train_data, val_data | ||
29 | |||
30 | |||
31 | |||
32 | def build_dataloader(cfg): | ||
33 | |||
34 | dataloader_cfg = copy.deepcopy(cfg) | ||
35 | try: | ||
36 | dataloader_cfg = cfg['dataloader'] | ||
37 | except Exception: | ||
38 | raise 'should contain {dataloader}!' | ||
39 | |||
40 | train_ds, val_ds = build_dataset(cfg) | ||
41 | |||
42 | train_loader = DataLoader(train_ds, | ||
43 | **dataloader_cfg) | ||
44 | |||
45 | val_loader = DataLoader(val_ds, | ||
46 | **dataloader_cfg) | ||
47 | |||
48 | return train_loader, val_loader | ||
49 |
data/create_dataset.py
0 → 100644
1 | import os | ||
2 | import cv2 | ||
3 | import uuid | ||
4 | import json | ||
5 | import random | ||
6 | import copy | ||
7 | |||
8 | import pandas as pd | ||
9 | from tools import get_file_paths, load_json | ||
10 | |||
11 | |||
12 | def text_statistics(go_res_dir): | ||
13 | """ | ||
14 | Args: | ||
15 | go_res_dir: str 通用OCR的JSON文件夹 | ||
16 | Returns: list 出现次数最多的文本及其次数 | ||
17 | """ | ||
18 | json_count = 0 | ||
19 | text_dict = {} | ||
20 | go_res_json_paths = get_file_paths(go_res_dir, ['.json', ]) | ||
21 | for go_res_json_path in go_res_json_paths: | ||
22 | print('Info: start {0}'.format(go_res_json_path)) | ||
23 | json_count += 1 | ||
24 | go_res = load_json(go_res_json_path) | ||
25 | for _, text in go_res.values(): | ||
26 | if text in text_dict: | ||
27 | text_dict[text] += 1 | ||
28 | else: | ||
29 | text_dict[text] = 1 | ||
30 | top_text_list = [] | ||
31 | # 按照次数排序 | ||
32 | for text, count in sorted(text_dict.items(), key=lambda x: x[1], reverse=True): | ||
33 | if text == '': | ||
34 | continue | ||
35 | # 丢弃:次数少于总数的2/3 | ||
36 | if count <= json_count // 3: | ||
37 | break | ||
38 | top_text_list.append((text, count)) | ||
39 | return top_text_list | ||
40 | |||
41 | def build_anno_file(dataset_dir, anno_file_path): | ||
42 | img_list = os.listdir(dataset_dir) | ||
43 | random.shuffle(img_list) | ||
44 | df = pd.DataFrame(columns=['name']) | ||
45 | df['name'] = img_list | ||
46 | df.to_csv(anno_file_path) | ||
47 | |||
48 | def build_dataset(img_dir, go_res_dir, label_dir, top_text_list, skip_list, save_dir): | ||
49 | """ | ||
50 | Args: | ||
51 | img_dir: str 图片目录 | ||
52 | go_res_dir: str 通用OCR的JSON保存目录 | ||
53 | label_dir: str 标注的JSON保存目录 | ||
54 | top_text_list: list 出现次数最多的文本及其次数 | ||
55 | skip_list: list 跳过的图片列表 | ||
56 | save_dir: str 数据集保存目录 | ||
57 | """ | ||
58 | # if os.path.exists(save_dir): | ||
59 | # return | ||
60 | # else: | ||
61 | # os.makedirs(save_dir, exist_ok=True) | ||
62 | |||
63 | count = 0 | ||
64 | un_count = 0 | ||
65 | top_text_count = len(top_text_list) | ||
66 | for img_name in sorted(os.listdir(img_dir)): | ||
67 | if img_name in skip_list: | ||
68 | print('Info: skip {0}'.format(img_name)) | ||
69 | continue | ||
70 | |||
71 | print('Info: start {0}'.format(img_name)) | ||
72 | image_path = os.path.join(img_dir, img_name) | ||
73 | img = cv2.imread(image_path) | ||
74 | h, w, _ = img.shape | ||
75 | base_image_name, _ = os.path.splitext(img_name) | ||
76 | go_res_json_path = os.path.join(go_res_dir, '{0}.json'.format(base_image_name)) | ||
77 | go_res = load_json(go_res_json_path) | ||
78 | |||
79 | input_key_list = [] | ||
80 | not_found_count = 0 | ||
81 | go_key_set = set() | ||
82 | for top_text, _ in top_text_list: | ||
83 | for go_key, ((x0, y0, x1, y1, x2, y2, x3, y3), text) in go_res.items(): | ||
84 | if text == top_text: | ||
85 | input_key_list.append([x0/w, y0/h, x1/w, y1/h, x2/w, y2/h, x3/w, y3/h]) | ||
86 | go_key_set.add(go_key) | ||
87 | break | ||
88 | else: | ||
89 | not_found_count += 1 | ||
90 | input_key_list.append([0, 0, 0, 0, 0, 0, 0, 0]) | ||
91 | if not_found_count >= top_text_count // 3: | ||
92 | print('Info: skip {0} : {1}/{2}'.format(img_name, not_found_count, top_text_count)) | ||
93 | continue | ||
94 | |||
95 | label_json_path = os.path.join(label_dir, '{0}.json'.format(base_image_name)) | ||
96 | label_res = load_json(label_json_path) | ||
97 | |||
98 | # 开票日期 发票代码 机打号码 车辆类型 电话 | ||
99 | test_group_id = [1, 2, 5, 9, 20] | ||
100 | group_list = [] | ||
101 | for group_id in test_group_id: | ||
102 | for item in label_res.get("shapes", []): | ||
103 | if item.get("group_id") == group_id: | ||
104 | x_list = [] | ||
105 | y_list = [] | ||
106 | for point in item['points']: | ||
107 | x_list.append(point[0]) | ||
108 | y_list.append(point[1]) | ||
109 | group_list.append([min(x_list) + (max(x_list) - min(x_list))/2, min(y_list) + (max(y_list) - min(y_list))/2]) | ||
110 | break | ||
111 | else: | ||
112 | group_list.append(None) | ||
113 | |||
114 | go_center_list = [] | ||
115 | for go_key, ((x0, y0, x1, y1, x2, y2, x3, y3), _) in go_res.items(): | ||
116 | if go_key in go_key_set: | ||
117 | continue | ||
118 | xmin = min(x0, x1, x2, x3) | ||
119 | ymin = min(y0, y1, y2, y3) | ||
120 | xmax = max(x0, x1, x2, x3) | ||
121 | ymax = max(y0, y1, y2, y3) | ||
122 | xcenter = xmin + (xmax - xmin)/2 | ||
123 | ycenter = ymin + (ymax - ymin)/2 | ||
124 | go_center_list.append([xcenter, ycenter, go_key]) | ||
125 | |||
126 | group_go_key_list = [] | ||
127 | for label_center_list in group_list: | ||
128 | if isinstance(label_center_list, list): | ||
129 | min_go_key = None | ||
130 | min_length = None | ||
131 | for go_x_center, go_y_center, go_key in go_center_list: | ||
132 | if go_key in go_key_set: | ||
133 | continue | ||
134 | length = abs(go_x_center-label_center_list[0])+abs(go_y_center-label_center_list[1]) | ||
135 | if min_go_key is None or length < min_length: | ||
136 | min_go_key = go_key | ||
137 | min_length = length | ||
138 | if min_go_key is not None: | ||
139 | go_key_set.add(min_go_key) | ||
140 | group_go_key_list.append(min_go_key) | ||
141 | else: | ||
142 | group_go_key_list.append(None) | ||
143 | else: | ||
144 | group_go_key_list.append(None) | ||
145 | |||
146 | src_label_list = [0 for _ in test_group_id] | ||
147 | for idx, find_go_key in enumerate(group_go_key_list): | ||
148 | if find_go_key is None: | ||
149 | continue | ||
150 | (x0, y0, x1, y1, x2, y2, x3, y3), _ = go_res[find_go_key] | ||
151 | input_list = copy.deepcopy(input_key_list) | ||
152 | input_list.append([x0/w, y0/h, x1/w, y1/h, x2/w, y2/h, x3/w, y3/h]) | ||
153 | |||
154 | input_label = copy.deepcopy(src_label_list) | ||
155 | input_label[idx] = 1 | ||
156 | # with open(os.path.join(save_dir, '{0}.json'.format(uuid.uuid3(uuid.NAMESPACE_DNS, '{0}-{1}'.format(img_name, find_go_key)))), 'w') as fp: | ||
157 | # json.dump([input_list, input_label], fp) | ||
158 | count += 1 | ||
159 | |||
160 | for go_key, ((x0, y0, x1, y1, x2, y2, x3, y3), _) in go_res.items(): | ||
161 | if go_key in go_key_set: | ||
162 | continue | ||
163 | input_list = copy.deepcopy(input_key_list) | ||
164 | input_list.append([x0/w, y0/h, x1/w, y1/h, x2/w, y2/h, x3/w, y3/h]) | ||
165 | # with open(os.path.join(save_dir, '{0}.json'.format(uuid.uuid3(uuid.NAMESPACE_DNS, '{0}-{1}'.format(img_name, go_key)))), 'w') as fp: | ||
166 | # json.dump([input_list, src_label_list], fp) | ||
167 | un_count += 1 | ||
168 | |||
169 | # break | ||
170 | print(count) | ||
171 | print(un_count) | ||
172 | |||
173 | |||
174 | if __name__ == '__main__': | ||
175 | base_dir = '/Users/zhouweiqi/Downloads/gcfp/data' | ||
176 | go_dir = os.path.join(base_dir, 'go_res') | ||
177 | dataset_save_dir = os.path.join(base_dir, 'dataset') | ||
178 | label_dir = os.path.join(base_dir, 'labeled') | ||
179 | |||
180 | train_go_path = os.path.join(go_dir, 'train') | ||
181 | train_image_path = os.path.join(label_dir, 'train', 'image') | ||
182 | train_label_path = os.path.join(label_dir, 'train', 'label') | ||
183 | train_dataset_dir = os.path.join(dataset_save_dir, 'train') | ||
184 | train_anno_file_path = os.path.join(dataset_save_dir, 'train.csv') | ||
185 | |||
186 | valid_go_path = os.path.join(go_dir, 'valid') | ||
187 | valid_image_path = os.path.join(label_dir, 'valid', 'image') | ||
188 | valid_label_path = os.path.join(label_dir, 'valid', 'label') | ||
189 | valid_dataset_dir = os.path.join(dataset_save_dir, 'valid') | ||
190 | valid_anno_file_path = os.path.join(dataset_save_dir, 'valid.csv') | ||
191 | |||
192 | # top_text_list = text_statistics(go_dir) | ||
193 | # for t in top_text_list: | ||
194 | # print(t) | ||
195 | |||
196 | filter_from_top_text_list = [ | ||
197 | ('机器编号', 496), | ||
198 | ('购买方名称', 496), | ||
199 | ('合格证号', 495), | ||
200 | ('进口证明书号', 495), | ||
201 | ('机打代码', 494), | ||
202 | ('车辆类型', 492), | ||
203 | ('完税凭证号码', 492), | ||
204 | ('机打号码', 491), | ||
205 | ('发动机号码', 491), | ||
206 | ('主管税务', 491), | ||
207 | ('价税合计', 489), | ||
208 | ('机关及代码', 489), | ||
209 | ('销货单位名称', 486), | ||
210 | ('厂牌型号', 485), | ||
211 | ('产地', 485), | ||
212 | ('商检单号', 483), | ||
213 | ('电话', 476), | ||
214 | ('开户银行', 472), | ||
215 | ('车辆识别代号/车架号码', 463), | ||
216 | ('身份证号码', 454), | ||
217 | ('吨位', 452), | ||
218 | ('备注:一车一票', 439), | ||
219 | ('地', 432), | ||
220 | ('账号', 431), | ||
221 | ('统一社会信用代码/', 424), | ||
222 | ('限乘人数', 404), | ||
223 | ('税额', 465), | ||
224 | ('址', 392) | ||
225 | ] | ||
226 | |||
227 | skip_list_train = [ | ||
228 | 'CH-B101910792-page-12.jpg', | ||
229 | 'CH-B101655312-page-13.jpg', | ||
230 | 'CH-B102278656.jpg', | ||
231 | 'CH-B101846620_page_1_img_0.jpg', | ||
232 | 'CH-B103062528-0.jpg', | ||
233 | 'CH-B102613120-3.jpg', | ||
234 | 'CH-B102997980-3.jpg', | ||
235 | 'CH-B102680060-3.jpg', | ||
236 | # 'CH-B102995500-2.jpg', # 没value | ||
237 | ] | ||
238 | |||
239 | skip_list_valid = [ | ||
240 | 'CH-B102897920-2.jpg', | ||
241 | 'CH-B102551284-0.jpg', | ||
242 | 'CH-B102879376-2.jpg', | ||
243 | 'CH-B101509488-page-16.jpg', | ||
244 | 'CH-B102708352-2.jpg', | ||
245 | ] | ||
246 | |||
247 | # build_dataset(train_image_path, train_go_path, train_label_path, filter_from_top_text_list, skip_list_train, train_dataset_dir) | ||
248 | |||
249 | build_dataset(valid_image_path, valid_go_path, valid_label_path, filter_from_top_text_list, skip_list_valid, valid_dataset_dir) | ||
250 | |||
251 | # build_anno_file(train_dataset_dir, train_anno_file_path) | ||
252 | # build_anno_file(valid_dataset_dir, valid_anno_file_path) | ||
253 | |||
254 |
data/tools.py
0 → 100644
1 | import json | ||
2 | import os | ||
3 | |||
4 | |||
5 | def get_exclude_paths(input_path, exclude_list=[]): | ||
6 | """ | ||
7 | Args: | ||
8 | input_path: str 目标目录 | ||
9 | exclude_list: list 排除文件或目录的相对位置 | ||
10 | Returns: set 排除文件或目录的绝对路径集合 | ||
11 | """ | ||
12 | exclude_paths_set = set() | ||
13 | if os.path.isdir(input_path): | ||
14 | for path in exclude_list: | ||
15 | abs_path = path if os.path.isabs(path) else os.path.join(input_path, path) | ||
16 | if not os.path.exists(abs_path): | ||
17 | print('Warning: exclude path not exists: {0}'.format(abs_path)) | ||
18 | continue | ||
19 | exclude_paths_set.add(abs_path) | ||
20 | return exclude_paths_set | ||
21 | |||
22 | def get_file_paths(input_path, suffix_list, exclude_list=[]): | ||
23 | """ | ||
24 | Args: | ||
25 | input_path: str 目标目录 | ||
26 | suffix_list: list 搜索的文件的后缀列表 | ||
27 | exclude_list: list 排除文件或目录的相对位置 | ||
28 | Returns: list 搜索到的相关文件绝对路径列表 | ||
29 | """ | ||
30 | exclude_paths_set = get_exclude_paths(input_path, exclude_list) | ||
31 | for parent, _, filenames in os.walk(input_path): | ||
32 | if parent in exclude_paths_set: | ||
33 | print('Info: exclude path: {0}'.format(parent)) | ||
34 | continue | ||
35 | for filename in filenames: | ||
36 | for suffix in suffix_list: | ||
37 | if filename.endswith(suffix): | ||
38 | file_path = os.path.join(parent, filename) | ||
39 | break | ||
40 | else: | ||
41 | continue | ||
42 | if file_path in exclude_paths_set: | ||
43 | print('Info: exclude path: {0}'.format(file_path)) | ||
44 | continue | ||
45 | yield file_path | ||
46 | |||
47 | def load_json(json_path): | ||
48 | """ | ||
49 | Args: | ||
50 | json_path: str JSON文件路径 | ||
51 | Returns: obj JSON对象 | ||
52 | """ | ||
53 | with open(json_path, 'r') as fp: | ||
54 | output = json.load(fp) | ||
55 | return output | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
loss/__init__.py
0 → 100644
loss/builder.py
0 → 100644
1 | import copy | ||
2 | import torch | ||
3 | import inspect | ||
4 | from utils.registery import LOSS_REGISTRY | ||
5 | from torchvision.ops import sigmoid_focal_loss | ||
6 | |||
7 | class SigmoidFocalLoss(torch.nn.modules.loss._WeightedLoss): | ||
8 | |||
9 | def __init__(self, | ||
10 | weight= None, | ||
11 | size_average=None, | ||
12 | reduce=None, | ||
13 | reduction: str = 'mean', | ||
14 | alpha: float = 0.25, | ||
15 | gamma: float = 2): | ||
16 | super().__init__(weight, size_average, reduce, reduction) | ||
17 | self.alpha = alpha | ||
18 | self.gamma = gamma | ||
19 | self.reduction = reduction | ||
20 | |||
21 | def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: | ||
22 | return sigmoid_focal_loss(inputs, targets, self.alpha, self.gamma, self.reduction) | ||
23 | |||
24 | |||
25 | def register_sigmoid_focal_loss(): | ||
26 | LOSS_REGISTRY.register()(SigmoidFocalLoss) | ||
27 | |||
28 | |||
29 | def register_torch_loss(): | ||
30 | for module_name in dir(torch.nn): | ||
31 | if module_name.startswith('__') or 'Loss' not in module_name: | ||
32 | continue | ||
33 | _loss = getattr(torch.nn, module_name) | ||
34 | if inspect.isclass(_loss) and issubclass(_loss, torch.nn.Module): | ||
35 | LOSS_REGISTRY.register()(_loss) | ||
36 | |||
37 | def build_loss(cfg): | ||
38 | register_sigmoid_focal_loss() | ||
39 | register_torch_loss() | ||
40 | loss_cfg = copy.deepcopy(cfg) | ||
41 | try: | ||
42 | loss_cfg = cfg['solver']['loss'] | ||
43 | except Exception: | ||
44 | raise 'should contain {solver.loss}!' | ||
45 | |||
46 | # return sigmoid_focal_loss | ||
47 | return LOSS_REGISTRY.get(loss_cfg['name'])(**loss_cfg['args']) | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
main.py
0 → 100644
1 | import argparse | ||
2 | import torch | ||
3 | import yaml | ||
4 | from solver.builder import build_solver | ||
5 | |||
6 | |||
7 | def main(): | ||
8 | parser = argparse.ArgumentParser() | ||
9 | parser.add_argument('--config', default='./config/mlp.yaml', type=str, help='config file') | ||
10 | args = parser.parse_args() | ||
11 | |||
12 | cfg = yaml.load(open(args.config, 'r').read(), Loader=yaml.FullLoader) | ||
13 | # print(cfg) | ||
14 | # print(torch.cuda.is_available()) | ||
15 | |||
16 | solver = build_solver(cfg) | ||
17 | solver.run() | ||
18 | |||
19 | |||
20 | if __name__ == '__main__': | ||
21 | main() |
model/__init__.py
0 → 100644
model/builder.py
0 → 100644
1 | import copy | ||
2 | from utils import MODEL_REGISTRY | ||
3 | |||
4 | from .mlp import MLPModel | ||
5 | |||
6 | |||
7 | def build_model(cfg): | ||
8 | model_cfg = copy.deepcopy(cfg) | ||
9 | try: | ||
10 | model_cfg = model_cfg['model'] | ||
11 | except Exception: | ||
12 | raise 'should contain {model}' | ||
13 | |||
14 | model = MODEL_REGISTRY.get(model_cfg['name'])(**model_cfg['args']) | ||
15 | |||
16 | return model | ||
17 |
model/mlp.py
0 → 100644
1 | from abc import ABCMeta | ||
2 | |||
3 | import torch.nn as nn | ||
4 | from utils.registery import MODEL_REGISTRY | ||
5 | |||
6 | |||
7 | @MODEL_REGISTRY.register() | ||
8 | class MLPModel(nn.Module): | ||
9 | |||
10 | def __init__(self, activation): | ||
11 | super().__init__() | ||
12 | self.activation_fn = activation | ||
13 | self.flatten = nn.Flatten() | ||
14 | self.linear_relu_stack = nn.Sequential( | ||
15 | nn.Linear(29*8, 512), | ||
16 | nn.ReLU(), | ||
17 | nn.Linear(512, 256), | ||
18 | nn.ReLU(), | ||
19 | nn.Linear(256, 5), | ||
20 | nn.Sigmoid(), | ||
21 | ) | ||
22 | self._initialize_weights() | ||
23 | |||
24 | def _initialize_weights(self): | ||
25 | for m in self.modules(): | ||
26 | if isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): | ||
27 | nn.init.xavier_uniform_(m.weight) | ||
28 | if m.bias is not None: | ||
29 | nn.init.constant_(m.bias, 0) | ||
30 | elif isinstance(m, (nn.LayerNorm, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)): | ||
31 | nn.init.constant_(m.weight, 1) | ||
32 | nn.init.constant_(m.bias, 0) | ||
33 | elif isinstance(m, nn.Linear): | ||
34 | nn.init.xavier_uniform_(m.weight) | ||
35 | if m.bias is not None: | ||
36 | nn.init.constant_(m.bias, 0) | ||
37 | |||
38 | def forward(self, x): | ||
39 | x = self.flatten(x) | ||
40 | logits = self.linear_relu_stack(x) | ||
41 | return logits | ||
42 | |||
43 |
optimizer/__init__.py
0 → 100644
optimizer/builder.py
0 → 100644
1 | import torch | ||
2 | import inspect | ||
3 | from utils.registery import OPTIMIZER_REGISTRY, LR_SCHEDULER_REGISTRY | ||
4 | import copy | ||
5 | |||
6 | def register_torch_optimizers(): | ||
7 | """ | ||
8 | Register all optimizers implemented by torch | ||
9 | """ | ||
10 | for module_name in dir(torch.optim): | ||
11 | if module_name.startswith('__'): | ||
12 | continue | ||
13 | _optim = getattr(torch.optim, module_name) | ||
14 | if inspect.isclass(_optim) and issubclass(_optim, torch.optim.Optimizer): | ||
15 | OPTIMIZER_REGISTRY.register()(_optim) | ||
16 | |||
17 | def build_optimizer(cfg): | ||
18 | register_torch_optimizers() | ||
19 | optimizer_cfg = copy.deepcopy(cfg) | ||
20 | try: | ||
21 | optimizer_cfg = cfg['solver']['optimizer'] | ||
22 | except Exception: | ||
23 | raise 'should contain {solver.optimizer}!' | ||
24 | |||
25 | return OPTIMIZER_REGISTRY.get(optimizer_cfg['name']) | ||
26 | |||
27 | def register_torch_lr_scheduler(): | ||
28 | """ | ||
29 | Register all lr_schedulers implemented by torch | ||
30 | """ | ||
31 | for module_name in dir(torch.optim.lr_scheduler): | ||
32 | if module_name.startswith('__'): | ||
33 | continue | ||
34 | |||
35 | _scheduler = getattr(torch.optim.lr_scheduler, module_name) | ||
36 | if inspect.isclass(_scheduler) and issubclass(_scheduler, torch.optim.lr_scheduler._LRScheduler): | ||
37 | LR_SCHEDULER_REGISTRY.register()(_scheduler) | ||
38 | |||
39 | def build_lr_scheduler(cfg): | ||
40 | register_torch_lr_scheduler() | ||
41 | scheduler_cfg = copy.deepcopy(cfg) | ||
42 | try: | ||
43 | scheduler_cfg = cfg['solver']['lr_scheduler'] | ||
44 | except Exception: | ||
45 | raise 'should contain {solver.lr_scheduler}!' | ||
46 | return LR_SCHEDULER_REGISTRY.get(scheduler_cfg['name']) | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
requirements.txt
0 → 100644
solver/__init__.py
0 → 100644
solver/builder.py
0 → 100644
1 | import copy | ||
2 | |||
3 | from utils.registery import SOLVER_REGISTRY | ||
4 | from .mlp_solver import MLPSolver | ||
5 | |||
6 | |||
7 | def build_solver(cfg): | ||
8 | cfg = copy.deepcopy(cfg) | ||
9 | |||
10 | try: | ||
11 | solver_cfg = cfg['solver'] | ||
12 | except Exception: | ||
13 | raise 'should contain {solver}!' | ||
14 | |||
15 | return SOLVER_REGISTRY.get(solver_cfg['name'])(cfg) |
solver/mlp_solver.py
0 → 100644
1 | import os | ||
2 | import copy | ||
3 | import torch | ||
4 | |||
5 | from model import build_model | ||
6 | from data import build_dataloader | ||
7 | from optimizer import build_optimizer, build_lr_scheduler | ||
8 | from loss import build_loss | ||
9 | from utils import SOLVER_REGISTRY, get_logger_and_log_dir | ||
10 | |||
11 | |||
12 | @SOLVER_REGISTRY.register() | ||
13 | class MLPSolver(object): | ||
14 | |||
15 | def __init__(self, cfg): | ||
16 | self.cfg = copy.deepcopy(cfg) | ||
17 | |||
18 | self.train_loader, self.val_loader = build_dataloader(cfg) | ||
19 | self.train_loader_size, self.val_loader_size = len(self.train_loader), len(self.val_loader) | ||
20 | self.train_dataset_size, self.val_dataset_size = len(self.train_loader.dataset), len(self.val_loader.dataset) | ||
21 | |||
22 | # BatchNorm ? | ||
23 | self.model = build_model(cfg) | ||
24 | |||
25 | self.loss_fn = build_loss(cfg) | ||
26 | |||
27 | self.optimizer = build_optimizer(cfg)(self.model.parameters(), **cfg['solver']['optimizer']['args']) | ||
28 | |||
29 | self.hyper_params = cfg['solver']['args'] | ||
30 | try: | ||
31 | self.epoch = self.hyper_params['epoch'] | ||
32 | except Exception: | ||
33 | raise 'should contain epoch in {solver.args}' | ||
34 | |||
35 | self.logger, self.log_dir = get_logger_and_log_dir(**cfg['solver']['logger']) | ||
36 | |||
37 | @staticmethod | ||
38 | def evaluate(y_pred, y_true, thresholds=0.5): | ||
39 | y_pred_idx = torch.argmax(y_pred, dim=1) + 1 | ||
40 | y_pred_is_other = (torch.amax(y_pred, dim=1) > 0.5).int() | ||
41 | y_pred_rebuild = torch.multiply(y_pred_idx, y_pred_is_other) | ||
42 | |||
43 | y_true_idx = torch.argmax(y_true, dim=1) + 1 | ||
44 | y_true_is_other = torch.sum(y_true, dim=1) | ||
45 | y_true_rebuild = torch.multiply(y_true_idx, y_true_is_other) | ||
46 | |||
47 | return torch.sum((y_pred_rebuild == y_true_rebuild).int()).item() | ||
48 | |||
49 | def train_loop(self): | ||
50 | self.model.train() | ||
51 | |||
52 | train_loss = 0 | ||
53 | for batch, (X, y) in enumerate(self.train_loader): | ||
54 | pred = self.model(X) | ||
55 | |||
56 | # loss = self.loss_fn(pred, y, reduction="mean") | ||
57 | loss = self.loss_fn(pred, y) | ||
58 | train_loss += loss.item() | ||
59 | |||
60 | if batch % 100 == 0: | ||
61 | loss_value, current = loss.item(), batch | ||
62 | self.logger.info(f'train iteration: {current}/{self.train_loader_size}, train loss: {loss_value :.4f}') | ||
63 | |||
64 | self.optimizer.zero_grad() | ||
65 | loss.backward() | ||
66 | self.optimizer.step() | ||
67 | |||
68 | train_loss /= self.train_loader_size | ||
69 | self.logger.info(f'train mean loss: {train_loss :.4f}') | ||
70 | |||
71 | @torch.no_grad() | ||
72 | def val_loop(self, t): | ||
73 | self.model.eval() | ||
74 | |||
75 | val_loss, correct = 0, 0 | ||
76 | for X, y in self.val_loader: | ||
77 | pred = self.model(X) | ||
78 | |||
79 | correct += self.evaluate(pred, y) | ||
80 | |||
81 | loss = self.loss_fn(pred, y) | ||
82 | val_loss += loss.item() | ||
83 | |||
84 | correct /= self.val_dataset_size | ||
85 | val_loss /= self.val_loader_size | ||
86 | |||
87 | self.logger.info(f"val accuracy: {correct :.4f}, val loss: {val_loss :.4f}") | ||
88 | |||
89 | def save_checkpoint(self, epoch_id): | ||
90 | self.model.eval() | ||
91 | torch.save(self.model.state_dict(), os.path.join(self.log_dir, f'ckpt_epoch_{epoch_id}.pt')) | ||
92 | |||
93 | def run(self): | ||
94 | self.logger.info('==> Start Training') | ||
95 | print(self.model) | ||
96 | |||
97 | # lr_scheduler = build_lr_scheduler(self.cfg)(self.optimizer, **self.cfg['solver']['lr_scheduler']['args']) | ||
98 | |||
99 | for t in range(self.epoch): | ||
100 | self.logger.info(f'==> epoch {t + 1}') | ||
101 | |||
102 | self.train_loop() | ||
103 | self.val_loop(t + 1) | ||
104 | self.save_checkpoint(t + 1) | ||
105 | |||
106 | # lr_scheduler.step() | ||
107 | |||
108 | self.logger.info('==> End Training') | ||
109 | |||
110 | # for X, y in self.train_loader: | ||
111 | # print(X.size()) | ||
112 | # print(y.size()) | ||
113 | |||
114 | # pred = self.model(X) | ||
115 | # print(pred) | ||
116 | # print(y) | ||
117 | |||
118 | # loss = self.loss_fn(pred, y, reduction="mean") | ||
119 | # print(loss) | ||
120 | |||
121 | # break | ||
122 | |||
123 | # y_true = [ | ||
124 | # [0, 1, 0], | ||
125 | # [0, 1, 0], | ||
126 | # [0, 0, 1], | ||
127 | # [0, 0, 0], | ||
128 | # ] | ||
129 | # y_pred = [ | ||
130 | # [0.1, 0.8, 0.9], | ||
131 | # [0.2, 0.8, 0.1], | ||
132 | # [0.2, 0.1, 0.85], | ||
133 | # [0.2, 0.6, 0.1], | ||
134 | # ] | ||
135 | # acc_num = self.evaluate(torch.tensor(y_pred), torch.tensor(y_true)) |
utils/__init__.py
0 → 100644
utils/logger.py
0 → 100644
1 | import loguru | ||
2 | import os | ||
3 | import datetime | ||
4 | |||
5 | def get_logger_and_log_dir(log_root, suffix): | ||
6 | """ | ||
7 | get logger and log path | ||
8 | |||
9 | Args: | ||
10 | log_root (str): root path of log | ||
11 | suffix (str): log save name | ||
12 | |||
13 | Returns: | ||
14 | logger (loguru.logger): logger object | ||
15 | log_path (str): current root log path (with suffix) | ||
16 | """ | ||
17 | crt_date = datetime.date.today().strftime('%Y-%m-%d') | ||
18 | log_dir = os.path.join(log_root, crt_date, suffix) | ||
19 | if not os.path.exists(log_dir): | ||
20 | os.makedirs(log_dir) | ||
21 | |||
22 | logger_path = os.path.join(log_dir, 'logfile.log') | ||
23 | fmt = '{time:YYYY-MM-DD at HH:mm:ss} | {message}' | ||
24 | logger = loguru.logger | ||
25 | logger.add(logger_path, format=fmt) | ||
26 | |||
27 | return logger, log_dir | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
utils/registery.py
0 → 100644
1 | class Registry(): | ||
2 | """ | ||
3 | The registry that provides name -> object mapping, to support third-party | ||
4 | users' custom modules. | ||
5 | |||
6 | """ | ||
7 | |||
8 | def __init__(self, name): | ||
9 | """ | ||
10 | Args: | ||
11 | name (str): the name of this registry | ||
12 | """ | ||
13 | self._name = name | ||
14 | self._obj_map = {} | ||
15 | |||
16 | def _do_register(self, name, obj, suffix=None): | ||
17 | if isinstance(suffix, str): | ||
18 | name = name + '_' + suffix | ||
19 | |||
20 | assert (name not in self._obj_map), (f"An object named '{name}' was already registered " | ||
21 | f"in '{self._name}' registry!") | ||
22 | self._obj_map[name] = obj | ||
23 | |||
24 | def register(self, obj=None, suffix=None): | ||
25 | """ | ||
26 | Register the given object under the the name `obj.__name__`. | ||
27 | Can be used as either a decorator or not. | ||
28 | See docstring of this class for usage. | ||
29 | """ | ||
30 | if obj is None: | ||
31 | # used as a decorator | ||
32 | def deco(func_or_class): | ||
33 | name = func_or_class.__name__ | ||
34 | self._do_register(name, func_or_class, suffix) | ||
35 | return func_or_class | ||
36 | |||
37 | return deco | ||
38 | |||
39 | # used as a function call | ||
40 | name = obj.__name__ | ||
41 | self._do_register(name, obj, suffix) | ||
42 | |||
43 | def get(self, name, suffix='soulwalker'): | ||
44 | ret = self._obj_map.get(name) | ||
45 | if ret is None: | ||
46 | ret = self._obj_map.get(name + '_' + suffix) | ||
47 | print(f'Name {name} is not found, use name: {name}_{suffix}!') | ||
48 | if ret is None: | ||
49 | raise KeyError(f"No object named '{name}' found in '{self._name}' registry!") | ||
50 | return ret | ||
51 | |||
52 | def __contains__(self, name): | ||
53 | return name in self._obj_map | ||
54 | |||
55 | def __iter__(self): | ||
56 | return iter(self._obj_map.items()) | ||
57 | |||
58 | def keys(self): | ||
59 | return self._obj_map.keys() | ||
60 | |||
61 | |||
62 | DATASET_REGISTRY = Registry('dataset') | ||
63 | MODEL_REGISTRY = Registry('model') | ||
64 | LOSS_REGISTRY = Registry('loss') | ||
65 | METRIC_REGISTRY = Registry('metric') | ||
66 | OPTIMIZER_REGISTRY = Registry('optimizer') | ||
67 | SOLVER_REGISTRY = Registry('solver') | ||
68 | LR_SCHEDULER_REGISTRY = Registry('lr_scheduler') | ||
69 | COLLATE_FN_REGISTRY = Registry('collate_fn') |
-
Please register or sign in to post a comment