Implement types of auto-encoders and corruption, use log scale in loss graphs, lots of helper function hooks in training process to allow implementations
- Encoders: sparse, denoising, contractive and variational - Noise: gaussian
This commit is contained in:
		
							parent
							
								
									62f9b873e9
								
							
						
					
					
						commit
						fb9ce46bd8
					
				
					 16 changed files with 405 additions and 63 deletions
				
			
		| 
						 | 
					@ -1,5 +1,7 @@
 | 
				
			||||||
MODEL_STORAGE_BASE_PATH = "/path/to/this/project/saved_models"
 | 
					MODEL_STORAGE_BASE_PATH = "/path/to/this/project/saved_models"
 | 
				
			||||||
DATASET_STORAGE_BASE_PATH = "/path/to/this/project/datasets"
 | 
					DATASET_STORAGE_BASE_PATH = "/path/to/this/project/datasets"
 | 
				
			||||||
 | 
					TRAIN_TEMP_DATA_BASE_PATH = "/path/to/this/project/train_temp"
 | 
				
			||||||
 | 
					TEST_TEMP_DATA_BASE_PATH = "/path/to/this/project/test_temp"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
TEST_RUNS = [
 | 
					TEST_RUNS = [
 | 
				
			||||||
    {
 | 
					    {
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,5 +1,5 @@
 | 
				
			||||||
[loggers]
 | 
					[loggers]
 | 
				
			||||||
keys=root
 | 
					keys=root,matplotlib
 | 
				
			||||||
 | 
					
 | 
				
			||||||
[handlers]
 | 
					[handlers]
 | 
				
			||||||
keys=consoleHandler,fileHandler
 | 
					keys=consoleHandler,fileHandler
 | 
				
			||||||
| 
						 | 
					@ -11,6 +11,12 @@ keys=simpleFormatter
 | 
				
			||||||
level=DEBUG
 | 
					level=DEBUG
 | 
				
			||||||
handlers=consoleHandler,fileHandler
 | 
					handlers=consoleHandler,fileHandler
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					[logger_matplotlib]
 | 
				
			||||||
 | 
					level=NOTSET
 | 
				
			||||||
 | 
					handlers=
 | 
				
			||||||
 | 
					propagate=0
 | 
				
			||||||
 | 
					qualname=matplotlib
 | 
				
			||||||
 | 
					
 | 
				
			||||||
[handler_fileHandler]
 | 
					[handler_fileHandler]
 | 
				
			||||||
class=FileHandler
 | 
					class=FileHandler
 | 
				
			||||||
level=DEBUG
 | 
					level=DEBUG
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
							
								
								
									
										18
									
								
								main.py
									
										
									
									
									
								
							
							
						
						
									
										18
									
								
								main.py
									
										
									
									
									
								
							| 
						 | 
					@ -1,11 +1,10 @@
 | 
				
			||||||
import importlib
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import config
 | 
					import config
 | 
				
			||||||
import logging.config
 | 
					import logging.config
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Get logging as early as possible!
 | 
					# Get logging as early as possible!
 | 
				
			||||||
logging.config.fileConfig("logging.conf")
 | 
					logging.config.fileConfig("logging.conf")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from utils import load_dotted_path
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from models.base_corruption import BaseCorruption
 | 
					from models.base_corruption import BaseCorruption
 | 
				
			||||||
from models.base_dataset import BaseDataset
 | 
					from models.base_dataset import BaseDataset
 | 
				
			||||||
| 
						 | 
					@ -13,13 +12,6 @@ from models.base_encoder import BaseEncoder
 | 
				
			||||||
from models.test_run import TestRun
 | 
					from models.test_run import TestRun
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def load_dotted_path(path):
 | 
					 | 
				
			||||||
    split_path = path.split(".")
 | 
					 | 
				
			||||||
    modulename, classname = ".".join(split_path[:-1]), split_path[-1]
 | 
					 | 
				
			||||||
    model = getattr(importlib.import_module(modulename), classname)
 | 
					 | 
				
			||||||
    return model
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def run_tests():
 | 
					def run_tests():
 | 
				
			||||||
    logger = logging.getLogger("main.run_tests")
 | 
					    logger = logging.getLogger("main.run_tests")
 | 
				
			||||||
    for test in config.TEST_RUNS:
 | 
					    for test in config.TEST_RUNS:
 | 
				
			||||||
| 
						 | 
					@ -41,9 +33,11 @@ def run_tests():
 | 
				
			||||||
        logger.debug(f"Using corruption model '{corruption_model.__name__}'")
 | 
					        logger.debug(f"Using corruption model '{corruption_model.__name__}'")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Create TestRun instance
 | 
					        # Create TestRun instance
 | 
				
			||||||
        test_run = TestRun(dataset=dataset_model(**test['dataset_kwargs']),
 | 
					        dataset = dataset_model(**test['dataset_kwargs'])
 | 
				
			||||||
                           encoder=encoder_model(**test['encoder_kwargs']),
 | 
					        encoder = encoder_model(**test['encoder_kwargs'])
 | 
				
			||||||
                           corruption=corruption_model(**test['corruption_kwargs']))
 | 
					        encoder.after_init()
 | 
				
			||||||
 | 
					        corruption = corruption_model(**test['corruption_kwargs'])
 | 
				
			||||||
 | 
					        test_run = TestRun(dataset=dataset, encoder=encoder, corruption=corruption)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Run TestRun
 | 
					        # Run TestRun
 | 
				
			||||||
        test_run.run(retrain=False)
 | 
					        test_run.run(retrain=False)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -15,7 +15,11 @@ class BaseCorruption:
 | 
				
			||||||
        return f"{self.name}"
 | 
					        return f"{self.name}"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @classmethod
 | 
					    @classmethod
 | 
				
			||||||
    def corrupt(cls, dataset: BaseDataset) -> BaseDataset:
 | 
					    def corrupt_image(cls, image):
 | 
				
			||||||
 | 
					        raise NotImplementedError()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @classmethod
 | 
				
			||||||
 | 
					    def corrupt_dataset(cls, dataset: BaseDataset) -> BaseDataset:
 | 
				
			||||||
        raise NotImplementedError()
 | 
					        raise NotImplementedError()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -26,5 +30,9 @@ class NoCorruption(BaseCorruption):
 | 
				
			||||||
    name = "No corruption"
 | 
					    name = "No corruption"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @classmethod
 | 
					    @classmethod
 | 
				
			||||||
    def corrupt(cls, dataset: BaseDataset) -> BaseDataset:
 | 
					    def corrupt_image(cls, image):
 | 
				
			||||||
 | 
					        return image
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @classmethod
 | 
				
			||||||
 | 
					    def corrupt_dataset(cls, dataset: BaseDataset) -> BaseDataset:
 | 
				
			||||||
        return dataset
 | 
					        return dataset
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -78,6 +78,12 @@ class BaseDataset(Dataset):
 | 
				
			||||||
        self._trainset = self.__class__.get_new(name=f"{self.name} Training", data=train_data, source_path=self._source_path)
 | 
					        self._trainset = self.__class__.get_new(name=f"{self.name} Training", data=train_data, source_path=self._source_path)
 | 
				
			||||||
        self._testset = self.__class__.get_new(name=f"{self.name} Testing", data=test_data, source_path=self._source_path)
 | 
					        self._testset = self.__class__.get_new(name=f"{self.name} Testing", data=test_data, source_path=self._source_path)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def has_train(self):
 | 
				
			||||||
 | 
					        return self._trainset is not None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def has_test(self):
 | 
				
			||||||
 | 
					        return self._testset is not None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def get_train(self) -> 'BaseDataset':
 | 
					    def get_train(self) -> 'BaseDataset':
 | 
				
			||||||
        if not self._trainset or not self._testset:
 | 
					        if not self._trainset or not self._testset:
 | 
				
			||||||
            self._subdivide(self.TRAIN_AMOUNT)
 | 
					            self._subdivide(self.TRAIN_AMOUNT)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -6,6 +6,7 @@ import torch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from typing import Optional
 | 
					from typing import Optional
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from torch.nn.modules.loss import _Loss
 | 
				
			||||||
from torchvision.utils import save_image
 | 
					from torchvision.utils import save_image
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from config import TRAIN_TEMP_DATA_BASE_PATH, TEST_TEMP_DATA_BASE_PATH, MODEL_STORAGE_BASE_PATH
 | 
					from config import TRAIN_TEMP_DATA_BASE_PATH, TEST_TEMP_DATA_BASE_PATH, MODEL_STORAGE_BASE_PATH
 | 
				
			||||||
| 
						 | 
					@ -37,7 +38,8 @@ class BaseEncoder(torch.nn.Module):
 | 
				
			||||||
            torch.nn.ReLU(),
 | 
					            torch.nn.ReLU(),
 | 
				
			||||||
            torch.nn.Linear(in_features=input_shape // 4, out_features=input_shape // 2),
 | 
					            torch.nn.Linear(in_features=input_shape // 4, out_features=input_shape // 2),
 | 
				
			||||||
            torch.nn.ReLU(),
 | 
					            torch.nn.ReLU(),
 | 
				
			||||||
            torch.nn.Linear(in_features=input_shape // 2, out_features=input_shape)
 | 
					            torch.nn.Linear(in_features=input_shape // 2, out_features=input_shape),
 | 
				
			||||||
 | 
					            torch.nn.ReLU()
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Use GPU acceleration if available
 | 
					        # Use GPU acceleration if available
 | 
				
			||||||
| 
						 | 
					@ -50,7 +52,10 @@ class BaseEncoder(torch.nn.Module):
 | 
				
			||||||
        self.loss_function = torch.nn.MSELoss()
 | 
					        self.loss_function = torch.nn.MSELoss()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def after_init(self):
 | 
					    def after_init(self):
 | 
				
			||||||
        self.log.info(f"Auto-encoder {self.__class__.__name__} initialized with {len(list(self.network.children()))} layers on {self.device.type}. Optimizer: {self.optimizer}, Loss function: {self.loss_function}")
 | 
					        self.log.info(f"Auto-encoder {self.__class__.__name__} initialized with "
 | 
				
			||||||
 | 
					                      f"{len(list(self.network.children())) if self.network else 'custom'} layers on "
 | 
				
			||||||
 | 
					                      f"{self.device.type}. Optimizer: {self.optimizer.__class__.__name__}, "
 | 
				
			||||||
 | 
					                      f"Loss function: {self.loss_function.__class__.__name__}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def forward(self, features):
 | 
					    def forward(self, features):
 | 
				
			||||||
        return self.network(features)
 | 
					        return self.network(features)
 | 
				
			||||||
| 
						 | 
					@ -109,21 +114,33 @@ class BaseEncoder(torch.nn.Module):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        outputs = None
 | 
					        outputs = None
 | 
				
			||||||
        for epoch in range(epochs):
 | 
					        for epoch in range(epochs):
 | 
				
			||||||
            self.log.debug(f"Start training epoch {epoch}...")
 | 
					            self.log.debug(f"Start training epoch {epoch + 1}...")
 | 
				
			||||||
            loss = 0
 | 
					            loss = 0
 | 
				
			||||||
            for batch_features in train_loader:
 | 
					            for i, batch_features in enumerate(train_loader):
 | 
				
			||||||
                # load batch features to the active device
 | 
					                # # load batch features to the active device
 | 
				
			||||||
                batch_features = batch_features.to(self.device)
 | 
					                # batch_features = batch_features.to(self.device)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                # reset the gradients back to zero
 | 
					                # reset the gradients back to zero
 | 
				
			||||||
                # PyTorch accumulates gradients on subsequent backward passes
 | 
					                # PyTorch accumulates gradients on subsequent backward passes
 | 
				
			||||||
                self.optimizer.zero_grad()
 | 
					                self.optimizer.zero_grad()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                # Modify features used in training model (if necessary) and load to the active device
 | 
				
			||||||
 | 
					                train_features = self.process_train_features(batch_features).to(self.device)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                # compute reconstructions
 | 
					                # compute reconstructions
 | 
				
			||||||
                outputs = self(batch_features)
 | 
					                outputs = self(train_features)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                # Modify outputs used in loss function (if necessary) and load to the active device
 | 
				
			||||||
 | 
					                outputs_for_loss = self.process_outputs_for_loss_function(outputs).to(self.device)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                # Modify features used in comparing in loss function (if necessary) and load to the active device
 | 
				
			||||||
 | 
					                compare_features = self.process_compare_features(batch_features).to(self.device)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                # compute training reconstruction loss
 | 
					                # compute training reconstruction loss
 | 
				
			||||||
                train_loss = self.loss_function(outputs, batch_features)
 | 
					                train_loss = self.loss_function(outputs_for_loss, compare_features)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                # Process loss if necessary (default implementation does nothing)
 | 
				
			||||||
 | 
					                train_loss = self.process_loss(train_loss, compare_features, outputs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                # compute accumulated gradients
 | 
					                # compute accumulated gradients
 | 
				
			||||||
                train_loss.backward()
 | 
					                train_loss.backward()
 | 
				
			||||||
| 
						 | 
					@ -134,6 +151,11 @@ class BaseEncoder(torch.nn.Module):
 | 
				
			||||||
                # add the mini-batch training loss to epoch loss
 | 
					                # add the mini-batch training loss to epoch loss
 | 
				
			||||||
                loss += train_loss.item()
 | 
					                loss += train_loss.item()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                # Print progress every 50 batches
 | 
				
			||||||
 | 
					                if i % 50 == 0:
 | 
				
			||||||
 | 
					                    self.log.debug(f"  progress: [{i * len(batch_features)}/{len(train_loader.dataset)} "
 | 
				
			||||||
 | 
					                                   f"({(100 * i / len(train_loader)):.0f}%)]")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            # compute the epoch training loss
 | 
					            # compute the epoch training loss
 | 
				
			||||||
            loss = loss / len(train_loader)
 | 
					            loss = loss / len(train_loader)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -143,7 +165,7 @@ class BaseEncoder(torch.nn.Module):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            # Every 5 epochs, save a test image
 | 
					            # Every 5 epochs, save a test image
 | 
				
			||||||
            if epoch % 5 == 0:
 | 
					            if epoch % 5 == 0:
 | 
				
			||||||
                img = outputs.cpu().data
 | 
					                img = self.process_outputs_for_testing(outputs).cpu().data
 | 
				
			||||||
                img = img.view(img.size(0), 3, 32, 32)
 | 
					                img = img.view(img.size(0), 3, 32, 32)
 | 
				
			||||||
                save_image(img, os.path.join(TRAIN_TEMP_DATA_BASE_PATH,
 | 
					                save_image(img, os.path.join(TRAIN_TEMP_DATA_BASE_PATH,
 | 
				
			||||||
                                             f'{self.name}_{dataset.name}_linear_ae_image{epoch}.png'))
 | 
					                                             f'{self.name}_{dataset.name}_linear_ae_image{epoch}.png'))
 | 
				
			||||||
| 
						 | 
					@ -163,10 +185,25 @@ class BaseEncoder(torch.nn.Module):
 | 
				
			||||||
                                         f'{self.name}_{dataset.name}_test_input_{i}.png'))
 | 
					                                         f'{self.name}_{dataset.name}_test_input_{i}.png'))
 | 
				
			||||||
            # load batch features to the active device
 | 
					            # load batch features to the active device
 | 
				
			||||||
            batch = batch.to(self.device)
 | 
					            batch = batch.to(self.device)
 | 
				
			||||||
            outputs = self(batch)
 | 
					            outputs = self.process_outputs_for_testing(self(batch))
 | 
				
			||||||
            img = outputs.cpu().data
 | 
					            img = outputs.cpu().data
 | 
				
			||||||
            img = img.view(outputs.size(0), 3, 32, 32)
 | 
					            img = img.view(outputs.size(0), 3, 32, 32)
 | 
				
			||||||
            save_image(img, os.path.join(TEST_TEMP_DATA_BASE_PATH,
 | 
					            save_image(img, os.path.join(TEST_TEMP_DATA_BASE_PATH,
 | 
				
			||||||
                                         f'{self.name}_{dataset.name}_test_reconstruction_{i}.png'))
 | 
					                                         f'{self.name}_{dataset.name}_test_reconstruction_{i}.png'))
 | 
				
			||||||
            i += 1
 | 
					            i += 1
 | 
				
			||||||
            break
 | 
					            break
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def process_loss(self, train_loss, features, outputs) -> _Loss:
 | 
				
			||||||
 | 
					        return train_loss
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def process_train_features(self, features):
 | 
				
			||||||
 | 
					        return features
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def process_compare_features(self, features):
 | 
				
			||||||
 | 
					        return features
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def process_outputs_for_loss_function(self, outputs):
 | 
				
			||||||
 | 
					        return outputs
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def process_outputs_for_testing(self, outputs):
 | 
				
			||||||
 | 
					        return outputs
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,8 +1,4 @@
 | 
				
			||||||
import json
 | 
					 | 
				
			||||||
import logging
 | 
					import logging
 | 
				
			||||||
import os
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import torch
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
from typing import Optional
 | 
					from typing import Optional
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -19,30 +15,4 @@ class BasicAutoEncoder(BaseEncoder):
 | 
				
			||||||
        # Call superclass to initialize parameters.
 | 
					        # Call superclass to initialize parameters.
 | 
				
			||||||
        super(BasicAutoEncoder, self).__init__(name, input_shape)
 | 
					        super(BasicAutoEncoder, self).__init__(name, input_shape)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Override parameters to custom values for this encoder type
 | 
					        # Network, optimizer and loss function are the same as defined in the base encoder.
 | 
				
			||||||
 | 
					 | 
				
			||||||
        # TODO - Hoe kan ik het beste bepalen hoe groot de intermediate layers moeten zijn?
 | 
					 | 
				
			||||||
        # - Proportioneel van input grootte naar opgegeven bottleneck grootte?
 | 
					 | 
				
			||||||
        # - Uit een paper plukken
 | 
					 | 
				
			||||||
        # - Zelf kiezen (e.g. helft elke keer, fixed aantal layers)?
 | 
					 | 
				
			||||||
        self.network = torch.nn.Sequential(
 | 
					 | 
				
			||||||
            torch.nn.Linear(in_features=input_shape, out_features=input_shape // 2),
 | 
					 | 
				
			||||||
            torch.nn.ReLU(),
 | 
					 | 
				
			||||||
            torch.nn.Linear(in_features=input_shape // 2, out_features=input_shape // 4),
 | 
					 | 
				
			||||||
            torch.nn.ReLU(),
 | 
					 | 
				
			||||||
            torch.nn.Linear(in_features=input_shape // 4, out_features=input_shape // 2),
 | 
					 | 
				
			||||||
            torch.nn.ReLU(),
 | 
					 | 
				
			||||||
            torch.nn.Linear(in_features=input_shape // 2, out_features=input_shape)
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Use GPU acceleration if available
 | 
					 | 
				
			||||||
        # self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Adam optimizer with learning rate 1e-3
 | 
					 | 
				
			||||||
        self.optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Mean Squared Error loss function
 | 
					 | 
				
			||||||
        # self.loss_function = torch.nn.MSELoss()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        self.after_init()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -3,8 +3,6 @@ import os
 | 
				
			||||||
from typing import Optional
 | 
					from typing import Optional
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import numpy
 | 
					import numpy
 | 
				
			||||||
import torchvision
 | 
					 | 
				
			||||||
from PIL import Image
 | 
					 | 
				
			||||||
from torchvision import transforms
 | 
					from torchvision import transforms
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from config import DATASET_STORAGE_BASE_PATH
 | 
					from config import DATASET_STORAGE_BASE_PATH
 | 
				
			||||||
| 
						 | 
					@ -12,6 +10,7 @@ from models.base_dataset import BaseDataset
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class Cifar10Dataset(BaseDataset):
 | 
					class Cifar10Dataset(BaseDataset):
 | 
				
			||||||
 | 
					    name = "CIFAR-10"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),
 | 
					    # transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),
 | 
				
			||||||
    #                                             torchvision.transforms.Normalize((0.5, ), (0.5, ))
 | 
					    #                                             torchvision.transforms.Normalize((0.5, ), (0.5, ))
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
							
								
								
									
										88
									
								
								models/contractive_encoder.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										88
									
								
								models/contractive_encoder.py
									
										
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,88 @@
 | 
				
			||||||
 | 
					import logging
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from typing import Optional
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from torch.autograd import Variable
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from models.base_encoder import BaseEncoder
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class ContractiveAutoEncoder(BaseEncoder):
 | 
				
			||||||
 | 
					    # Based on https://github.com/avijit9/Contractive_Autoencoder_in_Pytorch/blob/master/CAE_pytorch.py
 | 
				
			||||||
 | 
					    name = "ContractiveAutoEncoder"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(self, name: Optional[str] = None, input_shape: int = 0, regularizer_weight: float = 1e-4):
 | 
				
			||||||
 | 
					        self.log = logging.getLogger(self.__class__.__name__)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Call superclass to initialize parameters.
 | 
				
			||||||
 | 
					        super(ContractiveAutoEncoder, self).__init__(name, input_shape)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.regularizer_weight = regularizer_weight
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # CAE needs intermediate output of the encoder stage, so split up the network into encoder/decoder
 | 
				
			||||||
 | 
					        self.network = None
 | 
				
			||||||
 | 
					        self.encoder = torch.nn.Sequential(
 | 
				
			||||||
 | 
					            torch.nn.Linear(in_features=input_shape, out_features=input_shape // 2, bias=False),
 | 
				
			||||||
 | 
					            torch.nn.ReLU(),
 | 
				
			||||||
 | 
					            torch.nn.Linear(in_features=input_shape // 2, out_features=input_shape // 4, bias=False),
 | 
				
			||||||
 | 
					            torch.nn.ReLU()
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.decoder = torch.nn.Sequential(
 | 
				
			||||||
 | 
					            torch.nn.Linear(in_features=input_shape // 4, out_features=input_shape // 2, bias=False),
 | 
				
			||||||
 | 
					            torch.nn.ReLU(),
 | 
				
			||||||
 | 
					            torch.nn.Linear(in_features=input_shape // 2, out_features=input_shape, bias=False),
 | 
				
			||||||
 | 
					            torch.nn.ReLU()
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Adam optimizer with learning rate 1e-3 (parameters changed so we need to re-declare it)
 | 
				
			||||||
 | 
					        self.optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Loss function is the same as defined in the base encoder.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Because network is split up now, and a CAE needs the intermediate representations,
 | 
				
			||||||
 | 
					    # we need to override the forward method and return the intermediate state as well.
 | 
				
			||||||
 | 
					    def forward(self, features):
 | 
				
			||||||
 | 
					        encoder_out = self.encoder(features)
 | 
				
			||||||
 | 
					        decoder_out = self.decoder(encoder_out)
 | 
				
			||||||
 | 
					        return encoder_out, decoder_out
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Because the model returns the intermediate state as well as the output,
 | 
				
			||||||
 | 
					    # we need to override some output processing functions
 | 
				
			||||||
 | 
					    def process_outputs_for_loss_function(self, outputs):
 | 
				
			||||||
 | 
					        return outputs[1]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def process_outputs_for_testing(self, outputs):
 | 
				
			||||||
 | 
					        return outputs[1]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Finally we define the loss process function, which adds the contractive loss.
 | 
				
			||||||
 | 
					    def process_loss(self, train_loss, features, outputs):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Evaluates the CAE loss, which is the summation of the MSE and the weighted L2-norm of the Jacobian of the
 | 
				
			||||||
 | 
					        hidden units with respect to the inputs.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Reference: http://wiseodd.github.io/techblog/2016/12/05/contractive-autoencoder
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        :param train_loss: The (MSE) loss as returned by the loss function of the model
 | 
				
			||||||
 | 
					        :param features: The input features
 | 
				
			||||||
 | 
					        :param outputs: The raw outputs as returned by the model (in this case includes the hidden encoder output)
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        hidden_output = outputs[0]
 | 
				
			||||||
 | 
					        # Weights of the second Linear layer in the encoder (index 2)
 | 
				
			||||||
 | 
					        weights = self.state_dict()['encoder.2.weight']
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Hadamard product
 | 
				
			||||||
 | 
					        hidden_output = hidden_output.reshape(hidden_output.shape[0], hidden_output.shape[2])
 | 
				
			||||||
 | 
					        dh = hidden_output * (1 - hidden_output)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Sum through input dimension to improve efficiency (suggested in reference)
 | 
				
			||||||
 | 
					        w_sum = torch.sum(Variable(weights) ** 2, dim=1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Unsqueeze to avoid issues with torch.mv
 | 
				
			||||||
 | 
					        w_sum = w_sum.unsqueeze(1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Calculate contractive loss
 | 
				
			||||||
 | 
					        contractive_loss = torch.sum(torch.mm(dh ** 2, w_sum), 0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return train_loss + contractive_loss.mul_(self.regularizer_weight)
 | 
				
			||||||
							
								
								
									
										35
									
								
								models/denoising_encoder.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										35
									
								
								models/denoising_encoder.py
									
										
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,35 @@
 | 
				
			||||||
 | 
					import logging
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from typing import Optional
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from torch import Tensor
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from main import load_dotted_path
 | 
				
			||||||
 | 
					from models.base_corruption import BaseCorruption, NoCorruption
 | 
				
			||||||
 | 
					from models.base_encoder import BaseEncoder
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class DenoisingAutoEncoder(BaseEncoder):
 | 
				
			||||||
 | 
					    # Based on https://github.com/pranjaldatta/Denoising-Autoencoder-in-Pytorch/blob/master/DenoisingAutoencoder.ipynb
 | 
				
			||||||
 | 
					    name = "DenoisingAutoEncoder"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(self, name: Optional[str] = None, input_shape: int = 0,
 | 
				
			||||||
 | 
					                 input_corruption_model: BaseCorruption = NoCorruption):
 | 
				
			||||||
 | 
					        self.log = logging.getLogger(self.__class__.__name__)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Call superclass to initialize parameters.
 | 
				
			||||||
 | 
					        super(DenoisingAutoEncoder, self).__init__(name, input_shape)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Network, optimizer and loss function are the same as defined in the base encoder.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Corruption used for data corruption during training
 | 
				
			||||||
 | 
					        if isinstance(input_corruption_model, str):
 | 
				
			||||||
 | 
					            self.input_corruption_model = load_dotted_path(input_corruption_model)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            self.input_corruption_model = input_corruption_model
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Need to corrupt features used for training, to add 'noise' to training data (comparison features are unchanged)
 | 
				
			||||||
 | 
					    def process_train_features(self, features: Tensor) -> Tensor:
 | 
				
			||||||
 | 
					        return torch.tensor([self.input_corruption_model.corrupt_image(x) for x in features], dtype=torch.float32)
 | 
				
			||||||
							
								
								
									
										36
									
								
								models/gaussian_corruption.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										36
									
								
								models/gaussian_corruption.py
									
										
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,36 @@
 | 
				
			||||||
 | 
					from torch import Tensor
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from models.base_corruption import BaseCorruption
 | 
				
			||||||
 | 
					from models.base_dataset import BaseDataset
 | 
				
			||||||
 | 
					import numpy
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def add_noise(image):
 | 
				
			||||||
 | 
					    image = image.astype(numpy.float32)
 | 
				
			||||||
 | 
					    mean, variance = 0, 0.1
 | 
				
			||||||
 | 
					    sigma = variance ** 0.5
 | 
				
			||||||
 | 
					    noise = numpy.random.normal(mean, sigma, image.shape).reshape(image.shape)
 | 
				
			||||||
 | 
					    return image + noise
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class GaussianCorruption(BaseCorruption):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Corruption model that adds Gaussian noise to the dataset.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    name = "Gaussian"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @classmethod
 | 
				
			||||||
 | 
					    def corrupt_image(cls, image: Tensor):
 | 
				
			||||||
 | 
					        return add_noise(image.numpy())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @classmethod
 | 
				
			||||||
 | 
					    def corrupt_dataset(cls, dataset: BaseDataset) -> BaseDataset:
 | 
				
			||||||
 | 
					        data = list(map(add_noise, dataset))
 | 
				
			||||||
 | 
					        train_set = cls.corrupt_dataset(dataset.get_train()) if dataset.has_train() else None
 | 
				
			||||||
 | 
					        test_set = cls.corrupt_dataset(dataset.get_test()) if dataset.has_test() else None
 | 
				
			||||||
 | 
					        return dataset.__class__.get_new(
 | 
				
			||||||
 | 
					            name=f"{dataset.name} Corrupted",
 | 
				
			||||||
 | 
					            data=data,
 | 
				
			||||||
 | 
					            source_path=dataset._source_path,
 | 
				
			||||||
 | 
					            train_set=train_set,
 | 
				
			||||||
 | 
					            test_set=test_set)
 | 
				
			||||||
							
								
								
									
										68
									
								
								models/sparse_encoder.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										68
									
								
								models/sparse_encoder.py
									
										
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,68 @@
 | 
				
			||||||
 | 
					import logging
 | 
				
			||||||
 | 
					import math
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from typing import Optional
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from torch.nn.modules.loss import _Loss
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from models.base_encoder import BaseEncoder
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class SparseL1AutoEncoder(BaseEncoder):
 | 
				
			||||||
 | 
					    # Based on https://debuggercafe.com/sparse-autoencoders-using-l1-regularization-with-pytorch/
 | 
				
			||||||
 | 
					    name = "SparseL1AutoEncoder"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(self, name: Optional[str] = None, input_shape: int = 0, regularization_parameter: float = 0.001):
 | 
				
			||||||
 | 
					        self.log = logging.getLogger(self.__class__.__name__)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Call superclass to initialize parameters.
 | 
				
			||||||
 | 
					        super(SparseL1AutoEncoder, self).__init__(name, input_shape)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Override parameters to custom values for this encoder type
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Sparse encoder has larger intermediary layers, so let's increase them 1.5 times in the first layer,
 | 
				
			||||||
 | 
					        # and 2 times in the second layer (compared to the original input shape
 | 
				
			||||||
 | 
					        self.network = torch.nn.Sequential(
 | 
				
			||||||
 | 
					            torch.nn.Linear(in_features=input_shape, out_features=math.floor(input_shape * 1.5)),
 | 
				
			||||||
 | 
					            torch.nn.ReLU(),
 | 
				
			||||||
 | 
					            torch.nn.Linear(in_features=math.floor(input_shape * 1.5), out_features=input_shape * 2),
 | 
				
			||||||
 | 
					            torch.nn.ReLU(),
 | 
				
			||||||
 | 
					            torch.nn.Linear(in_features=input_shape * 2, out_features=math.floor(input_shape * 1.5)),
 | 
				
			||||||
 | 
					            torch.nn.ReLU(),
 | 
				
			||||||
 | 
					            torch.nn.Linear(in_features=math.floor(input_shape * 1.5), out_features=input_shape),
 | 
				
			||||||
 | 
					            torch.nn.ReLU()
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Adam optimizer with learning rate 1e-3 (parameters changed so we need to re-declare it)
 | 
				
			||||||
 | 
					        self.optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Loss function is the same as defined in the base encoder.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Regularization parameter (lambda) for the L1 sparse loss function
 | 
				
			||||||
 | 
					        self.regularization_parameter = regularization_parameter
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def get_sparse_loss(self, images):
 | 
				
			||||||
 | 
					        def get_sparse_loss_rec(loss, values, children):
 | 
				
			||||||
 | 
					            for child in children:
 | 
				
			||||||
 | 
					                if isinstance(child, torch.nn.Sequential):
 | 
				
			||||||
 | 
					                    loss, values = get_sparse_loss_rec(loss, values, [x for x in child])
 | 
				
			||||||
 | 
					                elif isinstance(child, torch.nn.ReLU):
 | 
				
			||||||
 | 
					                    values = child(values)
 | 
				
			||||||
 | 
					                    loss += torch.mean(torch.abs(values))
 | 
				
			||||||
 | 
					                elif isinstance(child, torch.nn.Linear):
 | 
				
			||||||
 | 
					                    values = child(values)
 | 
				
			||||||
 | 
					                else:
 | 
				
			||||||
 | 
					                    # Ignore unknown layers in sparse loss calculation
 | 
				
			||||||
 | 
					                    pass
 | 
				
			||||||
 | 
					            return loss, values
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        loss, values = get_sparse_loss_rec(loss=0, values=images, children=list(self.children()))
 | 
				
			||||||
 | 
					        return loss
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def process_loss(self, train_loss, features, outputs) -> _Loss:
 | 
				
			||||||
 | 
					        l1_loss = self.get_sparse_loss(features)
 | 
				
			||||||
 | 
					        # Add sparsity penalty
 | 
				
			||||||
 | 
					        return train_loss + (self.regularization_parameter * l1_loss)
 | 
				
			||||||
| 
						 | 
					@ -4,7 +4,7 @@ import multiprocessing
 | 
				
			||||||
from models.base_corruption import BaseCorruption
 | 
					from models.base_corruption import BaseCorruption
 | 
				
			||||||
from models.base_dataset import BaseDataset
 | 
					from models.base_dataset import BaseDataset
 | 
				
			||||||
from models.base_encoder import BaseEncoder
 | 
					from models.base_encoder import BaseEncoder
 | 
				
			||||||
from output_utils import save_train_loss_graph
 | 
					from utils import save_train_loss_graph
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class TestRun:
 | 
					class TestRun:
 | 
				
			||||||
| 
						 | 
					@ -49,7 +49,7 @@ class TestRun:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            # Save train loss graph
 | 
					            # Save train loss graph
 | 
				
			||||||
            self.log.info("Saving loss graph...")
 | 
					            self.log.info("Saving loss graph...")
 | 
				
			||||||
            save_train_loss_graph(train_loss, self.dataset.name)
 | 
					            save_train_loss_graph(train_loss, f"{self.encoder.name}_{self.dataset.name}")
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            self.log.info("Loading saved auto-encoder...")
 | 
					            self.log.info("Loading saved auto-encoder...")
 | 
				
			||||||
            load_success = self.encoder.load_model(f"{self.encoder.name}_{self.dataset.name}")
 | 
					            load_success = self.encoder.load_model(f"{self.encoder.name}_{self.dataset.name}")
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
							
								
								
									
										84
									
								
								models/variational_encoder.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										84
									
								
								models/variational_encoder.py
									
										
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,84 @@
 | 
				
			||||||
 | 
					import logging
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from typing import Optional
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from models.base_encoder import BaseEncoder
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class VariationalAutoEncoder(BaseEncoder):
 | 
				
			||||||
 | 
					    # Based on https://debuggercafe.com/getting-started-with-variational-autoencoder-using-pytorch/
 | 
				
			||||||
 | 
					    # and https://github.com/pytorch/examples/blob/master/vae/main.py
 | 
				
			||||||
 | 
					    name = "VariationalAutoEncoder"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(self, name: Optional[str] = None, input_shape: int = 0):
 | 
				
			||||||
 | 
					        self.log = logging.getLogger(self.__class__.__name__)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Call superclass to initialize parameters.
 | 
				
			||||||
 | 
					        super(VariationalAutoEncoder, self).__init__(name, input_shape)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # VAE needs intermediate output of the encoder stage, so split up the network into encoder/decoder
 | 
				
			||||||
 | 
					        # with no ReLU layer at the end of the encoder so we have access to the mu and variance.
 | 
				
			||||||
 | 
					        # We also split the last layer of the encoder in two, so we can make two passes.
 | 
				
			||||||
 | 
					        # One to determine the mu and one to determine the variance
 | 
				
			||||||
 | 
					        self.network = None
 | 
				
			||||||
 | 
					        self.encoder1 = torch.nn.Sequential(
 | 
				
			||||||
 | 
					            torch.nn.Linear(in_features=input_shape, out_features=input_shape // 2, bias=False),
 | 
				
			||||||
 | 
					            torch.nn.ReLU()
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.encoder2_1 = torch.nn.Linear(in_features=input_shape // 2, out_features=input_shape // 4, bias=False)
 | 
				
			||||||
 | 
					        self.encoder2_2 = torch.nn.Linear(in_features=input_shape // 2, out_features=input_shape // 4, bias=False)
 | 
				
			||||||
 | 
					        self.decoder = torch.nn.Sequential(
 | 
				
			||||||
 | 
					            torch.nn.Linear(in_features=input_shape // 4, out_features=input_shape // 2, bias=False),
 | 
				
			||||||
 | 
					            torch.nn.ReLU(),
 | 
				
			||||||
 | 
					            torch.nn.Linear(in_features=input_shape // 2, out_features=input_shape, bias=False),
 | 
				
			||||||
 | 
					            torch.nn.ReLU()
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Adam optimizer with learning rate 1e-3 (parameters changed so we need to re-declare it)
 | 
				
			||||||
 | 
					        self.optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Loss function is the same as defined in the base encoder.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Reparameterize takes a mu and variance, and returns a sample from the distribution N(mu(X), log_var(x))
 | 
				
			||||||
 | 
					    def reparameterize(self, mu, log_var):
 | 
				
			||||||
 | 
					        # z = μ(X) + Σ1/2(X)∗e
 | 
				
			||||||
 | 
					        std_dev = torch.exp(0.5 * log_var)
 | 
				
			||||||
 | 
					        epsilon = torch.randn_like(std_dev)
 | 
				
			||||||
 | 
					        return mu + (std_dev * epsilon)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Because network is split up, and a VAE needs the intermediate representations, needs to modify the input to the
 | 
				
			||||||
 | 
					    # decoder, and needs to return the parameters used for the sample, we need to override the forward method.
 | 
				
			||||||
 | 
					    def forward(self, features):
 | 
				
			||||||
 | 
					        encoder_out = self.encoder1(features)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Use the two last layers of the encoder to determine the mu and log_var
 | 
				
			||||||
 | 
					        mu = self.encoder2_1(encoder_out)
 | 
				
			||||||
 | 
					        log_var = self.encoder2_2(encoder_out)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Get a sample from the distribution with mu and log_var, for use in the decoder
 | 
				
			||||||
 | 
					        sample = self.reparameterize(mu, log_var)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        decoder_out = self.decoder(sample)
 | 
				
			||||||
 | 
					        return decoder_out, mu, log_var
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Because the model returns the mu and log_var in addition to the output,
 | 
				
			||||||
 | 
					    # we need to override some output processing functions
 | 
				
			||||||
 | 
					    def process_outputs_for_loss_function(self, outputs):
 | 
				
			||||||
 | 
					        return outputs[0]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def process_outputs_for_testing(self, outputs):
 | 
				
			||||||
 | 
					        return outputs[0]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # After the loss function is executed, we modify the loss with KL divergence
 | 
				
			||||||
 | 
					    def process_loss(self, train_loss, features, outputs):
 | 
				
			||||||
 | 
					        # Loss = Reconstruction loss + KL divergence loss summed over all elements and batch
 | 
				
			||||||
 | 
					        _, mu, log_var = outputs
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # see Appendix B from VAE paper:
 | 
				
			||||||
 | 
					        # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
 | 
				
			||||||
 | 
					        # https://arxiv.org/abs/1312.6114
 | 
				
			||||||
 | 
					        # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
 | 
				
			||||||
 | 
					        kl_divergence = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
 | 
				
			||||||
 | 
					        return train_loss + kl_divergence
 | 
				
			||||||
| 
						 | 
					@ -1,5 +1,5 @@
 | 
				
			||||||
torch==1.7.0+cu110
 | 
					torch==1.7.1
 | 
				
			||||||
torchvision==0.8.1+cu110
 | 
					torchvision==0.8.2
 | 
				
			||||||
torchaudio===0.7.0
 | 
					torchaudio===0.7.2
 | 
				
			||||||
tabulate
 | 
					tabulate
 | 
				
			||||||
matplotlib
 | 
					matplotlib
 | 
				
			||||||
| 
						 | 
					@ -1,3 +1,4 @@
 | 
				
			||||||
 | 
					import importlib
 | 
				
			||||||
import os
 | 
					import os
 | 
				
			||||||
from string import Template
 | 
					from string import Template
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -8,6 +9,13 @@ import matplotlib.pyplot as plt
 | 
				
			||||||
from config import TRAIN_TEMP_DATA_BASE_PATH
 | 
					from config import TRAIN_TEMP_DATA_BASE_PATH
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def load_dotted_path(path):
 | 
				
			||||||
 | 
					    split_path = path.split(".")
 | 
				
			||||||
 | 
					    modulename, classname = ".".join(split_path[:-1]), split_path[-1]
 | 
				
			||||||
 | 
					    model = getattr(importlib.import_module(modulename), classname)
 | 
				
			||||||
 | 
					    return model
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def training_header():
 | 
					def training_header():
 | 
				
			||||||
    """Prints a training header to the console"""
 | 
					    """Prints a training header to the console"""
 | 
				
			||||||
    print('| Epoch | Avg. loss | Train Acc. | Test Acc.  | Elapsed  |   ETA    |\n'
 | 
					    print('| Epoch | Avg. loss | Train Acc. | Test Acc.  | Elapsed  |   ETA    |\n'
 | 
				
			||||||
| 
						 | 
					@ -106,4 +114,5 @@ def save_train_loss_graph(train_loss, filename):
 | 
				
			||||||
    plt.title('Train Loss')
 | 
					    plt.title('Train Loss')
 | 
				
			||||||
    plt.xlabel('Epochs')
 | 
					    plt.xlabel('Epochs')
 | 
				
			||||||
    plt.ylabel('Loss')
 | 
					    plt.ylabel('Loss')
 | 
				
			||||||
 | 
					    plt.yscale('log')
 | 
				
			||||||
    plt.savefig(os.path.join(TRAIN_TEMP_DATA_BASE_PATH, f'{filename}_loss.png'))
 | 
					    plt.savefig(os.path.join(TRAIN_TEMP_DATA_BASE_PATH, f'{filename}_loss.png'))
 | 
				
			||||||
		Reference in a new issue