builder.py
331 Bytes
import copy
from utils import MODEL_REGISTRY
from .mlp import MLPModel
def build_model(cfg):
model_cfg = copy.deepcopy(cfg)
try:
model_cfg = model_cfg['model']
except Exception:
raise 'should contain {model}'
model = MODEL_REGISTRY.get(model_cfg['name'])(**model_cfg['args'])
return model