unet_skip.py 928 Bytes
import segmentation_models_pytorch as smp

from utils.registery import MODEL_REGISTRY
from core.model.base_model import BaseModel


@MODEL_REGISTRY.register()
class UnetSkip(BaseModel):
    def __init__(self,
                 encoder_name: str = 'resnet50',
                 encoder_weights: str = 'imagenet',
                 in_channels: int = 3,
                 classes: int = 3,
                 activation: str = 'tanh'):
        super().__init__()
        self.model = smp.Unet(
            encoder_name=encoder_name,
            encoder_weights=encoder_weights,
            in_channels=in_channels,
            classes=classes,
            activation=activation
        )

        self._initialize_weights()

    def forward(self, x):
        residual = self.model(x)
        reconstruction = x + residual

        returned_dict = {'reconstruction': reconstruction, 'residual': residual}

        return returned_dict