unet_skip.py
928 Bytes
import segmentation_models_pytorch as smp
from utils.registery import MODEL_REGISTRY
from core.model.base_model import BaseModel
@MODEL_REGISTRY.register()
class UnetSkip(BaseModel):
def __init__(self,
encoder_name: str = 'resnet50',
encoder_weights: str = 'imagenet',
in_channels: int = 3,
classes: int = 3,
activation: str = 'tanh'):
super().__init__()
self.model = smp.Unet(
encoder_name=encoder_name,
encoder_weights=encoder_weights,
in_channels=in_channels,
classes=classes,
activation=activation
)
self._initialize_weights()
def forward(self, x):
residual = self.model(x)
reconstruction = x + residual
returned_dict = {'reconstruction': reconstruction, 'residual': residual}
return returned_dict