Move saving of samples to dataset, as the process differs per dataset. Add MNIST dataset. Allow saving labels with the dataset (for use in tabular data in the future)

This commit is contained in:
Kevin Alberts 2021-01-14 18:45:26 +01:00
parent fb9ce46bd8
commit bc95548ae3
Signed by: Kurocon
GPG key ID: BCD496FEBA0C6BC1
5 changed files with 91 additions and 17 deletions

View file

@ -17,6 +17,7 @@ class BaseDataset(Dataset):
name = "BaseDataset" name = "BaseDataset"
_source_path = None _source_path = None
_data = None _data = None
_labels = None
_trainset: 'BaseDataset' = None _trainset: 'BaseDataset' = None
_testset: 'BaseDataset' = None _testset: 'BaseDataset' = None
transform = None transform = None
@ -43,11 +44,12 @@ class BaseDataset(Dataset):
return self.transform(self._data[item]) if self.transform else self._data[item] return self.transform(self._data[item]) if self.transform else self._data[item]
@classmethod @classmethod
def get_new(cls, name: str, data: Optional[list] = None, source_path: Optional[str] = None, def get_new(cls, name: str, data: Optional[list] = None, labels: Optional[dict] = None, source_path: Optional[str] = None,
train_set: Optional['BaseDataset'] = None, test_set: Optional['BaseDataset'] = None): train_set: Optional['BaseDataset'] = None, test_set: Optional['BaseDataset'] = None):
dset = cls() dset = cls()
dset.name = name dset.name = name
dset._data = data dset._data = data
dset._labels = labels
dset._source_path = source_path dset._source_path = source_path
dset._trainset = train_set dset._trainset = train_set
dset._testset = test_set dset._testset = test_set
@ -75,8 +77,8 @@ class BaseDataset(Dataset):
raise ValueError("Cannot subdivide! Invalid amount given, " raise ValueError("Cannot subdivide! Invalid amount given, "
"must be either a fraction between 0 and 1, or an integer.") "must be either a fraction between 0 and 1, or an integer.")
self._trainset = self.__class__.get_new(name=f"{self.name} Training", data=train_data, source_path=self._source_path) self._trainset = self.__class__.get_new(name=f"{self.name} Training", data=train_data, labels=self._labels, source_path=self._source_path)
self._testset = self.__class__.get_new(name=f"{self.name} Testing", data=test_data, source_path=self._source_path) self._testset = self.__class__.get_new(name=f"{self.name} Testing", data=test_data, labels=self._labels, source_path=self._source_path)
def has_train(self): def has_train(self):
return self._trainset is not None return self._trainset is not None
@ -103,3 +105,8 @@ class BaseDataset(Dataset):
def get_test_loader(self, batch_size: int = 128, num_workers: int = 4) -> torch.utils.data.DataLoader: def get_test_loader(self, batch_size: int = 128, num_workers: int = 4) -> torch.utils.data.DataLoader:
return self.get_loader(self.get_test(), batch_size=batch_size, num_workers=num_workers) return self.get_loader(self.get_test(), batch_size=batch_size, num_workers=num_workers)
def save_batch_to_sample(self, batch, filename):
# Save a batch of tensors to a sample file for comparison (no implementation for base dataset)
pass

View file

@ -166,9 +166,11 @@ class BaseEncoder(torch.nn.Module):
# Every 5 epochs, save a test image # Every 5 epochs, save a test image
if epoch % 5 == 0: if epoch % 5 == 0:
img = self.process_outputs_for_testing(outputs).cpu().data img = self.process_outputs_for_testing(outputs).cpu().data
img = img.view(img.size(0), 3, 32, 32) dataset.save_batch_to_sample(
save_image(img, os.path.join(TRAIN_TEMP_DATA_BASE_PATH, batch=img,
f'{self.name}_{dataset.name}_linear_ae_image{epoch}.png')) filename=os.path.join(TRAIN_TEMP_DATA_BASE_PATH,
f'{self.name}_{dataset.name}_linear_ae_image{epoch}.png')
)
return losses return losses
@ -180,16 +182,21 @@ class BaseEncoder(torch.nn.Module):
self.log.debug(f"Start testing...") self.log.debug(f"Start testing...")
i = 0 i = 0
for batch in test_loader: for batch in test_loader:
img = batch.view(batch.size(0), 3, 32, 32) dataset.save_batch_to_sample(
save_image(img, os.path.join(TEST_TEMP_DATA_BASE_PATH, batch=batch,
f'{self.name}_{dataset.name}_test_input_{i}.png')) filename=os.path.join(TEST_TEMP_DATA_BASE_PATH,
f'{self.name}_{dataset.name}_test_input_{i}')
)
# load batch features to the active device # load batch features to the active device
batch = batch.to(self.device) batch = batch.to(self.device)
outputs = self.process_outputs_for_testing(self(batch)) outputs = self.process_outputs_for_testing(self(batch))
img = outputs.cpu().data img = outputs.cpu().data
img = img.view(outputs.size(0), 3, 32, 32) dataset.save_batch_to_sample(
save_image(img, os.path.join(TEST_TEMP_DATA_BASE_PATH, batch=img,
f'{self.name}_{dataset.name}_test_reconstruction_{i}.png')) filename=os.path.join(TEST_TEMP_DATA_BASE_PATH,
f'{self.name}_{dataset.name}_test_reconstruction_{i}')
)
i += 1 i += 1
break break

