From bc95548ae35bd41f2160b5fbf454e5dca37cf8a3 Mon Sep 17 00:00:00 2001 From: Kevin Alberts Date: Thu, 14 Jan 2021 18:45:26 +0100 Subject: [PATCH] Move saving of samples to dataset, as the process differs per dataset. Add MNIST dataset. Allow saving labels with the dataset (for use in tabular data in the future) --- models/base_dataset.py | 13 ++++++-- models/base_encoder.py | 25 +++++++++------ models/cifar10_dataset.py | 11 ++++--- models/gaussian_corruption.py | 1 + models/mnist_dataset.py | 58 +++++++++++++++++++++++++++++++++++ 5 files changed, 91 insertions(+), 17 deletions(-) create mode 100644 models/mnist_dataset.py diff --git a/models/base_dataset.py b/models/base_dataset.py index abf394f..a837ade 100644 --- a/models/base_dataset.py +++ b/models/base_dataset.py @@ -17,6 +17,7 @@ class BaseDataset(Dataset): name = "BaseDataset" _source_path = None _data = None + _labels = None _trainset: 'BaseDataset' = None _testset: 'BaseDataset' = None transform = None @@ -43,11 +44,12 @@ class BaseDataset(Dataset): return self.transform(self._data[item]) if self.transform else self._data[item] @classmethod - def get_new(cls, name: str, data: Optional[list] = None, source_path: Optional[str] = None, + def get_new(cls, name: str, data: Optional[list] = None, labels: Optional[dict] = None, source_path: Optional[str] = None, train_set: Optional['BaseDataset'] = None, test_set: Optional['BaseDataset'] = None): dset = cls() dset.name = name dset._data = data + dset._labels = labels dset._source_path = source_path dset._trainset = train_set dset._testset = test_set @@ -75,8 +77,8 @@ class BaseDataset(Dataset): raise ValueError("Cannot subdivide! Invalid amount given, " "must be either a fraction between 0 and 1, or an integer.") - self._trainset = self.__class__.get_new(name=f"{self.name} Training", data=train_data, source_path=self._source_path) - self._testset = self.__class__.get_new(name=f"{self.name} Testing", data=test_data, source_path=self._source_path) + self._trainset = self.__class__.get_new(name=f"{self.name} Training", data=train_data, labels=self._labels, source_path=self._source_path) + self._testset = self.__class__.get_new(name=f"{self.name} Testing", data=test_data, labels=self._labels, source_path=self._source_path) def has_train(self): return self._trainset is not None @@ -103,3 +105,8 @@ class BaseDataset(Dataset): def get_test_loader(self, batch_size: int = 128, num_workers: int = 4) -> torch.utils.data.DataLoader: return self.get_loader(self.get_test(), batch_size=batch_size, num_workers=num_workers) + + def save_batch_to_sample(self, batch, filename): + # Save a batch of tensors to a sample file for comparison (no implementation for base dataset) + pass + diff --git a/models/base_encoder.py b/models/base_encoder.py index f167ea6..885dc0f 100644 --- a/models/base_encoder.py +++ b/models/base_encoder.py @@ -166,9 +166,11 @@ class BaseEncoder(torch.nn.Module): # Every 5 epochs, save a test image if epoch % 5 == 0: img = self.process_outputs_for_testing(outputs).cpu().data - img = img.view(img.size(0), 3, 32, 32) - save_image(img, os.path.join(TRAIN_TEMP_DATA_BASE_PATH, - f'{self.name}_{dataset.name}_linear_ae_image{epoch}.png')) + dataset.save_batch_to_sample( + batch=img, + filename=os.path.join(TRAIN_TEMP_DATA_BASE_PATH, + f'{self.name}_{dataset.name}_linear_ae_image{epoch}.png') + ) return losses @@ -180,16 +182,21 @@ class BaseEncoder(torch.nn.Module): self.log.debug(f"Start testing...") i = 0 for batch in test_loader: - img = batch.view(batch.size(0), 3, 32, 32) - save_image(img, os.path.join(TEST_TEMP_DATA_BASE_PATH, - f'{self.name}_{dataset.name}_test_input_{i}.png')) + dataset.save_batch_to_sample( + batch=batch, + filename=os.path.join(TEST_TEMP_DATA_BASE_PATH, + f'{self.name}_{dataset.name}_test_input_{i}') + ) + # load batch features to the active device batch = batch.to(self.device) outputs = self.process_outputs_for_testing(self(batch)) img = outputs.cpu().data - img = img.view(outputs.size(0), 3, 32, 32) - save_image(img, os.path.join(TEST_TEMP_DATA_BASE_PATH, - f'{self.name}_{dataset.name}_test_reconstruction_{i}.png')) + dataset.save_batch_to_sample( + batch=img, + filename=os.path.join(TEST_TEMP_DATA_BASE_PATH, + f'{self.name}_{dataset.name}_test_reconstruction_{i}') + ) i += 1 break diff --git a/models/cifar10_dataset.py b/models/cifar10_dataset.py index ead9f90..f2257ca 100644 --- a/models/cifar10_dataset.py +++ b/models/cifar10_dataset.py @@ -4,6 +4,7 @@ from typing import Optional import numpy from torchvision import transforms +from torchvision.utils import save_image from config import DATASET_STORAGE_BASE_PATH from models.base_dataset import BaseDataset @@ -12,13 +13,9 @@ from models.base_dataset import BaseDataset class Cifar10Dataset(BaseDataset): name = "CIFAR-10" - # transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(), - # torchvision.transforms.Normalize((0.5, ), (0.5, )) - # ]) transform = transforms.Compose([ transforms.ToPILImage(), - transforms.ToTensor(), - # transforms.Normalize((0.5,), (0.5,)) + transforms.ToTensor() ]) def unpickle(self, filename): @@ -72,3 +69,7 @@ class Cifar10Dataset(BaseDataset): img = img.view(-1, 32 * 32 * 3) return img + + def save_batch_to_sample(self, batch, filename): + img = batch.view(batch.size(0), 3, 32, 32) + save_image(img, f"{filename}.png") diff --git a/models/gaussian_corruption.py b/models/gaussian_corruption.py index c9b773c..5a301cf 100644 --- a/models/gaussian_corruption.py +++ b/models/gaussian_corruption.py @@ -31,6 +31,7 @@ class GaussianCorruption(BaseCorruption): return dataset.__class__.get_new( name=f"{dataset.name} Corrupted", data=data, + labels=dataset._labels, source_path=dataset._source_path, train_set=train_set, test_set=test_set) diff --git a/models/mnist_dataset.py b/models/mnist_dataset.py new file mode 100644 index 0000000..1582e08 --- /dev/null +++ b/models/mnist_dataset.py @@ -0,0 +1,58 @@ +import os + +from typing import Optional + +from torchvision import transforms +from torchvision.datasets import MNIST +from torchvision.utils import save_image + +from config import DATASET_STORAGE_BASE_PATH +from models.base_dataset import BaseDataset + + +class MNISTDataset(BaseDataset): + name = "MNIST" + + transform = transforms.Compose([ + transforms.ToPILImage(), + transforms.ToTensor() + ]) + + def load(self, name: Optional[str] = None, path: Optional[str] = None): + if name is not None: + self.name = name + if path is not None: + self._source_path = path + + train_dataset = MNIST(root=os.path.join(DATASET_STORAGE_BASE_PATH, self._source_path), train=True, download=True) + train_data = [x for x in train_dataset.data] + self._data = train_data + + self._trainset = self.__class__.get_new(name=f"{self.name} Training", data=train_data[:], + source_path=self._source_path) + + test_dataset = MNIST(root=os.path.join(DATASET_STORAGE_BASE_PATH, self._source_path), train=False, download=True) + test_data = [x for x in test_dataset.data] + self._data.extend(test_data) + + self._testset = self.__class__.get_new(name=f"{self.name} Testing", data=test_data[:], + source_path=self._source_path) + + self.log.info(f"Loaded {self}, divided into {self._trainset} and {self._testset}") + + def __getitem__(self, item): + # Get image data + img = self._data[item] + + # Run transforms + if self.transform is not None: + img = self.transform(img) + + # Reshape the 28x28x1 image to a 1x784 array for the Linear layer + img = img.view(-1, 28 * 28) + + return img + + def save_batch_to_sample(self, batch, filename): + img = batch.view(batch.size(0), 1, 28, 28) + save_image(img, f"{filename}.png")