diff --git a/config.example.py b/config.example.py index b003cd9..ef885de 100644 --- a/config.example.py +++ b/config.example.py @@ -1,5 +1,7 @@ MODEL_STORAGE_BASE_PATH = "/path/to/this/project/saved_models" DATASET_STORAGE_BASE_PATH = "/path/to/this/project/datasets" +TRAIN_TEMP_DATA_BASE_PATH = "/path/to/this/project/train_temp" +TEST_TEMP_DATA_BASE_PATH = "/path/to/this/project/test_temp" TEST_RUNS = [ { diff --git a/logging.conf b/logging.conf index 79cb15f..56846c9 100644 --- a/logging.conf +++ b/logging.conf @@ -1,5 +1,5 @@ [loggers] -keys=root +keys=root,matplotlib [handlers] keys=consoleHandler,fileHandler @@ -11,6 +11,12 @@ keys=simpleFormatter level=DEBUG handlers=consoleHandler,fileHandler +[logger_matplotlib] +level=NOTSET +handlers= +propagate=0 +qualname=matplotlib + [handler_fileHandler] class=FileHandler level=DEBUG diff --git a/main.py b/main.py index 3d78522..0e6f1ae 100644 --- a/main.py +++ b/main.py @@ -1,11 +1,10 @@ -import importlib - import config import logging.config # Get logging as early as possible! logging.config.fileConfig("logging.conf") +from utils import load_dotted_path from models.base_corruption import BaseCorruption from models.base_dataset import BaseDataset @@ -13,13 +12,6 @@ from models.base_encoder import BaseEncoder from models.test_run import TestRun -def load_dotted_path(path): - split_path = path.split(".") - modulename, classname = ".".join(split_path[:-1]), split_path[-1] - model = getattr(importlib.import_module(modulename), classname) - return model - - def run_tests(): logger = logging.getLogger("main.run_tests") for test in config.TEST_RUNS: @@ -41,9 +33,11 @@ def run_tests(): logger.debug(f"Using corruption model '{corruption_model.__name__}'") # Create TestRun instance - test_run = TestRun(dataset=dataset_model(**test['dataset_kwargs']), - encoder=encoder_model(**test['encoder_kwargs']), - corruption=corruption_model(**test['corruption_kwargs'])) + dataset = dataset_model(**test['dataset_kwargs']) + encoder = encoder_model(**test['encoder_kwargs']) + encoder.after_init() + corruption = corruption_model(**test['corruption_kwargs']) + test_run = TestRun(dataset=dataset, encoder=encoder, corruption=corruption) # Run TestRun test_run.run(retrain=False) diff --git a/models/base_corruption.py b/models/base_corruption.py index fb754f0..eb8cd3d 100644 --- a/models/base_corruption.py +++ b/models/base_corruption.py @@ -15,7 +15,11 @@ class BaseCorruption: return f"{self.name}" @classmethod - def corrupt(cls, dataset: BaseDataset) -> BaseDataset: + def corrupt_image(cls, image): + raise NotImplementedError() + + @classmethod + def corrupt_dataset(cls, dataset: BaseDataset) -> BaseDataset: raise NotImplementedError() @@ -26,5 +30,9 @@ class NoCorruption(BaseCorruption): name = "No corruption" @classmethod - def corrupt(cls, dataset: BaseDataset) -> BaseDataset: + def corrupt_image(cls, image): + return image + + @classmethod + def corrupt_dataset(cls, dataset: BaseDataset) -> BaseDataset: return dataset diff --git a/models/base_dataset.py b/models/base_dataset.py index 98aeef6..abf394f 100644 --- a/models/base_dataset.py +++ b/models/base_dataset.py @@ -78,6 +78,12 @@ class BaseDataset(Dataset): self._trainset = self.__class__.get_new(name=f"{self.name} Training", data=train_data, source_path=self._source_path) self._testset = self.__class__.get_new(name=f"{self.name} Testing", data=test_data, source_path=self._source_path) + def has_train(self): + return self._trainset is not None + + def has_test(self): + return self._testset is not None + def get_train(self) -> 'BaseDataset': if not self._trainset or not self._testset: self._subdivide(self.TRAIN_AMOUNT) diff --git a/models/base_encoder.py b/models/base_encoder.py index 9728850..f167ea6 100644 --- a/models/base_encoder.py +++ b/models/base_encoder.py @@ -6,6 +6,7 @@ import torch from typing import Optional +from torch.nn.modules.loss import _Loss from torchvision.utils import save_image from config import TRAIN_TEMP_DATA_BASE_PATH, TEST_TEMP_DATA_BASE_PATH, MODEL_STORAGE_BASE_PATH @@ -37,7 +38,8 @@ class BaseEncoder(torch.nn.Module): torch.nn.ReLU(), torch.nn.Linear(in_features=input_shape // 4, out_features=input_shape // 2), torch.nn.ReLU(), - torch.nn.Linear(in_features=input_shape // 2, out_features=input_shape) + torch.nn.Linear(in_features=input_shape // 2, out_features=input_shape), + torch.nn.ReLU() ) # Use GPU acceleration if available @@ -50,7 +52,10 @@ class BaseEncoder(torch.nn.Module): self.loss_function = torch.nn.MSELoss() def after_init(self): - self.log.info(f"Auto-encoder {self.__class__.__name__} initialized with {len(list(self.network.children()))} layers on {self.device.type}. Optimizer: {self.optimizer}, Loss function: {self.loss_function}") + self.log.info(f"Auto-encoder {self.__class__.__name__} initialized with " + f"{len(list(self.network.children())) if self.network else 'custom'} layers on " + f"{self.device.type}. Optimizer: {self.optimizer.__class__.__name__}, " + f"Loss function: {self.loss_function.__class__.__name__}") def forward(self, features): return self.network(features) @@ -109,21 +114,33 @@ class BaseEncoder(torch.nn.Module): outputs = None for epoch in range(epochs): - self.log.debug(f"Start training epoch {epoch}...") + self.log.debug(f"Start training epoch {epoch + 1}...") loss = 0 - for batch_features in train_loader: - # load batch features to the active device - batch_features = batch_features.to(self.device) + for i, batch_features in enumerate(train_loader): + # # load batch features to the active device + # batch_features = batch_features.to(self.device) # reset the gradients back to zero # PyTorch accumulates gradients on subsequent backward passes self.optimizer.zero_grad() + # Modify features used in training model (if necessary) and load to the active device + train_features = self.process_train_features(batch_features).to(self.device) + # compute reconstructions - outputs = self(batch_features) + outputs = self(train_features) + + # Modify outputs used in loss function (if necessary) and load to the active device + outputs_for_loss = self.process_outputs_for_loss_function(outputs).to(self.device) + + # Modify features used in comparing in loss function (if necessary) and load to the active device + compare_features = self.process_compare_features(batch_features).to(self.device) # compute training reconstruction loss - train_loss = self.loss_function(outputs, batch_features) + train_loss = self.loss_function(outputs_for_loss, compare_features) + + # Process loss if necessary (default implementation does nothing) + train_loss = self.process_loss(train_loss, compare_features, outputs) # compute accumulated gradients train_loss.backward() @@ -134,6 +151,11 @@ class BaseEncoder(torch.nn.Module): # add the mini-batch training loss to epoch loss loss += train_loss.item() + # Print progress every 50 batches + if i % 50 == 0: + self.log.debug(f" progress: [{i * len(batch_features)}/{len(train_loader.dataset)} " + f"({(100 * i / len(train_loader)):.0f}%)]") + # compute the epoch training loss loss = loss / len(train_loader) @@ -143,7 +165,7 @@ class BaseEncoder(torch.nn.Module): # Every 5 epochs, save a test image if epoch % 5 == 0: - img = outputs.cpu().data + img = self.process_outputs_for_testing(outputs).cpu().data img = img.view(img.size(0), 3, 32, 32) save_image(img, os.path.join(TRAIN_TEMP_DATA_BASE_PATH, f'{self.name}_{dataset.name}_linear_ae_image{epoch}.png')) @@ -163,10 +185,25 @@ class BaseEncoder(torch.nn.Module): f'{self.name}_{dataset.name}_test_input_{i}.png')) # load batch features to the active device batch = batch.to(self.device) - outputs = self(batch) + outputs = self.process_outputs_for_testing(self(batch)) img = outputs.cpu().data img = img.view(outputs.size(0), 3, 32, 32) save_image(img, os.path.join(TEST_TEMP_DATA_BASE_PATH, f'{self.name}_{dataset.name}_test_reconstruction_{i}.png')) i += 1 break + + def process_loss(self, train_loss, features, outputs) -> _Loss: + return train_loss + + def process_train_features(self, features): + return features + + def process_compare_features(self, features): + return features + + def process_outputs_for_loss_function(self, outputs): + return outputs + + def process_outputs_for_testing(self, outputs): + return outputs diff --git a/models/basic_encoder.py b/models/basic_encoder.py index d78d4c3..0746175 100644 --- a/models/basic_encoder.py +++ b/models/basic_encoder.py @@ -1,8 +1,4 @@ -import json import logging -import os - -import torch from typing import Optional @@ -19,30 +15,4 @@ class BasicAutoEncoder(BaseEncoder): # Call superclass to initialize parameters. super(BasicAutoEncoder, self).__init__(name, input_shape) - # Override parameters to custom values for this encoder type - - # TODO - Hoe kan ik het beste bepalen hoe groot de intermediate layers moeten zijn? - # - Proportioneel van input grootte naar opgegeven bottleneck grootte? - # - Uit een paper plukken - # - Zelf kiezen (e.g. helft elke keer, fixed aantal layers)? - self.network = torch.nn.Sequential( - torch.nn.Linear(in_features=input_shape, out_features=input_shape // 2), - torch.nn.ReLU(), - torch.nn.Linear(in_features=input_shape // 2, out_features=input_shape // 4), - torch.nn.ReLU(), - torch.nn.Linear(in_features=input_shape // 4, out_features=input_shape // 2), - torch.nn.ReLU(), - torch.nn.Linear(in_features=input_shape // 2, out_features=input_shape) - ) - - # Use GPU acceleration if available - # self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - # Adam optimizer with learning rate 1e-3 - self.optimizer = torch.optim.Adam(self.parameters(), lr=1e-3) - - # Mean Squared Error loss function - # self.loss_function = torch.nn.MSELoss() - - self.after_init() - + # Network, optimizer and loss function are the same as defined in the base encoder. diff --git a/models/cifar10_dataset.py b/models/cifar10_dataset.py index feb65b5..ead9f90 100644 --- a/models/cifar10_dataset.py +++ b/models/cifar10_dataset.py @@ -3,8 +3,6 @@ import os from typing import Optional import numpy -import torchvision -from PIL import Image from torchvision import transforms from config import DATASET_STORAGE_BASE_PATH @@ -12,6 +10,7 @@ from models.base_dataset import BaseDataset class Cifar10Dataset(BaseDataset): + name = "CIFAR-10" # transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(), # torchvision.transforms.Normalize((0.5, ), (0.5, )) diff --git a/models/contractive_encoder.py b/models/contractive_encoder.py new file mode 100644 index 0000000..98685da --- /dev/null +++ b/models/contractive_encoder.py @@ -0,0 +1,88 @@ +import logging + +import torch + +from typing import Optional + +from torch.autograd import Variable + +from models.base_encoder import BaseEncoder + + +class ContractiveAutoEncoder(BaseEncoder): + # Based on https://github.com/avijit9/Contractive_Autoencoder_in_Pytorch/blob/master/CAE_pytorch.py + name = "ContractiveAutoEncoder" + + def __init__(self, name: Optional[str] = None, input_shape: int = 0, regularizer_weight: float = 1e-4): + self.log = logging.getLogger(self.__class__.__name__) + + # Call superclass to initialize parameters. + super(ContractiveAutoEncoder, self).__init__(name, input_shape) + + self.regularizer_weight = regularizer_weight + + # CAE needs intermediate output of the encoder stage, so split up the network into encoder/decoder + self.network = None + self.encoder = torch.nn.Sequential( + torch.nn.Linear(in_features=input_shape, out_features=input_shape // 2, bias=False), + torch.nn.ReLU(), + torch.nn.Linear(in_features=input_shape // 2, out_features=input_shape // 4, bias=False), + torch.nn.ReLU() + ) + self.decoder = torch.nn.Sequential( + torch.nn.Linear(in_features=input_shape // 4, out_features=input_shape // 2, bias=False), + torch.nn.ReLU(), + torch.nn.Linear(in_features=input_shape // 2, out_features=input_shape, bias=False), + torch.nn.ReLU() + ) + + # Adam optimizer with learning rate 1e-3 (parameters changed so we need to re-declare it) + self.optimizer = torch.optim.Adam(self.parameters(), lr=1e-3) + + # Loss function is the same as defined in the base encoder. + + # Because network is split up now, and a CAE needs the intermediate representations, + # we need to override the forward method and return the intermediate state as well. + def forward(self, features): + encoder_out = self.encoder(features) + decoder_out = self.decoder(encoder_out) + return encoder_out, decoder_out + + # Because the model returns the intermediate state as well as the output, + # we need to override some output processing functions + def process_outputs_for_loss_function(self, outputs): + return outputs[1] + + def process_outputs_for_testing(self, outputs): + return outputs[1] + + # Finally we define the loss process function, which adds the contractive loss. + def process_loss(self, train_loss, features, outputs): + """ + Evaluates the CAE loss, which is the summation of the MSE and the weighted L2-norm of the Jacobian of the + hidden units with respect to the inputs. + + Reference: http://wiseodd.github.io/techblog/2016/12/05/contractive-autoencoder + + :param train_loss: The (MSE) loss as returned by the loss function of the model + :param features: The input features + :param outputs: The raw outputs as returned by the model (in this case includes the hidden encoder output) + """ + hidden_output = outputs[0] + # Weights of the second Linear layer in the encoder (index 2) + weights = self.state_dict()['encoder.2.weight'] + + # Hadamard product + hidden_output = hidden_output.reshape(hidden_output.shape[0], hidden_output.shape[2]) + dh = hidden_output * (1 - hidden_output) + + # Sum through input dimension to improve efficiency (suggested in reference) + w_sum = torch.sum(Variable(weights) ** 2, dim=1) + + # Unsqueeze to avoid issues with torch.mv + w_sum = w_sum.unsqueeze(1) + + # Calculate contractive loss + contractive_loss = torch.sum(torch.mm(dh ** 2, w_sum), 0) + + return train_loss + contractive_loss.mul_(self.regularizer_weight) diff --git a/models/denoising_encoder.py b/models/denoising_encoder.py new file mode 100644 index 0000000..8af747a --- /dev/null +++ b/models/denoising_encoder.py @@ -0,0 +1,35 @@ +import logging + +import torch + +from typing import Optional + +from torch import Tensor + +from main import load_dotted_path +from models.base_corruption import BaseCorruption, NoCorruption +from models.base_encoder import BaseEncoder + + +class DenoisingAutoEncoder(BaseEncoder): + # Based on https://github.com/pranjaldatta/Denoising-Autoencoder-in-Pytorch/blob/master/DenoisingAutoencoder.ipynb + name = "DenoisingAutoEncoder" + + def __init__(self, name: Optional[str] = None, input_shape: int = 0, + input_corruption_model: BaseCorruption = NoCorruption): + self.log = logging.getLogger(self.__class__.__name__) + + # Call superclass to initialize parameters. + super(DenoisingAutoEncoder, self).__init__(name, input_shape) + + # Network, optimizer and loss function are the same as defined in the base encoder. + + # Corruption used for data corruption during training + if isinstance(input_corruption_model, str): + self.input_corruption_model = load_dotted_path(input_corruption_model) + else: + self.input_corruption_model = input_corruption_model + + # Need to corrupt features used for training, to add 'noise' to training data (comparison features are unchanged) + def process_train_features(self, features: Tensor) -> Tensor: + return torch.tensor([self.input_corruption_model.corrupt_image(x) for x in features], dtype=torch.float32) diff --git a/models/gaussian_corruption.py b/models/gaussian_corruption.py new file mode 100644 index 0000000..c9b773c --- /dev/null +++ b/models/gaussian_corruption.py @@ -0,0 +1,36 @@ +from torch import Tensor + +from models.base_corruption import BaseCorruption +from models.base_dataset import BaseDataset +import numpy + + +def add_noise(image): + image = image.astype(numpy.float32) + mean, variance = 0, 0.1 + sigma = variance ** 0.5 + noise = numpy.random.normal(mean, sigma, image.shape).reshape(image.shape) + return image + noise + + +class GaussianCorruption(BaseCorruption): + """ + Corruption model that adds Gaussian noise to the dataset. + """ + name = "Gaussian" + + @classmethod + def corrupt_image(cls, image: Tensor): + return add_noise(image.numpy()) + + @classmethod + def corrupt_dataset(cls, dataset: BaseDataset) -> BaseDataset: + data = list(map(add_noise, dataset)) + train_set = cls.corrupt_dataset(dataset.get_train()) if dataset.has_train() else None + test_set = cls.corrupt_dataset(dataset.get_test()) if dataset.has_test() else None + return dataset.__class__.get_new( + name=f"{dataset.name} Corrupted", + data=data, + source_path=dataset._source_path, + train_set=train_set, + test_set=test_set) diff --git a/models/sparse_encoder.py b/models/sparse_encoder.py new file mode 100644 index 0000000..720a4b1 --- /dev/null +++ b/models/sparse_encoder.py @@ -0,0 +1,68 @@ +import logging +import math + +import torch + +from typing import Optional + +from torch.nn.modules.loss import _Loss + +from models.base_encoder import BaseEncoder + + +class SparseL1AutoEncoder(BaseEncoder): + # Based on https://debuggercafe.com/sparse-autoencoders-using-l1-regularization-with-pytorch/ + name = "SparseL1AutoEncoder" + + def __init__(self, name: Optional[str] = None, input_shape: int = 0, regularization_parameter: float = 0.001): + self.log = logging.getLogger(self.__class__.__name__) + + # Call superclass to initialize parameters. + super(SparseL1AutoEncoder, self).__init__(name, input_shape) + + # Override parameters to custom values for this encoder type + + # Sparse encoder has larger intermediary layers, so let's increase them 1.5 times in the first layer, + # and 2 times in the second layer (compared to the original input shape + self.network = torch.nn.Sequential( + torch.nn.Linear(in_features=input_shape, out_features=math.floor(input_shape * 1.5)), + torch.nn.ReLU(), + torch.nn.Linear(in_features=math.floor(input_shape * 1.5), out_features=input_shape * 2), + torch.nn.ReLU(), + torch.nn.Linear(in_features=input_shape * 2, out_features=math.floor(input_shape * 1.5)), + torch.nn.ReLU(), + torch.nn.Linear(in_features=math.floor(input_shape * 1.5), out_features=input_shape), + torch.nn.ReLU() + ) + + # Adam optimizer with learning rate 1e-3 (parameters changed so we need to re-declare it) + self.optimizer = torch.optim.Adam(self.parameters(), lr=1e-3) + + # Loss function is the same as defined in the base encoder. + + # Regularization parameter (lambda) for the L1 sparse loss function + self.regularization_parameter = regularization_parameter + + def get_sparse_loss(self, images): + def get_sparse_loss_rec(loss, values, children): + for child in children: + if isinstance(child, torch.nn.Sequential): + loss, values = get_sparse_loss_rec(loss, values, [x for x in child]) + elif isinstance(child, torch.nn.ReLU): + values = child(values) + loss += torch.mean(torch.abs(values)) + elif isinstance(child, torch.nn.Linear): + values = child(values) + else: + # Ignore unknown layers in sparse loss calculation + pass + return loss, values + + loss, values = get_sparse_loss_rec(loss=0, values=images, children=list(self.children())) + return loss + + + def process_loss(self, train_loss, features, outputs) -> _Loss: + l1_loss = self.get_sparse_loss(features) + # Add sparsity penalty + return train_loss + (self.regularization_parameter * l1_loss) diff --git a/models/test_run.py b/models/test_run.py index 684b9b5..e4b1c2d 100644 --- a/models/test_run.py +++ b/models/test_run.py @@ -4,7 +4,7 @@ import multiprocessing from models.base_corruption import BaseCorruption from models.base_dataset import BaseDataset from models.base_encoder import BaseEncoder -from output_utils import save_train_loss_graph +from utils import save_train_loss_graph class TestRun: @@ -49,7 +49,7 @@ class TestRun: # Save train loss graph self.log.info("Saving loss graph...") - save_train_loss_graph(train_loss, self.dataset.name) + save_train_loss_graph(train_loss, f"{self.encoder.name}_{self.dataset.name}") else: self.log.info("Loading saved auto-encoder...") load_success = self.encoder.load_model(f"{self.encoder.name}_{self.dataset.name}") diff --git a/models/variational_encoder.py b/models/variational_encoder.py new file mode 100644 index 0000000..4b20529 --- /dev/null +++ b/models/variational_encoder.py @@ -0,0 +1,84 @@ +import logging + +import torch + +from typing import Optional + +from models.base_encoder import BaseEncoder + + +class VariationalAutoEncoder(BaseEncoder): + # Based on https://debuggercafe.com/getting-started-with-variational-autoencoder-using-pytorch/ + # and https://github.com/pytorch/examples/blob/master/vae/main.py + name = "VariationalAutoEncoder" + + def __init__(self, name: Optional[str] = None, input_shape: int = 0): + self.log = logging.getLogger(self.__class__.__name__) + + # Call superclass to initialize parameters. + super(VariationalAutoEncoder, self).__init__(name, input_shape) + + # VAE needs intermediate output of the encoder stage, so split up the network into encoder/decoder + # with no ReLU layer at the end of the encoder so we have access to the mu and variance. + # We also split the last layer of the encoder in two, so we can make two passes. + # One to determine the mu and one to determine the variance + self.network = None + self.encoder1 = torch.nn.Sequential( + torch.nn.Linear(in_features=input_shape, out_features=input_shape // 2, bias=False), + torch.nn.ReLU() + ) + self.encoder2_1 = torch.nn.Linear(in_features=input_shape // 2, out_features=input_shape // 4, bias=False) + self.encoder2_2 = torch.nn.Linear(in_features=input_shape // 2, out_features=input_shape // 4, bias=False) + self.decoder = torch.nn.Sequential( + torch.nn.Linear(in_features=input_shape // 4, out_features=input_shape // 2, bias=False), + torch.nn.ReLU(), + torch.nn.Linear(in_features=input_shape // 2, out_features=input_shape, bias=False), + torch.nn.ReLU() + ) + + # Adam optimizer with learning rate 1e-3 (parameters changed so we need to re-declare it) + self.optimizer = torch.optim.Adam(self.parameters(), lr=1e-3) + + # Loss function is the same as defined in the base encoder. + + # Reparameterize takes a mu and variance, and returns a sample from the distribution N(mu(X), log_var(x)) + def reparameterize(self, mu, log_var): + # z = μ(X) + Σ1/2(X)∗e + std_dev = torch.exp(0.5 * log_var) + epsilon = torch.randn_like(std_dev) + return mu + (std_dev * epsilon) + + # Because network is split up, and a VAE needs the intermediate representations, needs to modify the input to the + # decoder, and needs to return the parameters used for the sample, we need to override the forward method. + def forward(self, features): + encoder_out = self.encoder1(features) + + # Use the two last layers of the encoder to determine the mu and log_var + mu = self.encoder2_1(encoder_out) + log_var = self.encoder2_2(encoder_out) + + # Get a sample from the distribution with mu and log_var, for use in the decoder + sample = self.reparameterize(mu, log_var) + + decoder_out = self.decoder(sample) + return decoder_out, mu, log_var + + # Because the model returns the mu and log_var in addition to the output, + # we need to override some output processing functions + def process_outputs_for_loss_function(self, outputs): + return outputs[0] + + def process_outputs_for_testing(self, outputs): + return outputs[0] + + # After the loss function is executed, we modify the loss with KL divergence + def process_loss(self, train_loss, features, outputs): + # Loss = Reconstruction loss + KL divergence loss summed over all elements and batch + _, mu, log_var = outputs + + # see Appendix B from VAE paper: + # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014 + # https://arxiv.org/abs/1312.6114 + # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) + kl_divergence = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp()) + return train_loss + kl_divergence diff --git a/requirements.txt b/requirements.txt index 760c8e6..20c67e6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ -torch==1.7.0+cu110 -torchvision==0.8.1+cu110 -torchaudio===0.7.0 +torch==1.7.1 +torchvision==0.8.2 +torchaudio===0.7.2 tabulate matplotlib \ No newline at end of file diff --git a/output_utils.py b/utils.py similarity index 94% rename from output_utils.py rename to utils.py index 582fcd9..b80b69c 100644 --- a/output_utils.py +++ b/utils.py @@ -1,3 +1,4 @@ +import importlib import os from string import Template @@ -8,6 +9,13 @@ import matplotlib.pyplot as plt from config import TRAIN_TEMP_DATA_BASE_PATH +def load_dotted_path(path): + split_path = path.split(".") + modulename, classname = ".".join(split_path[:-1]), split_path[-1] + model = getattr(importlib.import_module(modulename), classname) + return model + + def training_header(): """Prints a training header to the console""" print('| Epoch | Avg. loss | Train Acc. | Test Acc. | Elapsed | ETA |\n' @@ -106,4 +114,5 @@ def save_train_loss_graph(train_loss, filename): plt.title('Train Loss') plt.xlabel('Epochs') plt.ylabel('Loss') + plt.yscale('log') plt.savefig(os.path.join(TRAIN_TEMP_DATA_BASE_PATH, f'{filename}_loss.png'))