builder.py 376 Bytes
import copy

from utils.registery import SOLVER_REGISTRY
from .mlp_solver import MLPSolver
from .vit_solver import VITSolver
from .sl_solver import SLSolver


def build_solver(cfg):
    cfg = copy.deepcopy(cfg)

    try:
        solver_cfg = cfg['solver']
    except Exception:
        raise 'should contain {solver}!'

    return SOLVER_REGISTRY.get(solver_cfg['name'])(cfg)