import logging import math import torch from typing import Optional from torch.nn.modules.loss import _Loss from models.base_encoder import BaseEncoder class SparseL1AutoEncoder(BaseEncoder): # Based on https://debuggercafe.com/sparse-autoencoders-using-l1-regularization-with-pytorch/ name = "SparseL1AutoEncoder" def __init__(self, name: Optional[str] = None, input_shape: int = 0, loss_function=None, regularization_parameter: float = 0.001): self.log = logging.getLogger(self.__class__.__name__) # Call superclass to initialize parameters. super(SparseL1AutoEncoder, self).__init__(name, input_shape, loss_function) # Override parameters to custom values for this encoder type # Sparse encoder has larger intermediary layers, so let's increase them 1.5 times in the first layer, # and 2 times in the second layer (compared to the original input shape self.network = torch.nn.Sequential( torch.nn.Linear(in_features=input_shape, out_features=math.floor(input_shape * 1.5)), torch.nn.ReLU(), torch.nn.Linear(in_features=math.floor(input_shape * 1.5), out_features=input_shape * 2), torch.nn.ReLU(), torch.nn.Linear(in_features=input_shape * 2, out_features=math.floor(input_shape * 1.5)), torch.nn.ReLU(), torch.nn.Linear(in_features=math.floor(input_shape * 1.5), out_features=input_shape), torch.nn.ReLU() ) # Adam optimizer with learning rate 1e-3 (parameters changed so we need to re-declare it) self.optimizer = torch.optim.Adam(self.parameters(), lr=1e-3) # Loss function is the same as defined in the base encoder. # Regularization parameter (lambda) for the L1 sparse loss function self.regularization_parameter = regularization_parameter def get_sparse_loss(self, images): def get_sparse_loss_rec(loss, values, children): for child in children: if isinstance(child, torch.nn.Sequential): loss, values = get_sparse_loss_rec(loss, values, [x for x in child]) elif isinstance(child, torch.nn.ReLU): values = child(values) loss += torch.mean(torch.abs(values)) elif isinstance(child, torch.nn.Linear): values = child(values) else: # Ignore unknown layers in sparse loss calculation pass return loss, values loss, values = get_sparse_loss_rec(loss=0, values=images, children=list(self.children())) return loss def process_loss(self, train_loss, features, outputs) -> _Loss: l1_loss = self.get_sparse_loss(features) # Add sparsity penalty return train_loss + (self.regularization_parameter * l1_loss)