RP_AutoEncoderComparison/models/contractive_encoder.py

89 lines
3.6 KiB
Python

import logging
import torch
from typing import Optional
from torch.autograd import Variable
from models.base_encoder import BaseEncoder
class ContractiveAutoEncoder(BaseEncoder):
# Based on https://github.com/avijit9/Contractive_Autoencoder_in_Pytorch/blob/master/CAE_pytorch.py
name = "ContractiveAutoEncoder"
def __init__(self, name: Optional[str] = None, input_shape: int = 0, regularizer_weight: float = 1e-4):
self.log = logging.getLogger(self.__class__.__name__)
# Call superclass to initialize parameters.
super(ContractiveAutoEncoder, self).__init__(name, input_shape)
self.regularizer_weight = regularizer_weight
# CAE needs intermediate output of the encoder stage, so split up the network into encoder/decoder
self.network = None
self.encoder = torch.nn.Sequential(
torch.nn.Linear(in_features=input_shape, out_features=input_shape // 2, bias=False),
torch.nn.ReLU(),
torch.nn.Linear(in_features=input_shape // 2, out_features=input_shape // 4, bias=False),
torch.nn.ReLU()
)
self.decoder = torch.nn.Sequential(
torch.nn.Linear(in_features=input_shape // 4, out_features=input_shape // 2, bias=False),
torch.nn.ReLU(),
torch.nn.Linear(in_features=input_shape // 2, out_features=input_shape, bias=False),
torch.nn.ReLU()
)
# Adam optimizer with learning rate 1e-3 (parameters changed so we need to re-declare it)
self.optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
# Loss function is the same as defined in the base encoder.
# Because network is split up now, and a CAE needs the intermediate representations,
# we need to override the forward method and return the intermediate state as well.
def forward(self, features):
encoder_out = self.encoder(features)
decoder_out = self.decoder(encoder_out)
return encoder_out, decoder_out
# Because the model returns the intermediate state as well as the output,
# we need to override some output processing functions
def process_outputs_for_loss_function(self, outputs):
return outputs[1]
def process_outputs_for_testing(self, outputs):
return outputs[1]
# Finally we define the loss process function, which adds the contractive loss.
def process_loss(self, train_loss, features, outputs):
"""
Evaluates the CAE loss, which is the summation of the MSE and the weighted L2-norm of the Jacobian of the
hidden units with respect to the inputs.
Reference: http://wiseodd.github.io/techblog/2016/12/05/contractive-autoencoder
:param train_loss: The (MSE) loss as returned by the loss function of the model
:param features: The input features
:param outputs: The raw outputs as returned by the model (in this case includes the hidden encoder output)
"""
hidden_output = outputs[0]
# Weights of the second Linear layer in the encoder (index 2)
weights = self.state_dict()['encoder.2.weight']
# Hadamard product
hidden_output = hidden_output.reshape(hidden_output.shape[0], hidden_output.shape[2])
dh = hidden_output * (1 - hidden_output)
# Sum through input dimension to improve efficiency (suggested in reference)
w_sum = torch.sum(Variable(weights) ** 2, dim=1)
# Unsqueeze to avoid issues with torch.mv
w_sum = w_sum.unsqueeze(1)
# Calculate contractive loss
contractive_loss = torch.sum(torch.mm(dh ** 2, w_sum), 0)
return train_loss + contractive_loss.mul_(self.regularizer_weight)