RP_AutoEncoderComparison/models/sparse_encoder.py

69 lines
2.8 KiB
Python

import logging
import math
import torch
from typing import Optional
from torch.nn.modules.loss import _Loss
from models.base_encoder import BaseEncoder
class SparseL1AutoEncoder(BaseEncoder):
# Based on https://debuggercafe.com/sparse-autoencoders-using-l1-regularization-with-pytorch/
name = "SparseL1AutoEncoder"
def __init__(self, name: Optional[str] = None, input_shape: int = 0, regularization_parameter: float = 0.001):
self.log = logging.getLogger(self.__class__.__name__)
# Call superclass to initialize parameters.
super(SparseL1AutoEncoder, self).__init__(name, input_shape)
# Override parameters to custom values for this encoder type
# Sparse encoder has larger intermediary layers, so let's increase them 1.5 times in the first layer,
# and 2 times in the second layer (compared to the original input shape
self.network = torch.nn.Sequential(
torch.nn.Linear(in_features=input_shape, out_features=math.floor(input_shape * 1.5)),
torch.nn.ReLU(),
torch.nn.Linear(in_features=math.floor(input_shape * 1.5), out_features=input_shape * 2),
torch.nn.ReLU(),
torch.nn.Linear(in_features=input_shape * 2, out_features=math.floor(input_shape * 1.5)),
torch.nn.ReLU(),
torch.nn.Linear(in_features=math.floor(input_shape * 1.5), out_features=input_shape),
torch.nn.ReLU()
)
# Adam optimizer with learning rate 1e-3 (parameters changed so we need to re-declare it)
self.optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
# Loss function is the same as defined in the base encoder.
# Regularization parameter (lambda) for the L1 sparse loss function
self.regularization_parameter = regularization_parameter
def get_sparse_loss(self, images):
def get_sparse_loss_rec(loss, values, children):
for child in children:
if isinstance(child, torch.nn.Sequential):
loss, values = get_sparse_loss_rec(loss, values, [x for x in child])
elif isinstance(child, torch.nn.ReLU):
values = child(values)
loss += torch.mean(torch.abs(values))
elif isinstance(child, torch.nn.Linear):
values = child(values)
else:
# Ignore unknown layers in sparse loss calculation
pass
return loss, values
loss, values = get_sparse_loss_rec(loss=0, values=images, children=list(self.children()))
return loss
def process_loss(self, train_loss, features, outputs) -> _Loss:
l1_loss = self.get_sparse_loss(features)
# Add sparsity penalty
return train_loss + (self.regularization_parameter * l1_loss)