Implement types of auto-encoders and corruption, use log scale in loss graphs, lots of helper function hooks in training process to allow implementations
- Encoders: sparse, denoising, contractive and variational - Noise: gaussian
This commit is contained in:
parent
62f9b873e9
commit
fb9ce46bd8
|
@ -1,5 +1,7 @@
|
||||||
MODEL_STORAGE_BASE_PATH = "/path/to/this/project/saved_models"
|
MODEL_STORAGE_BASE_PATH = "/path/to/this/project/saved_models"
|
||||||
DATASET_STORAGE_BASE_PATH = "/path/to/this/project/datasets"
|
DATASET_STORAGE_BASE_PATH = "/path/to/this/project/datasets"
|
||||||
|
TRAIN_TEMP_DATA_BASE_PATH = "/path/to/this/project/train_temp"
|
||||||
|
TEST_TEMP_DATA_BASE_PATH = "/path/to/this/project/test_temp"
|
||||||
|
|
||||||
TEST_RUNS = [
|
TEST_RUNS = [
|
||||||
{
|
{
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
[loggers]
|
[loggers]
|
||||||
keys=root
|
keys=root,matplotlib
|
||||||
|
|
||||||
[handlers]
|
[handlers]
|
||||||
keys=consoleHandler,fileHandler
|
keys=consoleHandler,fileHandler
|
||||||
|
@ -11,6 +11,12 @@ keys=simpleFormatter
|
||||||
level=DEBUG
|
level=DEBUG
|
||||||
handlers=consoleHandler,fileHandler
|
handlers=consoleHandler,fileHandler
|
||||||
|
|
||||||
|
[logger_matplotlib]
|
||||||
|
level=NOTSET
|
||||||
|
handlers=
|
||||||
|
propagate=0
|
||||||
|
qualname=matplotlib
|
||||||
|
|
||||||
[handler_fileHandler]
|
[handler_fileHandler]
|
||||||
class=FileHandler
|
class=FileHandler
|
||||||
level=DEBUG
|
level=DEBUG
|
||||||
|
|
18
main.py
18
main.py
|
@ -1,11 +1,10 @@
|
||||||
import importlib
|
|
||||||
|
|
||||||
import config
|
import config
|
||||||
import logging.config
|
import logging.config
|
||||||
|
|
||||||
# Get logging as early as possible!
|
# Get logging as early as possible!
|
||||||
logging.config.fileConfig("logging.conf")
|
logging.config.fileConfig("logging.conf")
|
||||||
|
|
||||||
|
from utils import load_dotted_path
|
||||||
|
|
||||||
from models.base_corruption import BaseCorruption
|
from models.base_corruption import BaseCorruption
|
||||||
from models.base_dataset import BaseDataset
|
from models.base_dataset import BaseDataset
|
||||||
|
@ -13,13 +12,6 @@ from models.base_encoder import BaseEncoder
|
||||||
from models.test_run import TestRun
|
from models.test_run import TestRun
|
||||||
|
|
||||||
|
|
||||||
def load_dotted_path(path):
|
|
||||||
split_path = path.split(".")
|
|
||||||
modulename, classname = ".".join(split_path[:-1]), split_path[-1]
|
|
||||||
model = getattr(importlib.import_module(modulename), classname)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def run_tests():
|
def run_tests():
|
||||||
logger = logging.getLogger("main.run_tests")
|
logger = logging.getLogger("main.run_tests")
|
||||||
for test in config.TEST_RUNS:
|
for test in config.TEST_RUNS:
|
||||||
|
@ -41,9 +33,11 @@ def run_tests():
|
||||||
logger.debug(f"Using corruption model '{corruption_model.__name__}'")
|
logger.debug(f"Using corruption model '{corruption_model.__name__}'")
|
||||||
|
|
||||||
# Create TestRun instance
|
# Create TestRun instance
|
||||||
test_run = TestRun(dataset=dataset_model(**test['dataset_kwargs']),
|
dataset = dataset_model(**test['dataset_kwargs'])
|
||||||
encoder=encoder_model(**test['encoder_kwargs']),
|
encoder = encoder_model(**test['encoder_kwargs'])
|
||||||
corruption=corruption_model(**test['corruption_kwargs']))
|
encoder.after_init()
|
||||||
|
corruption = corruption_model(**test['corruption_kwargs'])
|
||||||
|
test_run = TestRun(dataset=dataset, encoder=encoder, corruption=corruption)
|
||||||
|
|
||||||
# Run TestRun
|
# Run TestRun
|
||||||
test_run.run(retrain=False)
|
test_run.run(retrain=False)
|
||||||
|
|
|
@ -15,7 +15,11 @@ class BaseCorruption:
|
||||||
return f"{self.name}"
|
return f"{self.name}"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def corrupt(cls, dataset: BaseDataset) -> BaseDataset:
|
def corrupt_image(cls, image):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def corrupt_dataset(cls, dataset: BaseDataset) -> BaseDataset:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
@ -26,5 +30,9 @@ class NoCorruption(BaseCorruption):
|
||||||
name = "No corruption"
|
name = "No corruption"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def corrupt(cls, dataset: BaseDataset) -> BaseDataset:
|
def corrupt_image(cls, image):
|
||||||
|
return image
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def corrupt_dataset(cls, dataset: BaseDataset) -> BaseDataset:
|
||||||
return dataset
|
return dataset
|
||||||
|
|
|
@ -78,6 +78,12 @@ class BaseDataset(Dataset):
|
||||||
self._trainset = self.__class__.get_new(name=f"{self.name} Training", data=train_data, source_path=self._source_path)
|
self._trainset = self.__class__.get_new(name=f"{self.name} Training", data=train_data, source_path=self._source_path)
|
||||||
self._testset = self.__class__.get_new(name=f"{self.name} Testing", data=test_data, source_path=self._source_path)
|
self._testset = self.__class__.get_new(name=f"{self.name} Testing", data=test_data, source_path=self._source_path)
|
||||||
|
|
||||||
|
def has_train(self):
|
||||||
|
return self._trainset is not None
|
||||||
|
|
||||||
|
def has_test(self):
|
||||||
|
return self._testset is not None
|
||||||
|
|
||||||
def get_train(self) -> 'BaseDataset':
|
def get_train(self) -> 'BaseDataset':
|
||||||
if not self._trainset or not self._testset:
|
if not self._trainset or not self._testset:
|
||||||
self._subdivide(self.TRAIN_AMOUNT)
|
self._subdivide(self.TRAIN_AMOUNT)
|
||||||
|
|
|
@ -6,6 +6,7 @@ import torch
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
from torch.nn.modules.loss import _Loss
|
||||||
from torchvision.utils import save_image
|
from torchvision.utils import save_image
|
||||||
|
|
||||||
from config import TRAIN_TEMP_DATA_BASE_PATH, TEST_TEMP_DATA_BASE_PATH, MODEL_STORAGE_BASE_PATH
|
from config import TRAIN_TEMP_DATA_BASE_PATH, TEST_TEMP_DATA_BASE_PATH, MODEL_STORAGE_BASE_PATH
|
||||||
|
@ -37,7 +38,8 @@ class BaseEncoder(torch.nn.Module):
|
||||||
torch.nn.ReLU(),
|
torch.nn.ReLU(),
|
||||||
torch.nn.Linear(in_features=input_shape // 4, out_features=input_shape // 2),
|
torch.nn.Linear(in_features=input_shape // 4, out_features=input_shape // 2),
|
||||||
torch.nn.ReLU(),
|
torch.nn.ReLU(),
|
||||||
torch.nn.Linear(in_features=input_shape // 2, out_features=input_shape)
|
torch.nn.Linear(in_features=input_shape // 2, out_features=input_shape),
|
||||||
|
torch.nn.ReLU()
|
||||||
)
|
)
|
||||||
|
|
||||||
# Use GPU acceleration if available
|
# Use GPU acceleration if available
|
||||||
|
@ -50,7 +52,10 @@ class BaseEncoder(torch.nn.Module):
|
||||||
self.loss_function = torch.nn.MSELoss()
|
self.loss_function = torch.nn.MSELoss()
|
||||||
|
|
||||||
def after_init(self):
|
def after_init(self):
|
||||||
self.log.info(f"Auto-encoder {self.__class__.__name__} initialized with {len(list(self.network.children()))} layers on {self.device.type}. Optimizer: {self.optimizer}, Loss function: {self.loss_function}")
|
self.log.info(f"Auto-encoder {self.__class__.__name__} initialized with "
|
||||||
|
f"{len(list(self.network.children())) if self.network else 'custom'} layers on "
|
||||||
|
f"{self.device.type}. Optimizer: {self.optimizer.__class__.__name__}, "
|
||||||
|
f"Loss function: {self.loss_function.__class__.__name__}")
|
||||||
|
|
||||||
def forward(self, features):
|
def forward(self, features):
|
||||||
return self.network(features)
|
return self.network(features)
|
||||||
|
@ -109,21 +114,33 @@ class BaseEncoder(torch.nn.Module):
|
||||||
|
|
||||||
outputs = None
|
outputs = None
|
||||||
for epoch in range(epochs):
|
for epoch in range(epochs):
|
||||||
self.log.debug(f"Start training epoch {epoch}...")
|
self.log.debug(f"Start training epoch {epoch + 1}...")
|
||||||
loss = 0
|
loss = 0
|
||||||
for batch_features in train_loader:
|
for i, batch_features in enumerate(train_loader):
|
||||||
# load batch features to the active device
|
# # load batch features to the active device
|
||||||
batch_features = batch_features.to(self.device)
|
# batch_features = batch_features.to(self.device)
|
||||||
|
|
||||||
# reset the gradients back to zero
|
# reset the gradients back to zero
|
||||||
# PyTorch accumulates gradients on subsequent backward passes
|
# PyTorch accumulates gradients on subsequent backward passes
|
||||||
self.optimizer.zero_grad()
|
self.optimizer.zero_grad()
|
||||||
|
|
||||||
|
# Modify features used in training model (if necessary) and load to the active device
|
||||||
|
train_features = self.process_train_features(batch_features).to(self.device)
|
||||||
|
|
||||||
# compute reconstructions
|
# compute reconstructions
|
||||||
outputs = self(batch_features)
|
outputs = self(train_features)
|
||||||
|
|
||||||
|
# Modify outputs used in loss function (if necessary) and load to the active device
|
||||||
|
outputs_for_loss = self.process_outputs_for_loss_function(outputs).to(self.device)
|
||||||
|
|
||||||
|
# Modify features used in comparing in loss function (if necessary) and load to the active device
|
||||||
|
compare_features = self.process_compare_features(batch_features).to(self.device)
|
||||||
|
|
||||||
# compute training reconstruction loss
|
# compute training reconstruction loss
|
||||||
train_loss = self.loss_function(outputs, batch_features)
|
train_loss = self.loss_function(outputs_for_loss, compare_features)
|
||||||
|
|
||||||
|
# Process loss if necessary (default implementation does nothing)
|
||||||
|
train_loss = self.process_loss(train_loss, compare_features, outputs)
|
||||||
|
|
||||||
# compute accumulated gradients
|
# compute accumulated gradients
|
||||||
train_loss.backward()
|
train_loss.backward()
|
||||||
|
@ -134,6 +151,11 @@ class BaseEncoder(torch.nn.Module):
|
||||||
# add the mini-batch training loss to epoch loss
|
# add the mini-batch training loss to epoch loss
|
||||||
loss += train_loss.item()
|
loss += train_loss.item()
|
||||||
|
|
||||||
|
# Print progress every 50 batches
|
||||||
|
if i % 50 == 0:
|
||||||
|
self.log.debug(f" progress: [{i * len(batch_features)}/{len(train_loader.dataset)} "
|
||||||
|
f"({(100 * i / len(train_loader)):.0f}%)]")
|
||||||
|
|
||||||
# compute the epoch training loss
|
# compute the epoch training loss
|
||||||
loss = loss / len(train_loader)
|
loss = loss / len(train_loader)
|
||||||
|
|
||||||
|
@ -143,7 +165,7 @@ class BaseEncoder(torch.nn.Module):
|
||||||
|
|
||||||
# Every 5 epochs, save a test image
|
# Every 5 epochs, save a test image
|
||||||
if epoch % 5 == 0:
|
if epoch % 5 == 0:
|
||||||
img = outputs.cpu().data
|
img = self.process_outputs_for_testing(outputs).cpu().data
|
||||||
img = img.view(img.size(0), 3, 32, 32)
|
img = img.view(img.size(0), 3, 32, 32)
|
||||||
save_image(img, os.path.join(TRAIN_TEMP_DATA_BASE_PATH,
|
save_image(img, os.path.join(TRAIN_TEMP_DATA_BASE_PATH,
|
||||||
f'{self.name}_{dataset.name}_linear_ae_image{epoch}.png'))
|
f'{self.name}_{dataset.name}_linear_ae_image{epoch}.png'))
|
||||||
|
@ -163,10 +185,25 @@ class BaseEncoder(torch.nn.Module):
|
||||||
f'{self.name}_{dataset.name}_test_input_{i}.png'))
|
f'{self.name}_{dataset.name}_test_input_{i}.png'))
|
||||||
# load batch features to the active device
|
# load batch features to the active device
|
||||||
batch = batch.to(self.device)
|
batch = batch.to(self.device)
|
||||||
outputs = self(batch)
|
outputs = self.process_outputs_for_testing(self(batch))
|
||||||
img = outputs.cpu().data
|
img = outputs.cpu().data
|
||||||
img = img.view(outputs.size(0), 3, 32, 32)
|
img = img.view(outputs.size(0), 3, 32, 32)
|
||||||
save_image(img, os.path.join(TEST_TEMP_DATA_BASE_PATH,
|
save_image(img, os.path.join(TEST_TEMP_DATA_BASE_PATH,
|
||||||
f'{self.name}_{dataset.name}_test_reconstruction_{i}.png'))
|
f'{self.name}_{dataset.name}_test_reconstruction_{i}.png'))
|
||||||
i += 1
|
i += 1
|
||||||
break
|
break
|
||||||
|
|
||||||
|
def process_loss(self, train_loss, features, outputs) -> _Loss:
|
||||||
|
return train_loss
|
||||||
|
|
||||||
|
def process_train_features(self, features):
|
||||||
|
return features
|
||||||
|
|
||||||
|
def process_compare_features(self, features):
|
||||||
|
return features
|
||||||
|
|
||||||
|
def process_outputs_for_loss_function(self, outputs):
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
def process_outputs_for_testing(self, outputs):
|
||||||
|
return outputs
|
||||||
|
|
|
@ -1,8 +1,4 @@
|
||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
@ -19,30 +15,4 @@ class BasicAutoEncoder(BaseEncoder):
|
||||||
# Call superclass to initialize parameters.
|
# Call superclass to initialize parameters.
|
||||||
super(BasicAutoEncoder, self).__init__(name, input_shape)
|
super(BasicAutoEncoder, self).__init__(name, input_shape)
|
||||||
|
|
||||||
# Override parameters to custom values for this encoder type
|
# Network, optimizer and loss function are the same as defined in the base encoder.
|
||||||
|
|
||||||
# TODO - Hoe kan ik het beste bepalen hoe groot de intermediate layers moeten zijn?
|
|
||||||
# - Proportioneel van input grootte naar opgegeven bottleneck grootte?
|
|
||||||
# - Uit een paper plukken
|
|
||||||
# - Zelf kiezen (e.g. helft elke keer, fixed aantal layers)?
|
|
||||||
self.network = torch.nn.Sequential(
|
|
||||||
torch.nn.Linear(in_features=input_shape, out_features=input_shape // 2),
|
|
||||||
torch.nn.ReLU(),
|
|
||||||
torch.nn.Linear(in_features=input_shape // 2, out_features=input_shape // 4),
|
|
||||||
torch.nn.ReLU(),
|
|
||||||
torch.nn.Linear(in_features=input_shape // 4, out_features=input_shape // 2),
|
|
||||||
torch.nn.ReLU(),
|
|
||||||
torch.nn.Linear(in_features=input_shape // 2, out_features=input_shape)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Use GPU acceleration if available
|
|
||||||
# self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
||||||
|
|
||||||
# Adam optimizer with learning rate 1e-3
|
|
||||||
self.optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
|
|
||||||
|
|
||||||
# Mean Squared Error loss function
|
|
||||||
# self.loss_function = torch.nn.MSELoss()
|
|
||||||
|
|
||||||
self.after_init()
|
|
||||||
|
|
||||||
|
|
|
@ -3,8 +3,6 @@ import os
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import numpy
|
import numpy
|
||||||
import torchvision
|
|
||||||
from PIL import Image
|
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
|
|
||||||
from config import DATASET_STORAGE_BASE_PATH
|
from config import DATASET_STORAGE_BASE_PATH
|
||||||
|
@ -12,6 +10,7 @@ from models.base_dataset import BaseDataset
|
||||||
|
|
||||||
|
|
||||||
class Cifar10Dataset(BaseDataset):
|
class Cifar10Dataset(BaseDataset):
|
||||||
|
name = "CIFAR-10"
|
||||||
|
|
||||||
# transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),
|
# transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),
|
||||||
# torchvision.transforms.Normalize((0.5, ), (0.5, ))
|
# torchvision.transforms.Normalize((0.5, ), (0.5, ))
|
||||||
|
|
88
models/contractive_encoder.py
Normal file
88
models/contractive_encoder.py
Normal file
|
@ -0,0 +1,88 @@
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from torch.autograd import Variable
|
||||||
|
|
||||||
|
from models.base_encoder import BaseEncoder
|
||||||
|
|
||||||
|
|
||||||
|
class ContractiveAutoEncoder(BaseEncoder):
|
||||||
|
# Based on https://github.com/avijit9/Contractive_Autoencoder_in_Pytorch/blob/master/CAE_pytorch.py
|
||||||
|
name = "ContractiveAutoEncoder"
|
||||||
|
|
||||||
|
def __init__(self, name: Optional[str] = None, input_shape: int = 0, regularizer_weight: float = 1e-4):
|
||||||
|
self.log = logging.getLogger(self.__class__.__name__)
|
||||||
|
|
||||||
|
# Call superclass to initialize parameters.
|
||||||
|
super(ContractiveAutoEncoder, self).__init__(name, input_shape)
|
||||||
|
|
||||||
|
self.regularizer_weight = regularizer_weight
|
||||||
|
|
||||||
|
# CAE needs intermediate output of the encoder stage, so split up the network into encoder/decoder
|
||||||
|
self.network = None
|
||||||
|
self.encoder = torch.nn.Sequential(
|
||||||
|
torch.nn.Linear(in_features=input_shape, out_features=input_shape // 2, bias=False),
|
||||||
|
torch.nn.ReLU(),
|
||||||
|
torch.nn.Linear(in_features=input_shape // 2, out_features=input_shape // 4, bias=False),
|
||||||
|
torch.nn.ReLU()
|
||||||
|
)
|
||||||
|
self.decoder = torch.nn.Sequential(
|
||||||
|
torch.nn.Linear(in_features=input_shape // 4, out_features=input_shape // 2, bias=False),
|
||||||
|
torch.nn.ReLU(),
|
||||||
|
torch.nn.Linear(in_features=input_shape // 2, out_features=input_shape, bias=False),
|
||||||
|
torch.nn.ReLU()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Adam optimizer with learning rate 1e-3 (parameters changed so we need to re-declare it)
|
||||||
|
self.optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
|
||||||
|
|
||||||
|
# Loss function is the same as defined in the base encoder.
|
||||||
|
|
||||||
|
# Because network is split up now, and a CAE needs the intermediate representations,
|
||||||
|
# we need to override the forward method and return the intermediate state as well.
|
||||||
|
def forward(self, features):
|
||||||
|
encoder_out = self.encoder(features)
|
||||||
|
decoder_out = self.decoder(encoder_out)
|
||||||
|
return encoder_out, decoder_out
|
||||||
|
|
||||||
|
# Because the model returns the intermediate state as well as the output,
|
||||||
|
# we need to override some output processing functions
|
||||||
|
def process_outputs_for_loss_function(self, outputs):
|
||||||
|
return outputs[1]
|
||||||
|
|
||||||
|
def process_outputs_for_testing(self, outputs):
|
||||||
|
return outputs[1]
|
||||||
|
|
||||||
|
# Finally we define the loss process function, which adds the contractive loss.
|
||||||
|
def process_loss(self, train_loss, features, outputs):
|
||||||
|
"""
|
||||||
|
Evaluates the CAE loss, which is the summation of the MSE and the weighted L2-norm of the Jacobian of the
|
||||||
|
hidden units with respect to the inputs.
|
||||||
|
|
||||||
|
Reference: http://wiseodd.github.io/techblog/2016/12/05/contractive-autoencoder
|
||||||
|
|
||||||
|
:param train_loss: The (MSE) loss as returned by the loss function of the model
|
||||||
|
:param features: The input features
|
||||||
|
:param outputs: The raw outputs as returned by the model (in this case includes the hidden encoder output)
|
||||||
|
"""
|
||||||
|
hidden_output = outputs[0]
|
||||||
|
# Weights of the second Linear layer in the encoder (index 2)
|
||||||
|
weights = self.state_dict()['encoder.2.weight']
|
||||||
|
|
||||||
|
# Hadamard product
|
||||||
|
hidden_output = hidden_output.reshape(hidden_output.shape[0], hidden_output.shape[2])
|
||||||
|
dh = hidden_output * (1 - hidden_output)
|
||||||
|
|
||||||
|
# Sum through input dimension to improve efficiency (suggested in reference)
|
||||||
|
w_sum = torch.sum(Variable(weights) ** 2, dim=1)
|
||||||
|
|
||||||
|
# Unsqueeze to avoid issues with torch.mv
|
||||||
|
w_sum = w_sum.unsqueeze(1)
|
||||||
|
|
||||||
|
# Calculate contractive loss
|
||||||
|
contractive_loss = torch.sum(torch.mm(dh ** 2, w_sum), 0)
|
||||||
|
|
||||||
|
return train_loss + contractive_loss.mul_(self.regularizer_weight)
|
35
models/denoising_encoder.py
Normal file
35
models/denoising_encoder.py
Normal file
|
@ -0,0 +1,35 @@
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
from main import load_dotted_path
|
||||||
|
from models.base_corruption import BaseCorruption, NoCorruption
|
||||||
|
from models.base_encoder import BaseEncoder
|
||||||
|
|
||||||
|
|
||||||
|
class DenoisingAutoEncoder(BaseEncoder):
|
||||||
|
# Based on https://github.com/pranjaldatta/Denoising-Autoencoder-in-Pytorch/blob/master/DenoisingAutoencoder.ipynb
|
||||||
|
name = "DenoisingAutoEncoder"
|
||||||
|
|
||||||
|
def __init__(self, name: Optional[str] = None, input_shape: int = 0,
|
||||||
|
input_corruption_model: BaseCorruption = NoCorruption):
|
||||||
|
self.log = logging.getLogger(self.__class__.__name__)
|
||||||
|
|
||||||
|
# Call superclass to initialize parameters.
|
||||||
|
super(DenoisingAutoEncoder, self).__init__(name, input_shape)
|
||||||
|
|
||||||
|
# Network, optimizer and loss function are the same as defined in the base encoder.
|
||||||
|
|
||||||
|
# Corruption used for data corruption during training
|
||||||
|
if isinstance(input_corruption_model, str):
|
||||||
|
self.input_corruption_model = load_dotted_path(input_corruption_model)
|
||||||
|
else:
|
||||||
|
self.input_corruption_model = input_corruption_model
|
||||||
|
|
||||||
|
# Need to corrupt features used for training, to add 'noise' to training data (comparison features are unchanged)
|
||||||
|
def process_train_features(self, features: Tensor) -> Tensor:
|
||||||
|
return torch.tensor([self.input_corruption_model.corrupt_image(x) for x in features], dtype=torch.float32)
|
36
models/gaussian_corruption.py
Normal file
36
models/gaussian_corruption.py
Normal file
|
@ -0,0 +1,36 @@
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
from models.base_corruption import BaseCorruption
|
||||||
|
from models.base_dataset import BaseDataset
|
||||||
|
import numpy
|
||||||
|
|
||||||
|
|
||||||
|
def add_noise(image):
|
||||||
|
image = image.astype(numpy.float32)
|
||||||
|
mean, variance = 0, 0.1
|
||||||
|
sigma = variance ** 0.5
|
||||||
|
noise = numpy.random.normal(mean, sigma, image.shape).reshape(image.shape)
|
||||||
|
return image + noise
|
||||||
|
|
||||||
|
|
||||||
|
class GaussianCorruption(BaseCorruption):
|
||||||
|
"""
|
||||||
|
Corruption model that adds Gaussian noise to the dataset.
|
||||||
|
"""
|
||||||
|
name = "Gaussian"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def corrupt_image(cls, image: Tensor):
|
||||||
|
return add_noise(image.numpy())
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def corrupt_dataset(cls, dataset: BaseDataset) -> BaseDataset:
|
||||||
|
data = list(map(add_noise, dataset))
|
||||||
|
train_set = cls.corrupt_dataset(dataset.get_train()) if dataset.has_train() else None
|
||||||
|
test_set = cls.corrupt_dataset(dataset.get_test()) if dataset.has_test() else None
|
||||||
|
return dataset.__class__.get_new(
|
||||||
|
name=f"{dataset.name} Corrupted",
|
||||||
|
data=data,
|
||||||
|
source_path=dataset._source_path,
|
||||||
|
train_set=train_set,
|
||||||
|
test_set=test_set)
|
68
models/sparse_encoder.py
Normal file
68
models/sparse_encoder.py
Normal file
|
@ -0,0 +1,68 @@
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from torch.nn.modules.loss import _Loss
|
||||||
|
|
||||||
|
from models.base_encoder import BaseEncoder
|
||||||
|
|
||||||
|
|
||||||
|
class SparseL1AutoEncoder(BaseEncoder):
|
||||||
|
# Based on https://debuggercafe.com/sparse-autoencoders-using-l1-regularization-with-pytorch/
|
||||||
|
name = "SparseL1AutoEncoder"
|
||||||
|
|
||||||
|
def __init__(self, name: Optional[str] = None, input_shape: int = 0, regularization_parameter: float = 0.001):
|
||||||
|
self.log = logging.getLogger(self.__class__.__name__)
|
||||||
|
|
||||||
|
# Call superclass to initialize parameters.
|
||||||
|
super(SparseL1AutoEncoder, self).__init__(name, input_shape)
|
||||||
|
|
||||||
|
# Override parameters to custom values for this encoder type
|
||||||
|
|
||||||
|
# Sparse encoder has larger intermediary layers, so let's increase them 1.5 times in the first layer,
|
||||||
|
# and 2 times in the second layer (compared to the original input shape
|
||||||
|
self.network = torch.nn.Sequential(
|
||||||
|
torch.nn.Linear(in_features=input_shape, out_features=math.floor(input_shape * 1.5)),
|
||||||
|
torch.nn.ReLU(),
|
||||||
|
torch.nn.Linear(in_features=math.floor(input_shape * 1.5), out_features=input_shape * 2),
|
||||||
|
torch.nn.ReLU(),
|
||||||
|
torch.nn.Linear(in_features=input_shape * 2, out_features=math.floor(input_shape * 1.5)),
|
||||||
|
torch.nn.ReLU(),
|
||||||
|
torch.nn.Linear(in_features=math.floor(input_shape * 1.5), out_features=input_shape),
|
||||||
|
torch.nn.ReLU()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Adam optimizer with learning rate 1e-3 (parameters changed so we need to re-declare it)
|
||||||
|
self.optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
|
||||||
|
|
||||||
|
# Loss function is the same as defined in the base encoder.
|
||||||
|
|
||||||
|
# Regularization parameter (lambda) for the L1 sparse loss function
|
||||||
|
self.regularization_parameter = regularization_parameter
|
||||||
|
|
||||||
|
def get_sparse_loss(self, images):
|
||||||
|
def get_sparse_loss_rec(loss, values, children):
|
||||||
|
for child in children:
|
||||||
|
if isinstance(child, torch.nn.Sequential):
|
||||||
|
loss, values = get_sparse_loss_rec(loss, values, [x for x in child])
|
||||||
|
elif isinstance(child, torch.nn.ReLU):
|
||||||
|
values = child(values)
|
||||||
|
loss += torch.mean(torch.abs(values))
|
||||||
|
elif isinstance(child, torch.nn.Linear):
|
||||||
|
values = child(values)
|
||||||
|
else:
|
||||||
|
# Ignore unknown layers in sparse loss calculation
|
||||||
|
pass
|
||||||
|
return loss, values
|
||||||
|
|
||||||
|
loss, values = get_sparse_loss_rec(loss=0, values=images, children=list(self.children()))
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
def process_loss(self, train_loss, features, outputs) -> _Loss:
|
||||||
|
l1_loss = self.get_sparse_loss(features)
|
||||||
|
# Add sparsity penalty
|
||||||
|
return train_loss + (self.regularization_parameter * l1_loss)
|
|
@ -4,7 +4,7 @@ import multiprocessing
|
||||||
from models.base_corruption import BaseCorruption
|
from models.base_corruption import BaseCorruption
|
||||||
from models.base_dataset import BaseDataset
|
from models.base_dataset import BaseDataset
|
||||||
from models.base_encoder import BaseEncoder
|
from models.base_encoder import BaseEncoder
|
||||||
from output_utils import save_train_loss_graph
|
from utils import save_train_loss_graph
|
||||||
|
|
||||||
|
|
||||||
class TestRun:
|
class TestRun:
|
||||||
|
@ -49,7 +49,7 @@ class TestRun:
|
||||||
|
|
||||||
# Save train loss graph
|
# Save train loss graph
|
||||||
self.log.info("Saving loss graph...")
|
self.log.info("Saving loss graph...")
|
||||||
save_train_loss_graph(train_loss, self.dataset.name)
|
save_train_loss_graph(train_loss, f"{self.encoder.name}_{self.dataset.name}")
|
||||||
else:
|
else:
|
||||||
self.log.info("Loading saved auto-encoder...")
|
self.log.info("Loading saved auto-encoder...")
|
||||||
load_success = self.encoder.load_model(f"{self.encoder.name}_{self.dataset.name}")
|
load_success = self.encoder.load_model(f"{self.encoder.name}_{self.dataset.name}")
|
||||||
|
|
84
models/variational_encoder.py
Normal file
84
models/variational_encoder.py
Normal file
|
@ -0,0 +1,84 @@
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from models.base_encoder import BaseEncoder
|
||||||
|
|
||||||
|
|
||||||
|
class VariationalAutoEncoder(BaseEncoder):
|
||||||
|
# Based on https://debuggercafe.com/getting-started-with-variational-autoencoder-using-pytorch/
|
||||||
|
# and https://github.com/pytorch/examples/blob/master/vae/main.py
|
||||||
|
name = "VariationalAutoEncoder"
|
||||||
|
|
||||||
|
def __init__(self, name: Optional[str] = None, input_shape: int = 0):
|
||||||
|
self.log = logging.getLogger(self.__class__.__name__)
|
||||||
|
|
||||||
|
# Call superclass to initialize parameters.
|
||||||
|
super(VariationalAutoEncoder, self).__init__(name, input_shape)
|
||||||
|
|
||||||
|
# VAE needs intermediate output of the encoder stage, so split up the network into encoder/decoder
|
||||||
|
# with no ReLU layer at the end of the encoder so we have access to the mu and variance.
|
||||||
|
# We also split the last layer of the encoder in two, so we can make two passes.
|
||||||
|
# One to determine the mu and one to determine the variance
|
||||||
|
self.network = None
|
||||||
|
self.encoder1 = torch.nn.Sequential(
|
||||||
|
torch.nn.Linear(in_features=input_shape, out_features=input_shape // 2, bias=False),
|
||||||
|
torch.nn.ReLU()
|
||||||
|
)
|
||||||
|
self.encoder2_1 = torch.nn.Linear(in_features=input_shape // 2, out_features=input_shape // 4, bias=False)
|
||||||
|
self.encoder2_2 = torch.nn.Linear(in_features=input_shape // 2, out_features=input_shape // 4, bias=False)
|
||||||
|
self.decoder = torch.nn.Sequential(
|
||||||
|
torch.nn.Linear(in_features=input_shape // 4, out_features=input_shape // 2, bias=False),
|
||||||
|
torch.nn.ReLU(),
|
||||||
|
torch.nn.Linear(in_features=input_shape // 2, out_features=input_shape, bias=False),
|
||||||
|
torch.nn.ReLU()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Adam optimizer with learning rate 1e-3 (parameters changed so we need to re-declare it)
|
||||||
|
self.optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
|
||||||
|
|
||||||
|
# Loss function is the same as defined in the base encoder.
|
||||||
|
|
||||||
|
# Reparameterize takes a mu and variance, and returns a sample from the distribution N(mu(X), log_var(x))
|
||||||
|
def reparameterize(self, mu, log_var):
|
||||||
|
# z = μ(X) + Σ1/2(X)∗e
|
||||||
|
std_dev = torch.exp(0.5 * log_var)
|
||||||
|
epsilon = torch.randn_like(std_dev)
|
||||||
|
return mu + (std_dev * epsilon)
|
||||||
|
|
||||||
|
# Because network is split up, and a VAE needs the intermediate representations, needs to modify the input to the
|
||||||
|
# decoder, and needs to return the parameters used for the sample, we need to override the forward method.
|
||||||
|
def forward(self, features):
|
||||||
|
encoder_out = self.encoder1(features)
|
||||||
|
|
||||||
|
# Use the two last layers of the encoder to determine the mu and log_var
|
||||||
|
mu = self.encoder2_1(encoder_out)
|
||||||
|
log_var = self.encoder2_2(encoder_out)
|
||||||
|
|
||||||
|
# Get a sample from the distribution with mu and log_var, for use in the decoder
|
||||||
|
sample = self.reparameterize(mu, log_var)
|
||||||
|
|
||||||
|
decoder_out = self.decoder(sample)
|
||||||
|
return decoder_out, mu, log_var
|
||||||
|
|
||||||
|
# Because the model returns the mu and log_var in addition to the output,
|
||||||
|
# we need to override some output processing functions
|
||||||
|
def process_outputs_for_loss_function(self, outputs):
|
||||||
|
return outputs[0]
|
||||||
|
|
||||||
|
def process_outputs_for_testing(self, outputs):
|
||||||
|
return outputs[0]
|
||||||
|
|
||||||
|
# After the loss function is executed, we modify the loss with KL divergence
|
||||||
|
def process_loss(self, train_loss, features, outputs):
|
||||||
|
# Loss = Reconstruction loss + KL divergence loss summed over all elements and batch
|
||||||
|
_, mu, log_var = outputs
|
||||||
|
|
||||||
|
# see Appendix B from VAE paper:
|
||||||
|
# Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
|
||||||
|
# https://arxiv.org/abs/1312.6114
|
||||||
|
# 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
|
||||||
|
kl_divergence = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
|
||||||
|
return train_loss + kl_divergence
|
|
@ -1,5 +1,5 @@
|
||||||
torch==1.7.0+cu110
|
torch==1.7.1
|
||||||
torchvision==0.8.1+cu110
|
torchvision==0.8.2
|
||||||
torchaudio===0.7.0
|
torchaudio===0.7.2
|
||||||
tabulate
|
tabulate
|
||||||
matplotlib
|
matplotlib
|
|
@ -1,3 +1,4 @@
|
||||||
|
import importlib
|
||||||
import os
|
import os
|
||||||
from string import Template
|
from string import Template
|
||||||
|
|
||||||
|
@ -8,6 +9,13 @@ import matplotlib.pyplot as plt
|
||||||
from config import TRAIN_TEMP_DATA_BASE_PATH
|
from config import TRAIN_TEMP_DATA_BASE_PATH
|
||||||
|
|
||||||
|
|
||||||
|
def load_dotted_path(path):
|
||||||
|
split_path = path.split(".")
|
||||||
|
modulename, classname = ".".join(split_path[:-1]), split_path[-1]
|
||||||
|
model = getattr(importlib.import_module(modulename), classname)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
def training_header():
|
def training_header():
|
||||||
"""Prints a training header to the console"""
|
"""Prints a training header to the console"""
|
||||||
print('| Epoch | Avg. loss | Train Acc. | Test Acc. | Elapsed | ETA |\n'
|
print('| Epoch | Avg. loss | Train Acc. | Test Acc. | Elapsed | ETA |\n'
|
||||||
|
@ -106,4 +114,5 @@ def save_train_loss_graph(train_loss, filename):
|
||||||
plt.title('Train Loss')
|
plt.title('Train Loss')
|
||||||
plt.xlabel('Epochs')
|
plt.xlabel('Epochs')
|
||||||
plt.ylabel('Loss')
|
plt.ylabel('Loss')
|
||||||
|
plt.yscale('log')
|
||||||
plt.savefig(os.path.join(TRAIN_TEMP_DATA_BASE_PATH, f'{filename}_loss.png'))
|
plt.savefig(os.path.join(TRAIN_TEMP_DATA_BASE_PATH, f'{filename}_loss.png'))
|
Loading…
Reference in a new issue