import logging import torch from typing import Optional from torch.autograd import Variable from models.base_encoder import BaseEncoder class ContractiveAutoEncoder(BaseEncoder): # Based on https://github.com/avijit9/Contractive_Autoencoder_in_Pytorch/blob/master/CAE_pytorch.py name = "ContractiveAutoEncoder" def __init__(self, name: Optional[str] = None, input_shape: int = 0, regularizer_weight: float = 1e-4): self.log = logging.getLogger(self.__class__.__name__) # Call superclass to initialize parameters. super(ContractiveAutoEncoder, self).__init__(name, input_shape) self.regularizer_weight = regularizer_weight # CAE needs intermediate output of the encoder stage, so split up the network into encoder/decoder self.network = None self.encoder = torch.nn.Sequential( torch.nn.Linear(in_features=input_shape, out_features=input_shape // 2, bias=False), torch.nn.ReLU(), torch.nn.Linear(in_features=input_shape // 2, out_features=input_shape // 4, bias=False), torch.nn.ReLU() ) self.decoder = torch.nn.Sequential( torch.nn.Linear(in_features=input_shape // 4, out_features=input_shape // 2, bias=False), torch.nn.ReLU(), torch.nn.Linear(in_features=input_shape // 2, out_features=input_shape, bias=False), torch.nn.ReLU() ) # Adam optimizer with learning rate 1e-3 (parameters changed so we need to re-declare it) self.optimizer = torch.optim.Adam(self.parameters(), lr=1e-3) # Loss function is the same as defined in the base encoder. # Because network is split up now, and a CAE needs the intermediate representations, # we need to override the forward method and return the intermediate state as well. def forward(self, features): encoder_out = self.encoder(features) decoder_out = self.decoder(encoder_out) return encoder_out, decoder_out # Because the model returns the intermediate state as well as the output, # we need to override some output processing functions def process_outputs_for_loss_function(self, outputs): return outputs[1] def process_outputs_for_testing(self, outputs): return outputs[1] # Finally we define the loss process function, which adds the contractive loss. def process_loss(self, train_loss, features, outputs): """ Evaluates the CAE loss, which is the summation of the MSE and the weighted L2-norm of the Jacobian of the hidden units with respect to the inputs. Reference: http://wiseodd.github.io/techblog/2016/12/05/contractive-autoencoder :param train_loss: The (MSE) loss as returned by the loss function of the model :param features: The input features :param outputs: The raw outputs as returned by the model (in this case includes the hidden encoder output) """ hidden_output = outputs[0] # Weights of the second Linear layer in the encoder (index 2) weights = self.state_dict()['encoder.2.weight'] # Hadamard product hidden_output = hidden_output.reshape(hidden_output.shape[0], hidden_output.shape[2]) dh = hidden_output * (1 - hidden_output) # Sum through input dimension to improve efficiency (suggested in reference) w_sum = torch.sum(Variable(weights) ** 2, dim=1) # Unsqueeze to avoid issues with torch.mv w_sum = w_sum.unsqueeze(1) # Calculate contractive loss contractive_loss = torch.sum(torch.mm(dh ** 2, w_sum), 0) return train_loss + contractive_loss.mul_(self.regularizer_weight)