first commit
0 parents
Showing
23 changed files
with
901 additions
and
0 deletions
.gitignore
0 → 100644
README.md
0 → 100644
config/mlp.yaml
0 → 100644
| 1 | seed: 3407 | ||
| 2 | |||
| 3 | dataset: | ||
| 4 | name: 'CoordinatesData' | ||
| 5 | args: | ||
| 6 | data_root: '/Users/zhouweiqi/Downloads/gcfp/data/dataset' | ||
| 7 | train_anno_file: '/Users/zhouweiqi/Downloads/gcfp/data/dataset/train.csv' | ||
| 8 | val_anno_file: '/Users/zhouweiqi/Downloads/gcfp/data/dataset/valid.csv' | ||
| 9 | |||
| 10 | dataloader: | ||
| 11 | batch_size: 32 | ||
| 12 | num_workers: 4 | ||
| 13 | pin_memory: true | ||
| 14 | shuffle: true | ||
| 15 | |||
| 16 | model: | ||
| 17 | name: 'MLPModel' | ||
| 18 | args: | ||
| 19 | activation: 'relu' | ||
| 20 | |||
| 21 | solver: | ||
| 22 | name: 'MLPSolver' | ||
| 23 | args: | ||
| 24 | epoch: 100 | ||
| 25 | |||
| 26 | optimizer: | ||
| 27 | name: 'Adam' | ||
| 28 | args: | ||
| 29 | lr: !!float 1e-4 | ||
| 30 | weight_decay: !!float 5e-5 | ||
| 31 | |||
| 32 | lr_scheduler: | ||
| 33 | name: 'StepLR' | ||
| 34 | args: | ||
| 35 | step_size: 15 | ||
| 36 | gamma: 0.1 | ||
| 37 | |||
| 38 | loss: | ||
| 39 | name: 'SigmoidFocalLoss' | ||
| 40 | # name: 'CrossEntropyLoss' | ||
| 41 | args: | ||
| 42 | reduction: "mean" | ||
| 43 | |||
| 44 | logger: | ||
| 45 | log_root: '/Users/zhouweiqi/Downloads/test/logs' | ||
| 46 | suffix: 'mlp' | ||
| ... | \ No newline at end of file | ... | \ No newline at end of file |
data/CoordinatesData.py
0 → 100644
| 1 | import os | ||
| 2 | import json | ||
| 3 | import torch | ||
| 4 | from torch.utils.data import DataLoader, Dataset | ||
| 5 | import pandas as pd | ||
| 6 | from utils.registery import DATASET_REGISTRY | ||
| 7 | |||
| 8 | |||
| 9 | @DATASET_REGISTRY.register() | ||
| 10 | class CoordinatesData(Dataset): | ||
| 11 | |||
| 12 | def __init__(self, | ||
| 13 | data_root: str = '/Users/zhouweiqi/Downloads/gcfp/data/dataset', | ||
| 14 | anno_file: str = 'train.csv', | ||
| 15 | phase: str = 'train'): | ||
| 16 | self.data_root = data_root | ||
| 17 | self.df = pd.read_csv(anno_file) | ||
| 18 | self.phase = phase | ||
| 19 | |||
| 20 | |||
| 21 | def __len__(self): | ||
| 22 | return len(self.df) | ||
| 23 | |||
| 24 | def __getitem__(self, idx): | ||
| 25 | series = self.df.iloc[idx] | ||
| 26 | name = series['name'] | ||
| 27 | |||
| 28 | with open(os.path.join(self.data_root, self.phase, name), 'r') as fp: | ||
| 29 | input_coordinates_list, label_list = json.load(fp) | ||
| 30 | |||
| 31 | input_coordinates = torch.tensor(input_coordinates_list) | ||
| 32 | label = torch.tensor(label_list).float() | ||
| 33 | |||
| 34 | return input_coordinates, label | ||
| ... | \ No newline at end of file | ... | \ No newline at end of file |
data/__init__.py
0 → 100644
data/builder.py
0 → 100644
| 1 | import copy | ||
| 2 | from torch.utils.data import DataLoader | ||
| 3 | from utils.registery import DATASET_REGISTRY | ||
| 4 | |||
| 5 | from .CoordinatesData import CoordinatesData | ||
| 6 | |||
| 7 | |||
| 8 | def build_dataset(cfg): | ||
| 9 | |||
| 10 | dataset_cfg = copy.deepcopy(cfg) | ||
| 11 | try: | ||
| 12 | dataset_cfg = dataset_cfg['dataset'] | ||
| 13 | except Exception: | ||
| 14 | raise 'should contain {dataset}!' | ||
| 15 | |||
| 16 | train_cfg = copy.deepcopy(dataset_cfg) | ||
| 17 | val_cfg = copy.deepcopy(dataset_cfg) | ||
| 18 | train_cfg['args']['anno_file'] = train_cfg['args'].pop('train_anno_file') | ||
| 19 | train_cfg['args'].pop('val_anno_file', None) | ||
| 20 | train_cfg['args']['phase'] = 'train' | ||
| 21 | val_cfg['args']['anno_file'] = val_cfg['args'].pop('val_anno_file') | ||
| 22 | val_cfg['args'].pop('train_anno_file', None) | ||
| 23 | val_cfg['args']['phase'] = 'valid' | ||
| 24 | |||
| 25 | train_data = DATASET_REGISTRY.get(cfg['dataset']['name'])(**train_cfg['args']) | ||
| 26 | val_data = DATASET_REGISTRY.get(cfg['dataset']['name'])(**val_cfg['args']) | ||
| 27 | |||
| 28 | return train_data, val_data | ||
| 29 | |||
| 30 | |||
| 31 | |||
| 32 | def build_dataloader(cfg): | ||
| 33 | |||
| 34 | dataloader_cfg = copy.deepcopy(cfg) | ||
| 35 | try: | ||
| 36 | dataloader_cfg = cfg['dataloader'] | ||
| 37 | except Exception: | ||
| 38 | raise 'should contain {dataloader}!' | ||
| 39 | |||
| 40 | train_ds, val_ds = build_dataset(cfg) | ||
| 41 | |||
| 42 | train_loader = DataLoader(train_ds, | ||
| 43 | **dataloader_cfg) | ||
| 44 | |||
| 45 | val_loader = DataLoader(val_ds, | ||
| 46 | **dataloader_cfg) | ||
| 47 | |||
| 48 | return train_loader, val_loader | ||
| 49 |
data/create_dataset.py
0 → 100644
| 1 | import os | ||
| 2 | import cv2 | ||
| 3 | import uuid | ||
| 4 | import json | ||
| 5 | import random | ||
| 6 | import copy | ||
| 7 | |||
| 8 | import pandas as pd | ||
| 9 | from tools import get_file_paths, load_json | ||
| 10 | |||
| 11 | |||
| 12 | def text_statistics(go_res_dir): | ||
| 13 | """ | ||
| 14 | Args: | ||
| 15 | go_res_dir: str 通用OCR的JSON文件夹 | ||
| 16 | Returns: list 出现次数最多的文本及其次数 | ||
| 17 | """ | ||
| 18 | json_count = 0 | ||
| 19 | text_dict = {} | ||
| 20 | go_res_json_paths = get_file_paths(go_res_dir, ['.json', ]) | ||
| 21 | for go_res_json_path in go_res_json_paths: | ||
| 22 | print('Info: start {0}'.format(go_res_json_path)) | ||
| 23 | json_count += 1 | ||
| 24 | go_res = load_json(go_res_json_path) | ||
| 25 | for _, text in go_res.values(): | ||
| 26 | if text in text_dict: | ||
| 27 | text_dict[text] += 1 | ||
| 28 | else: | ||
| 29 | text_dict[text] = 1 | ||
| 30 | top_text_list = [] | ||
| 31 | # 按照次数排序 | ||
| 32 | for text, count in sorted(text_dict.items(), key=lambda x: x[1], reverse=True): | ||
| 33 | if text == '': | ||
| 34 | continue | ||
| 35 | # 丢弃:次数少于总数的2/3 | ||
| 36 | if count <= json_count // 3: | ||
| 37 | break | ||
| 38 | top_text_list.append((text, count)) | ||
| 39 | return top_text_list | ||
| 40 | |||
| 41 | def build_anno_file(dataset_dir, anno_file_path): | ||
| 42 | img_list = os.listdir(dataset_dir) | ||
| 43 | random.shuffle(img_list) | ||
| 44 | df = pd.DataFrame(columns=['name']) | ||
| 45 | df['name'] = img_list | ||
| 46 | df.to_csv(anno_file_path) | ||
| 47 | |||
| 48 | def build_dataset(img_dir, go_res_dir, label_dir, top_text_list, skip_list, save_dir): | ||
| 49 | """ | ||
| 50 | Args: | ||
| 51 | img_dir: str 图片目录 | ||
| 52 | go_res_dir: str 通用OCR的JSON保存目录 | ||
| 53 | label_dir: str 标注的JSON保存目录 | ||
| 54 | top_text_list: list 出现次数最多的文本及其次数 | ||
| 55 | skip_list: list 跳过的图片列表 | ||
| 56 | save_dir: str 数据集保存目录 | ||
| 57 | """ | ||
| 58 | # if os.path.exists(save_dir): | ||
| 59 | # return | ||
| 60 | # else: | ||
| 61 | # os.makedirs(save_dir, exist_ok=True) | ||
| 62 | |||
| 63 | count = 0 | ||
| 64 | un_count = 0 | ||
| 65 | top_text_count = len(top_text_list) | ||
| 66 | for img_name in sorted(os.listdir(img_dir)): | ||
| 67 | if img_name in skip_list: | ||
| 68 | print('Info: skip {0}'.format(img_name)) | ||
| 69 | continue | ||
| 70 | |||
| 71 | print('Info: start {0}'.format(img_name)) | ||
| 72 | image_path = os.path.join(img_dir, img_name) | ||
| 73 | img = cv2.imread(image_path) | ||
| 74 | h, w, _ = img.shape | ||
| 75 | base_image_name, _ = os.path.splitext(img_name) | ||
| 76 | go_res_json_path = os.path.join(go_res_dir, '{0}.json'.format(base_image_name)) | ||
| 77 | go_res = load_json(go_res_json_path) | ||
| 78 | |||
| 79 | input_key_list = [] | ||
| 80 | not_found_count = 0 | ||
| 81 | go_key_set = set() | ||
| 82 | for top_text, _ in top_text_list: | ||
| 83 | for go_key, ((x0, y0, x1, y1, x2, y2, x3, y3), text) in go_res.items(): | ||
| 84 | if text == top_text: | ||
| 85 | input_key_list.append([x0/w, y0/h, x1/w, y1/h, x2/w, y2/h, x3/w, y3/h]) | ||
| 86 | go_key_set.add(go_key) | ||
| 87 | break | ||
| 88 | else: | ||
| 89 | not_found_count += 1 | ||
| 90 | input_key_list.append([0, 0, 0, 0, 0, 0, 0, 0]) | ||
| 91 | if not_found_count >= top_text_count // 3: | ||
| 92 | print('Info: skip {0} : {1}/{2}'.format(img_name, not_found_count, top_text_count)) | ||
| 93 | continue | ||
| 94 | |||
| 95 | label_json_path = os.path.join(label_dir, '{0}.json'.format(base_image_name)) | ||
| 96 | label_res = load_json(label_json_path) | ||
| 97 | |||
| 98 | # 开票日期 发票代码 机打号码 车辆类型 电话 | ||
| 99 | test_group_id = [1, 2, 5, 9, 20] | ||
| 100 | group_list = [] | ||
| 101 | for group_id in test_group_id: | ||
| 102 | for item in label_res.get("shapes", []): | ||
| 103 | if item.get("group_id") == group_id: | ||
| 104 | x_list = [] | ||
| 105 | y_list = [] | ||
| 106 | for point in item['points']: | ||
| 107 | x_list.append(point[0]) | ||
| 108 | y_list.append(point[1]) | ||
| 109 | group_list.append([min(x_list) + (max(x_list) - min(x_list))/2, min(y_list) + (max(y_list) - min(y_list))/2]) | ||
| 110 | break | ||
| 111 | else: | ||
| 112 | group_list.append(None) | ||
| 113 | |||
| 114 | go_center_list = [] | ||
| 115 | for go_key, ((x0, y0, x1, y1, x2, y2, x3, y3), _) in go_res.items(): | ||
| 116 | if go_key in go_key_set: | ||
| 117 | continue | ||
| 118 | xmin = min(x0, x1, x2, x3) | ||
| 119 | ymin = min(y0, y1, y2, y3) | ||
| 120 | xmax = max(x0, x1, x2, x3) | ||
| 121 | ymax = max(y0, y1, y2, y3) | ||
| 122 | xcenter = xmin + (xmax - xmin)/2 | ||
| 123 | ycenter = ymin + (ymax - ymin)/2 | ||
| 124 | go_center_list.append([xcenter, ycenter, go_key]) | ||
| 125 | |||
| 126 | group_go_key_list = [] | ||
| 127 | for label_center_list in group_list: | ||
| 128 | if isinstance(label_center_list, list): | ||
| 129 | min_go_key = None | ||
| 130 | min_length = None | ||
| 131 | for go_x_center, go_y_center, go_key in go_center_list: | ||
| 132 | if go_key in go_key_set: | ||
| 133 | continue | ||
| 134 | length = abs(go_x_center-label_center_list[0])+abs(go_y_center-label_center_list[1]) | ||
| 135 | if min_go_key is None or length < min_length: | ||
| 136 | min_go_key = go_key | ||
| 137 | min_length = length | ||
| 138 | if min_go_key is not None: | ||
| 139 | go_key_set.add(min_go_key) | ||
| 140 | group_go_key_list.append(min_go_key) | ||
| 141 | else: | ||
| 142 | group_go_key_list.append(None) | ||
| 143 | else: | ||
| 144 | group_go_key_list.append(None) | ||
| 145 | |||
| 146 | src_label_list = [0 for _ in test_group_id] | ||
| 147 | for idx, find_go_key in enumerate(group_go_key_list): | ||
| 148 | if find_go_key is None: | ||
| 149 | continue | ||
| 150 | (x0, y0, x1, y1, x2, y2, x3, y3), _ = go_res[find_go_key] | ||
| 151 | input_list = copy.deepcopy(input_key_list) | ||
| 152 | input_list.append([x0/w, y0/h, x1/w, y1/h, x2/w, y2/h, x3/w, y3/h]) | ||
| 153 | |||
| 154 | input_label = copy.deepcopy(src_label_list) | ||
| 155 | input_label[idx] = 1 | ||
| 156 | # with open(os.path.join(save_dir, '{0}.json'.format(uuid.uuid3(uuid.NAMESPACE_DNS, '{0}-{1}'.format(img_name, find_go_key)))), 'w') as fp: | ||
| 157 | # json.dump([input_list, input_label], fp) | ||
| 158 | count += 1 | ||
| 159 | |||
| 160 | for go_key, ((x0, y0, x1, y1, x2, y2, x3, y3), _) in go_res.items(): | ||
| 161 | if go_key in go_key_set: | ||
| 162 | continue | ||
| 163 | input_list = copy.deepcopy(input_key_list) | ||
| 164 | input_list.append([x0/w, y0/h, x1/w, y1/h, x2/w, y2/h, x3/w, y3/h]) | ||
| 165 | # with open(os.path.join(save_dir, '{0}.json'.format(uuid.uuid3(uuid.NAMESPACE_DNS, '{0}-{1}'.format(img_name, go_key)))), 'w') as fp: | ||
| 166 | # json.dump([input_list, src_label_list], fp) | ||
| 167 | un_count += 1 | ||
| 168 | |||
| 169 | # break | ||
| 170 | print(count) | ||
| 171 | print(un_count) | ||
| 172 | |||
| 173 | |||
| 174 | if __name__ == '__main__': | ||
| 175 | base_dir = '/Users/zhouweiqi/Downloads/gcfp/data' | ||
| 176 | go_dir = os.path.join(base_dir, 'go_res') | ||
| 177 | dataset_save_dir = os.path.join(base_dir, 'dataset') | ||
| 178 | label_dir = os.path.join(base_dir, 'labeled') | ||
| 179 | |||
| 180 | train_go_path = os.path.join(go_dir, 'train') | ||
| 181 | train_image_path = os.path.join(label_dir, 'train', 'image') | ||
| 182 | train_label_path = os.path.join(label_dir, 'train', 'label') | ||
| 183 | train_dataset_dir = os.path.join(dataset_save_dir, 'train') | ||
| 184 | train_anno_file_path = os.path.join(dataset_save_dir, 'train.csv') | ||
| 185 | |||
| 186 | valid_go_path = os.path.join(go_dir, 'valid') | ||
| 187 | valid_image_path = os.path.join(label_dir, 'valid', 'image') | ||
| 188 | valid_label_path = os.path.join(label_dir, 'valid', 'label') | ||
| 189 | valid_dataset_dir = os.path.join(dataset_save_dir, 'valid') | ||
| 190 | valid_anno_file_path = os.path.join(dataset_save_dir, 'valid.csv') | ||
| 191 | |||
| 192 | # top_text_list = text_statistics(go_dir) | ||
| 193 | # for t in top_text_list: | ||
| 194 | # print(t) | ||
| 195 | |||
| 196 | filter_from_top_text_list = [ | ||
| 197 | ('机器编号', 496), | ||
| 198 | ('购买方名称', 496), | ||
| 199 | ('合格证号', 495), | ||
| 200 | ('进口证明书号', 495), | ||
| 201 | ('机打代码', 494), | ||
| 202 | ('车辆类型', 492), | ||
| 203 | ('完税凭证号码', 492), | ||
| 204 | ('机打号码', 491), | ||
| 205 | ('发动机号码', 491), | ||
| 206 | ('主管税务', 491), | ||
| 207 | ('价税合计', 489), | ||
| 208 | ('机关及代码', 489), | ||
| 209 | ('销货单位名称', 486), | ||
| 210 | ('厂牌型号', 485), | ||
| 211 | ('产地', 485), | ||
| 212 | ('商检单号', 483), | ||
| 213 | ('电话', 476), | ||
| 214 | ('开户银行', 472), | ||
| 215 | ('车辆识别代号/车架号码', 463), | ||
| 216 | ('身份证号码', 454), | ||
| 217 | ('吨位', 452), | ||
| 218 | ('备注:一车一票', 439), | ||
| 219 | ('地', 432), | ||
| 220 | ('账号', 431), | ||
| 221 | ('统一社会信用代码/', 424), | ||
| 222 | ('限乘人数', 404), | ||
| 223 | ('税额', 465), | ||
| 224 | ('址', 392) | ||
| 225 | ] | ||
| 226 | |||
| 227 | skip_list_train = [ | ||
| 228 | 'CH-B101910792-page-12.jpg', | ||
| 229 | 'CH-B101655312-page-13.jpg', | ||
| 230 | 'CH-B102278656.jpg', | ||
| 231 | 'CH-B101846620_page_1_img_0.jpg', | ||
| 232 | 'CH-B103062528-0.jpg', | ||
| 233 | 'CH-B102613120-3.jpg', | ||
| 234 | 'CH-B102997980-3.jpg', | ||
| 235 | 'CH-B102680060-3.jpg', | ||
| 236 | # 'CH-B102995500-2.jpg', # 没value | ||
| 237 | ] | ||
| 238 | |||
| 239 | skip_list_valid = [ | ||
| 240 | 'CH-B102897920-2.jpg', | ||
| 241 | 'CH-B102551284-0.jpg', | ||
| 242 | 'CH-B102879376-2.jpg', | ||
| 243 | 'CH-B101509488-page-16.jpg', | ||
| 244 | 'CH-B102708352-2.jpg', | ||
| 245 | ] | ||
| 246 | |||
| 247 | # build_dataset(train_image_path, train_go_path, train_label_path, filter_from_top_text_list, skip_list_train, train_dataset_dir) | ||
| 248 | |||
| 249 | build_dataset(valid_image_path, valid_go_path, valid_label_path, filter_from_top_text_list, skip_list_valid, valid_dataset_dir) | ||
| 250 | |||
| 251 | # build_anno_file(train_dataset_dir, train_anno_file_path) | ||
| 252 | # build_anno_file(valid_dataset_dir, valid_anno_file_path) | ||
| 253 | |||
| 254 |
data/tools.py
0 → 100644
| 1 | import json | ||
| 2 | import os | ||
| 3 | |||
| 4 | |||
| 5 | def get_exclude_paths(input_path, exclude_list=[]): | ||
| 6 | """ | ||
| 7 | Args: | ||
| 8 | input_path: str 目标目录 | ||
| 9 | exclude_list: list 排除文件或目录的相对位置 | ||
| 10 | Returns: set 排除文件或目录的绝对路径集合 | ||
| 11 | """ | ||
| 12 | exclude_paths_set = set() | ||
| 13 | if os.path.isdir(input_path): | ||
| 14 | for path in exclude_list: | ||
| 15 | abs_path = path if os.path.isabs(path) else os.path.join(input_path, path) | ||
| 16 | if not os.path.exists(abs_path): | ||
| 17 | print('Warning: exclude path not exists: {0}'.format(abs_path)) | ||
| 18 | continue | ||
| 19 | exclude_paths_set.add(abs_path) | ||
| 20 | return exclude_paths_set | ||
| 21 | |||
| 22 | def get_file_paths(input_path, suffix_list, exclude_list=[]): | ||
| 23 | """ | ||
| 24 | Args: | ||
| 25 | input_path: str 目标目录 | ||
| 26 | suffix_list: list 搜索的文件的后缀列表 | ||
| 27 | exclude_list: list 排除文件或目录的相对位置 | ||
| 28 | Returns: list 搜索到的相关文件绝对路径列表 | ||
| 29 | """ | ||
| 30 | exclude_paths_set = get_exclude_paths(input_path, exclude_list) | ||
| 31 | for parent, _, filenames in os.walk(input_path): | ||
| 32 | if parent in exclude_paths_set: | ||
| 33 | print('Info: exclude path: {0}'.format(parent)) | ||
| 34 | continue | ||
| 35 | for filename in filenames: | ||
| 36 | for suffix in suffix_list: | ||
| 37 | if filename.endswith(suffix): | ||
| 38 | file_path = os.path.join(parent, filename) | ||
| 39 | break | ||
| 40 | else: | ||
| 41 | continue | ||
| 42 | if file_path in exclude_paths_set: | ||
| 43 | print('Info: exclude path: {0}'.format(file_path)) | ||
| 44 | continue | ||
| 45 | yield file_path | ||
| 46 | |||
| 47 | def load_json(json_path): | ||
| 48 | """ | ||
| 49 | Args: | ||
| 50 | json_path: str JSON文件路径 | ||
| 51 | Returns: obj JSON对象 | ||
| 52 | """ | ||
| 53 | with open(json_path, 'r') as fp: | ||
| 54 | output = json.load(fp) | ||
| 55 | return output | ||
| ... | \ No newline at end of file | ... | \ No newline at end of file |
loss/__init__.py
0 → 100644
loss/builder.py
0 → 100644
| 1 | import copy | ||
| 2 | import torch | ||
| 3 | import inspect | ||
| 4 | from utils.registery import LOSS_REGISTRY | ||
| 5 | from torchvision.ops import sigmoid_focal_loss | ||
| 6 | |||
| 7 | class SigmoidFocalLoss(torch.nn.modules.loss._WeightedLoss): | ||
| 8 | |||
| 9 | def __init__(self, | ||
| 10 | weight= None, | ||
| 11 | size_average=None, | ||
| 12 | reduce=None, | ||
| 13 | reduction: str = 'mean', | ||
| 14 | alpha: float = 0.25, | ||
| 15 | gamma: float = 2): | ||
| 16 | super().__init__(weight, size_average, reduce, reduction) | ||
| 17 | self.alpha = alpha | ||
| 18 | self.gamma = gamma | ||
| 19 | self.reduction = reduction | ||
| 20 | |||
| 21 | def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: | ||
| 22 | return sigmoid_focal_loss(inputs, targets, self.alpha, self.gamma, self.reduction) | ||
| 23 | |||
| 24 | |||
| 25 | def register_sigmoid_focal_loss(): | ||
| 26 | LOSS_REGISTRY.register()(SigmoidFocalLoss) | ||
| 27 | |||
| 28 | |||
| 29 | def register_torch_loss(): | ||
| 30 | for module_name in dir(torch.nn): | ||
| 31 | if module_name.startswith('__') or 'Loss' not in module_name: | ||
| 32 | continue | ||
| 33 | _loss = getattr(torch.nn, module_name) | ||
| 34 | if inspect.isclass(_loss) and issubclass(_loss, torch.nn.Module): | ||
| 35 | LOSS_REGISTRY.register()(_loss) | ||
| 36 | |||
| 37 | def build_loss(cfg): | ||
| 38 | register_sigmoid_focal_loss() | ||
| 39 | register_torch_loss() | ||
| 40 | loss_cfg = copy.deepcopy(cfg) | ||
| 41 | try: | ||
| 42 | loss_cfg = cfg['solver']['loss'] | ||
| 43 | except Exception: | ||
| 44 | raise 'should contain {solver.loss}!' | ||
| 45 | |||
| 46 | # return sigmoid_focal_loss | ||
| 47 | return LOSS_REGISTRY.get(loss_cfg['name'])(**loss_cfg['args']) | ||
| ... | \ No newline at end of file | ... | \ No newline at end of file |
main.py
0 → 100644
| 1 | import argparse | ||
| 2 | import torch | ||
| 3 | import yaml | ||
| 4 | from solver.builder import build_solver | ||
| 5 | |||
| 6 | |||
| 7 | def main(): | ||
| 8 | parser = argparse.ArgumentParser() | ||
| 9 | parser.add_argument('--config', default='./config/mlp.yaml', type=str, help='config file') | ||
| 10 | args = parser.parse_args() | ||
| 11 | |||
| 12 | cfg = yaml.load(open(args.config, 'r').read(), Loader=yaml.FullLoader) | ||
| 13 | # print(cfg) | ||
| 14 | # print(torch.cuda.is_available()) | ||
| 15 | |||
| 16 | solver = build_solver(cfg) | ||
| 17 | solver.run() | ||
| 18 | |||
| 19 | |||
| 20 | if __name__ == '__main__': | ||
| 21 | main() |
model/__init__.py
0 → 100644
model/builder.py
0 → 100644
| 1 | import copy | ||
| 2 | from utils import MODEL_REGISTRY | ||
| 3 | |||
| 4 | from .mlp import MLPModel | ||
| 5 | |||
| 6 | |||
| 7 | def build_model(cfg): | ||
| 8 | model_cfg = copy.deepcopy(cfg) | ||
| 9 | try: | ||
| 10 | model_cfg = model_cfg['model'] | ||
| 11 | except Exception: | ||
| 12 | raise 'should contain {model}' | ||
| 13 | |||
| 14 | model = MODEL_REGISTRY.get(model_cfg['name'])(**model_cfg['args']) | ||
| 15 | |||
| 16 | return model | ||
| 17 |
model/mlp.py
0 → 100644
| 1 | from abc import ABCMeta | ||
| 2 | |||
| 3 | import torch.nn as nn | ||
| 4 | from utils.registery import MODEL_REGISTRY | ||
| 5 | |||
| 6 | |||
| 7 | @MODEL_REGISTRY.register() | ||
| 8 | class MLPModel(nn.Module): | ||
| 9 | |||
| 10 | def __init__(self, activation): | ||
| 11 | super().__init__() | ||
| 12 | self.activation_fn = activation | ||
| 13 | self.flatten = nn.Flatten() | ||
| 14 | self.linear_relu_stack = nn.Sequential( | ||
| 15 | nn.Linear(29*8, 512), | ||
| 16 | nn.ReLU(), | ||
| 17 | nn.Linear(512, 256), | ||
| 18 | nn.ReLU(), | ||
| 19 | nn.Linear(256, 5), | ||
| 20 | nn.Sigmoid(), | ||
| 21 | ) | ||
| 22 | self._initialize_weights() | ||
| 23 | |||
| 24 | def _initialize_weights(self): | ||
| 25 | for m in self.modules(): | ||
| 26 | if isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): | ||
| 27 | nn.init.xavier_uniform_(m.weight) | ||
| 28 | if m.bias is not None: | ||
| 29 | nn.init.constant_(m.bias, 0) | ||
| 30 | elif isinstance(m, (nn.LayerNorm, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)): | ||
| 31 | nn.init.constant_(m.weight, 1) | ||
| 32 | nn.init.constant_(m.bias, 0) | ||
| 33 | elif isinstance(m, nn.Linear): | ||
| 34 | nn.init.xavier_uniform_(m.weight) | ||
| 35 | if m.bias is not None: | ||
| 36 | nn.init.constant_(m.bias, 0) | ||
| 37 | |||
| 38 | def forward(self, x): | ||
| 39 | x = self.flatten(x) | ||
| 40 | logits = self.linear_relu_stack(x) | ||
| 41 | return logits | ||
| 42 | |||
| 43 |
optimizer/__init__.py
0 → 100644
optimizer/builder.py
0 → 100644
| 1 | import torch | ||
| 2 | import inspect | ||
| 3 | from utils.registery import OPTIMIZER_REGISTRY, LR_SCHEDULER_REGISTRY | ||
| 4 | import copy | ||
| 5 | |||
| 6 | def register_torch_optimizers(): | ||
| 7 | """ | ||
| 8 | Register all optimizers implemented by torch | ||
| 9 | """ | ||
| 10 | for module_name in dir(torch.optim): | ||
| 11 | if module_name.startswith('__'): | ||
| 12 | continue | ||
| 13 | _optim = getattr(torch.optim, module_name) | ||
| 14 | if inspect.isclass(_optim) and issubclass(_optim, torch.optim.Optimizer): | ||
| 15 | OPTIMIZER_REGISTRY.register()(_optim) | ||
| 16 | |||
| 17 | def build_optimizer(cfg): | ||
| 18 | register_torch_optimizers() | ||
| 19 | optimizer_cfg = copy.deepcopy(cfg) | ||
| 20 | try: | ||
| 21 | optimizer_cfg = cfg['solver']['optimizer'] | ||
| 22 | except Exception: | ||
| 23 | raise 'should contain {solver.optimizer}!' | ||
| 24 | |||
| 25 | return OPTIMIZER_REGISTRY.get(optimizer_cfg['name']) | ||
| 26 | |||
| 27 | def register_torch_lr_scheduler(): | ||
| 28 | """ | ||
| 29 | Register all lr_schedulers implemented by torch | ||
| 30 | """ | ||
| 31 | for module_name in dir(torch.optim.lr_scheduler): | ||
| 32 | if module_name.startswith('__'): | ||
| 33 | continue | ||
| 34 | |||
| 35 | _scheduler = getattr(torch.optim.lr_scheduler, module_name) | ||
| 36 | if inspect.isclass(_scheduler) and issubclass(_scheduler, torch.optim.lr_scheduler._LRScheduler): | ||
| 37 | LR_SCHEDULER_REGISTRY.register()(_scheduler) | ||
| 38 | |||
| 39 | def build_lr_scheduler(cfg): | ||
| 40 | register_torch_lr_scheduler() | ||
| 41 | scheduler_cfg = copy.deepcopy(cfg) | ||
| 42 | try: | ||
| 43 | scheduler_cfg = cfg['solver']['lr_scheduler'] | ||
| 44 | except Exception: | ||
| 45 | raise 'should contain {solver.lr_scheduler}!' | ||
| 46 | return LR_SCHEDULER_REGISTRY.get(scheduler_cfg['name']) | ||
| ... | \ No newline at end of file | ... | \ No newline at end of file |
requirements.txt
0 → 100644
solver/__init__.py
0 → 100644
solver/builder.py
0 → 100644
| 1 | import copy | ||
| 2 | |||
| 3 | from utils.registery import SOLVER_REGISTRY | ||
| 4 | from .mlp_solver import MLPSolver | ||
| 5 | |||
| 6 | |||
| 7 | def build_solver(cfg): | ||
| 8 | cfg = copy.deepcopy(cfg) | ||
| 9 | |||
| 10 | try: | ||
| 11 | solver_cfg = cfg['solver'] | ||
| 12 | except Exception: | ||
| 13 | raise 'should contain {solver}!' | ||
| 14 | |||
| 15 | return SOLVER_REGISTRY.get(solver_cfg['name'])(cfg) |
solver/mlp_solver.py
0 → 100644
| 1 | import os | ||
| 2 | import copy | ||
| 3 | import torch | ||
| 4 | |||
| 5 | from model import build_model | ||
| 6 | from data import build_dataloader | ||
| 7 | from optimizer import build_optimizer, build_lr_scheduler | ||
| 8 | from loss import build_loss | ||
| 9 | from utils import SOLVER_REGISTRY, get_logger_and_log_dir | ||
| 10 | |||
| 11 | |||
| 12 | @SOLVER_REGISTRY.register() | ||
| 13 | class MLPSolver(object): | ||
| 14 | |||
| 15 | def __init__(self, cfg): | ||
| 16 | self.cfg = copy.deepcopy(cfg) | ||
| 17 | |||
| 18 | self.train_loader, self.val_loader = build_dataloader(cfg) | ||
| 19 | self.train_loader_size, self.val_loader_size = len(self.train_loader), len(self.val_loader) | ||
| 20 | self.train_dataset_size, self.val_dataset_size = len(self.train_loader.dataset), len(self.val_loader.dataset) | ||
| 21 | |||
| 22 | # BatchNorm ? | ||
| 23 | self.model = build_model(cfg) | ||
| 24 | |||
| 25 | self.loss_fn = build_loss(cfg) | ||
| 26 | |||
| 27 | self.optimizer = build_optimizer(cfg)(self.model.parameters(), **cfg['solver']['optimizer']['args']) | ||
| 28 | |||
| 29 | self.hyper_params = cfg['solver']['args'] | ||
| 30 | try: | ||
| 31 | self.epoch = self.hyper_params['epoch'] | ||
| 32 | except Exception: | ||
| 33 | raise 'should contain epoch in {solver.args}' | ||
| 34 | |||
| 35 | self.logger, self.log_dir = get_logger_and_log_dir(**cfg['solver']['logger']) | ||
| 36 | |||
| 37 | @staticmethod | ||
| 38 | def evaluate(y_pred, y_true, thresholds=0.5): | ||
| 39 | y_pred_idx = torch.argmax(y_pred, dim=1) + 1 | ||
| 40 | y_pred_is_other = (torch.amax(y_pred, dim=1) > 0.5).int() | ||
| 41 | y_pred_rebuild = torch.multiply(y_pred_idx, y_pred_is_other) | ||
| 42 | |||
| 43 | y_true_idx = torch.argmax(y_true, dim=1) + 1 | ||
| 44 | y_true_is_other = torch.sum(y_true, dim=1) | ||
| 45 | y_true_rebuild = torch.multiply(y_true_idx, y_true_is_other) | ||
| 46 | |||
| 47 | return torch.sum((y_pred_rebuild == y_true_rebuild).int()).item() | ||
| 48 | |||
| 49 | def train_loop(self): | ||
| 50 | self.model.train() | ||
| 51 | |||
| 52 | train_loss = 0 | ||
| 53 | for batch, (X, y) in enumerate(self.train_loader): | ||
| 54 | pred = self.model(X) | ||
| 55 | |||
| 56 | # loss = self.loss_fn(pred, y, reduction="mean") | ||
| 57 | loss = self.loss_fn(pred, y) | ||
| 58 | train_loss += loss.item() | ||
| 59 | |||
| 60 | if batch % 100 == 0: | ||
| 61 | loss_value, current = loss.item(), batch | ||
| 62 | self.logger.info(f'train iteration: {current}/{self.train_loader_size}, train loss: {loss_value :.4f}') | ||
| 63 | |||
| 64 | self.optimizer.zero_grad() | ||
| 65 | loss.backward() | ||
| 66 | self.optimizer.step() | ||
| 67 | |||
| 68 | train_loss /= self.train_loader_size | ||
| 69 | self.logger.info(f'train mean loss: {train_loss :.4f}') | ||
| 70 | |||
| 71 | @torch.no_grad() | ||
| 72 | def val_loop(self, t): | ||
| 73 | self.model.eval() | ||
| 74 | |||
| 75 | val_loss, correct = 0, 0 | ||
| 76 | for X, y in self.val_loader: | ||
| 77 | pred = self.model(X) | ||
| 78 | |||
| 79 | correct += self.evaluate(pred, y) | ||
| 80 | |||
| 81 | loss = self.loss_fn(pred, y) | ||
| 82 | val_loss += loss.item() | ||
| 83 | |||
| 84 | correct /= self.val_dataset_size | ||
| 85 | val_loss /= self.val_loader_size | ||
| 86 | |||
| 87 | self.logger.info(f"val accuracy: {correct :.4f}, val loss: {val_loss :.4f}") | ||
| 88 | |||
| 89 | def save_checkpoint(self, epoch_id): | ||
| 90 | self.model.eval() | ||
| 91 | torch.save(self.model.state_dict(), os.path.join(self.log_dir, f'ckpt_epoch_{epoch_id}.pt')) | ||
| 92 | |||
| 93 | def run(self): | ||
| 94 | self.logger.info('==> Start Training') | ||
| 95 | print(self.model) | ||
| 96 | |||
| 97 | # lr_scheduler = build_lr_scheduler(self.cfg)(self.optimizer, **self.cfg['solver']['lr_scheduler']['args']) | ||
| 98 | |||
| 99 | for t in range(self.epoch): | ||
| 100 | self.logger.info(f'==> epoch {t + 1}') | ||
| 101 | |||
| 102 | self.train_loop() | ||
| 103 | self.val_loop(t + 1) | ||
| 104 | self.save_checkpoint(t + 1) | ||
| 105 | |||
| 106 | # lr_scheduler.step() | ||
| 107 | |||
| 108 | self.logger.info('==> End Training') | ||
| 109 | |||
| 110 | # for X, y in self.train_loader: | ||
| 111 | # print(X.size()) | ||
| 112 | # print(y.size()) | ||
| 113 | |||
| 114 | # pred = self.model(X) | ||
| 115 | # print(pred) | ||
| 116 | # print(y) | ||
| 117 | |||
| 118 | # loss = self.loss_fn(pred, y, reduction="mean") | ||
| 119 | # print(loss) | ||
| 120 | |||
| 121 | # break | ||
| 122 | |||
| 123 | # y_true = [ | ||
| 124 | # [0, 1, 0], | ||
| 125 | # [0, 1, 0], | ||
| 126 | # [0, 0, 1], | ||
| 127 | # [0, 0, 0], | ||
| 128 | # ] | ||
| 129 | # y_pred = [ | ||
| 130 | # [0.1, 0.8, 0.9], | ||
| 131 | # [0.2, 0.8, 0.1], | ||
| 132 | # [0.2, 0.1, 0.85], | ||
| 133 | # [0.2, 0.6, 0.1], | ||
| 134 | # ] | ||
| 135 | # acc_num = self.evaluate(torch.tensor(y_pred), torch.tensor(y_true)) |
utils/__init__.py
0 → 100644
utils/logger.py
0 → 100644
| 1 | import loguru | ||
| 2 | import os | ||
| 3 | import datetime | ||
| 4 | |||
| 5 | def get_logger_and_log_dir(log_root, suffix): | ||
| 6 | """ | ||
| 7 | get logger and log path | ||
| 8 | |||
| 9 | Args: | ||
| 10 | log_root (str): root path of log | ||
| 11 | suffix (str): log save name | ||
| 12 | |||
| 13 | Returns: | ||
| 14 | logger (loguru.logger): logger object | ||
| 15 | log_path (str): current root log path (with suffix) | ||
| 16 | """ | ||
| 17 | crt_date = datetime.date.today().strftime('%Y-%m-%d') | ||
| 18 | log_dir = os.path.join(log_root, crt_date, suffix) | ||
| 19 | if not os.path.exists(log_dir): | ||
| 20 | os.makedirs(log_dir) | ||
| 21 | |||
| 22 | logger_path = os.path.join(log_dir, 'logfile.log') | ||
| 23 | fmt = '{time:YYYY-MM-DD at HH:mm:ss} | {message}' | ||
| 24 | logger = loguru.logger | ||
| 25 | logger.add(logger_path, format=fmt) | ||
| 26 | |||
| 27 | return logger, log_dir | ||
| ... | \ No newline at end of file | ... | \ No newline at end of file |
utils/registery.py
0 → 100644
| 1 | class Registry(): | ||
| 2 | """ | ||
| 3 | The registry that provides name -> object mapping, to support third-party | ||
| 4 | users' custom modules. | ||
| 5 | |||
| 6 | """ | ||
| 7 | |||
| 8 | def __init__(self, name): | ||
| 9 | """ | ||
| 10 | Args: | ||
| 11 | name (str): the name of this registry | ||
| 12 | """ | ||
| 13 | self._name = name | ||
| 14 | self._obj_map = {} | ||
| 15 | |||
| 16 | def _do_register(self, name, obj, suffix=None): | ||
| 17 | if isinstance(suffix, str): | ||
| 18 | name = name + '_' + suffix | ||
| 19 | |||
| 20 | assert (name not in self._obj_map), (f"An object named '{name}' was already registered " | ||
| 21 | f"in '{self._name}' registry!") | ||
| 22 | self._obj_map[name] = obj | ||
| 23 | |||
| 24 | def register(self, obj=None, suffix=None): | ||
| 25 | """ | ||
| 26 | Register the given object under the the name `obj.__name__`. | ||
| 27 | Can be used as either a decorator or not. | ||
| 28 | See docstring of this class for usage. | ||
| 29 | """ | ||
| 30 | if obj is None: | ||
| 31 | # used as a decorator | ||
| 32 | def deco(func_or_class): | ||
| 33 | name = func_or_class.__name__ | ||
| 34 | self._do_register(name, func_or_class, suffix) | ||
| 35 | return func_or_class | ||
| 36 | |||
| 37 | return deco | ||
| 38 | |||
| 39 | # used as a function call | ||
| 40 | name = obj.__name__ | ||
| 41 | self._do_register(name, obj, suffix) | ||
| 42 | |||
| 43 | def get(self, name, suffix='soulwalker'): | ||
| 44 | ret = self._obj_map.get(name) | ||
| 45 | if ret is None: | ||
| 46 | ret = self._obj_map.get(name + '_' + suffix) | ||
| 47 | print(f'Name {name} is not found, use name: {name}_{suffix}!') | ||
| 48 | if ret is None: | ||
| 49 | raise KeyError(f"No object named '{name}' found in '{self._name}' registry!") | ||
| 50 | return ret | ||
| 51 | |||
| 52 | def __contains__(self, name): | ||
| 53 | return name in self._obj_map | ||
| 54 | |||
| 55 | def __iter__(self): | ||
| 56 | return iter(self._obj_map.items()) | ||
| 57 | |||
| 58 | def keys(self): | ||
| 59 | return self._obj_map.keys() | ||
| 60 | |||
| 61 | |||
| 62 | DATASET_REGISTRY = Registry('dataset') | ||
| 63 | MODEL_REGISTRY = Registry('model') | ||
| 64 | LOSS_REGISTRY = Registry('loss') | ||
| 65 | METRIC_REGISTRY = Registry('metric') | ||
| 66 | OPTIMIZER_REGISTRY = Registry('optimizer') | ||
| 67 | SOLVER_REGISTRY = Registry('solver') | ||
| 68 | LR_SCHEDULER_REGISTRY = Registry('lr_scheduler') | ||
| 69 | COLLATE_FN_REGISTRY = Registry('collate_fn') |
-
Please register or sign in to post a comment