unet.py 823 Bytes
import segmentation_models_pytorch as smp

from utils.registery import MODEL_REGISTRY
from core.model.base_model import BaseModel


@MODEL_REGISTRY.register()
class Unet(BaseModel):
    def __init__(self,
                 encoder_name: str = 'resnet50',
                 encoder_weights: str = 'imagenet',
                 in_channels: int = 3,
                 classes: int = 3,
                 activation: str = 'tanh'):
        super().__init__()
        self.model = smp.Unet(
            encoder_name=encoder_name,
            encoder_weights=encoder_weights,
            in_channels=in_channels,
            classes=classes,
            activation=activation
        )

        self._initialize_weights()

    def forward(self, x):
        out = x + self.model(x)
        # out = self.model(x)

        return out