builder.py 406 Bytes
import copy
from utils import MODEL_REGISTRY

from .mlp import MLPModel
from .vit import VisionTransformer
from .seq_labeling import SLTransformer


def build_model(cfg):
    model_cfg = copy.deepcopy(cfg)
    try:
        model_cfg = model_cfg['model']
    except Exception:
        raise 'should contain {model}'

    model = MODEL_REGISTRY.get(model_cfg['name'])(**model_cfg['args'])

    return model