helper.py 409 Bytes
import torch
import yaml
import os


def save_dict_to_yaml(dict_value, save_path):
    with open(save_path, 'w', encoding='utf-8') as file:
        yaml.dump(dict_value, file, sort_keys=False)


def save_checkpoint(model, cfg, log_path, epoch_id):
    save_dict_to_yaml(cfg, os.path.join(log_path, 'config.yaml'))
    torch.save(model.module.state_dict(), os.path.join(log_path, f'ckpt_epoch_{epoch_id}.pt'))