unet.py
823 Bytes
import segmentation_models_pytorch as smp
from utils.registery import MODEL_REGISTRY
from core.model.base_model import BaseModel
@MODEL_REGISTRY.register()
class Unet(BaseModel):
def __init__(self,
encoder_name: str = 'resnet50',
encoder_weights: str = 'imagenet',
in_channels: int = 3,
classes: int = 3,
activation: str = 'tanh'):
super().__init__()
self.model = smp.Unet(
encoder_name=encoder_name,
encoder_weights=encoder_weights,
in_channels=in_channels,
classes=classes,
activation=activation
)
self._initialize_weights()
def forward(self, x):
out = x + self.model(x)
# out = self.model(x)
return out