View file

@ -4,6 +4,7 @@ from typing import Optional
import numpy import numpy
from torchvision import transforms from torchvision import transforms
from torchvision.utils import save_image
from config import DATASET_STORAGE_BASE_PATH from config import DATASET_STORAGE_BASE_PATH
from models.base_dataset import BaseDataset from models.base_dataset import BaseDataset
@ -12,13 +13,9 @@ from models.base_dataset import BaseDataset
class Cifar10Dataset(BaseDataset): class Cifar10Dataset(BaseDataset):
name = "CIFAR-10" name = "CIFAR-10"
# transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),
# torchvision.transforms.Normalize((0.5, ), (0.5, ))
# ])
transform = transforms.Compose([ transform = transforms.Compose([
transforms.ToPILImage(), transforms.ToPILImage(),
transforms.ToTensor(), transforms.ToTensor()
# transforms.Normalize((0.5,), (0.5,))
]) ])
def unpickle(self, filename): def unpickle(self, filename):
@ -72,3 +69,7 @@ class Cifar10Dataset(BaseDataset):
img = img.view(-1, 32 * 32 * 3) img = img.view(-1, 32 * 32 * 3)
return img return img
def save_batch_to_sample(self, batch, filename):
img = batch.view(batch.size(0), 3, 32, 32)
save_image(img, f"{filename}.png")

View file

@ -31,6 +31,7 @@ class GaussianCorruption(BaseCorruption):
return dataset.__class__.get_new( return dataset.__class__.get_new(
name=f"{dataset.name} Corrupted", name=f"{dataset.name} Corrupted",
data=data, data=data,
labels=dataset._labels,
source_path=dataset._source_path, source_path=dataset._source_path,
train_set=train_set, train_set=train_set,
test_set=test_set) test_set=test_set)

58
models/mnist_dataset.py Normal file
View file

@ -0,0 +1,58 @@
import os
from typing import Optional
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.utils import save_image
from config import DATASET_STORAGE_BASE_PATH
from models.base_dataset import BaseDataset
class MNISTDataset(BaseDataset):
name = "MNIST"
transform = transforms.Compose([
transforms.ToPILImage(),
transforms.ToTensor()
])
def load(self, name: Optional[str] = None, path: Optional[str] = None):
if name is not None:
self.name = name
if path is not None:
self._source_path = path
train_dataset = MNIST(root=os.path.join(DATASET_STORAGE_BASE_PATH, self._source_path), train=True, download=True)
train_data = [x for x in train_dataset.data]
self._data = train_data
self._trainset = self.__class__.get_new(name=f"{self.name} Training", data=train_data[:],
source_path=self._source_path)
test_dataset = MNIST(root=os.path.join(DATASET_STORAGE_BASE_PATH, self._source_path), train=False, download=True)
test_data = [x for x in test_dataset.data]
self._data.extend(test_data)
self._testset = self.__class__.get_new(name=f"{self.name} Testing", data=test_data[:],
source_path=self._source_path)
self.log.info(f"Loaded {self}, divided into {self._trainset} and {self._testset}")
def __getitem__(self, item):
# Get image data
img = self._data[item]
# Run transforms
if self.transform is not None:
img = self.transform(img)
# Reshape the 28x28x1 image to a 1x784 array for the Linear layer
img = img.view(-1, 28 * 28)
return img
def save_batch_to_sample(self, batch, filename):
img = batch.view(batch.size(0), 1, 28, 28)
save_image(img, f"{filename}.png")