import logging import torch from typing import Optional from torch import Tensor from main import load_dotted_path from models.base_corruption import BaseCorruption, NoCorruption from models.base_encoder import BaseEncoder class DenoisingAutoEncoder(BaseEncoder): # Based on https://github.com/pranjaldatta/Denoising-Autoencoder-in-Pytorch/blob/master/DenoisingAutoencoder.ipynb name = "DenoisingAutoEncoder" def __init__(self, name: Optional[str] = None, input_shape: int = 0, loss_function=None, input_corruption_model: BaseCorruption = NoCorruption): self.log = logging.getLogger(self.__class__.__name__) # Call superclass to initialize parameters. super(DenoisingAutoEncoder, self).__init__(name, input_shape, loss_function) # Network, optimizer and loss function are the same as defined in the base encoder. # Corruption used for data corruption during training if isinstance(input_corruption_model, str): self.input_corruption_model = load_dotted_path(input_corruption_model) else: self.input_corruption_model = input_corruption_model # Need to corrupt features used for training, to add 'noise' to training data (comparison features are unchanged) def process_train_features(self, features: Tensor) -> Tensor: return torch.tensor([self.input_corruption_model.corrupt_image(x) for x in features], dtype=torch.float32)