base_model.py
850 Bytes
import torch
import torch.nn as nn
from abc import ABCMeta
import math
class BaseModel(nn.Module, metaclass=ABCMeta):
"""
Base model class
"""
def __init__(self):
super().__init__()
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, (nn.LayerNorm, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.constant_(m.bias, 0)