main.py 953 Bytes
import yaml
from core.solver import build_solver
import torch
import numpy as np
import random
import argparse


def init_seed(seed=778):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', default='./config/baseline.yaml', type=str, help='config file')
    parser.add_argument('--local_rank', default=-1, type=int, help='node rank for distributed training')
    args = parser.parse_args()

    cfg = yaml.load(open(args.config, 'r').read(), Loader=yaml.FullLoader)
    init_seed(cfg['seed'])

    torch.distributed.init_process_group(backend='nccl')
    torch.cuda.set_device(args.local_rank)

    solver = build_solver(cfg)

    solver.run()

if __name__ == '__main__':
    main()