main.py 601 Bytes
import argparse
import torch
import yaml
from solver.builder import build_solver


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', default='./config/mlp.yaml', type=str, help='config file')
    parser.add_argument('-e', '--eval', action="store_true")
    args = parser.parse_args()

    cfg = yaml.load(open(args.config, 'r').read(), Loader=yaml.FullLoader)
    # print(cfg)
    # print(torch.cuda.is_available())

    solver = build_solver(cfg)

    if args.eval:
        solver.evaluate()
    else:
        solver.run()


if __name__ == '__main__':
    main()