Move saving of samples to dataset, as the process differs per dataset. Add MNIST dataset. Allow saving labels with the dataset (for use in tabular data in the future)

This commit is contained in:
Kevin Alberts 2021-01-14 18:45:26 +01:00
parent fb9ce46bd8
commit bc95548ae3
Signed by: Kurocon
GPG key ID: BCD496FEBA0C6BC1
5 changed files with 91 additions and 17 deletions

View file

@ -17,6 +17,7 @@ class BaseDataset(Dataset):
name = "BaseDataset"
_source_path = None
_data = None
_labels = None
_trainset: 'BaseDataset' = None
_testset: 'BaseDataset' = None
transform = None
@ -43,11 +44,12 @@ class BaseDataset(Dataset):
return self.transform(self._data[item]) if self.transform else self._data[item]
@classmethod
def get_new(cls, name: str, data: Optional[list] = None, source_path: Optional[str] = None,
def get_new(cls, name: str, data: Optional[list] = None, labels: Optional[dict] = None, source_path: Optional[str] = None,
train_set: Optional['BaseDataset'] = None, test_set: Optional['BaseDataset'] = None):
dset = cls()
dset.name = name
dset._data = data
dset._labels = labels
dset._source_path = source_path
dset._trainset = train_set
dset._testset = test_set
@ -75,8 +77,8 @@ class BaseDataset(Dataset):
raise ValueError("Cannot subdivide! Invalid amount given, "
"must be either a fraction between 0 and 1, or an integer.")
self._trainset = self.__class__.get_new(name=f"{self.name} Training", data=train_data, source_path=self._source_path)
self._testset = self.__class__.get_new(name=f"{self.name} Testing", data=test_data, source_path=self._source_path)
self._trainset = self.__class__.get_new(name=f"{self.name} Training", data=train_data, labels=self._labels, source_path=self._source_path)
self._testset = self.__class__.get_new(name=f"{self.name} Testing", data=test_data, labels=self._labels, source_path=self._source_path)
def has_train(self):
return self._trainset is not None
@ -103,3 +105,8 @@ class BaseDataset(Dataset):
def get_test_loader(self, batch_size: int = 128, num_workers: int = 4) -> torch.utils.data.DataLoader:
return self.get_loader(self.get_test(), batch_size=batch_size, num_workers=num_workers)
def save_batch_to_sample(self, batch, filename):
# Save a batch of tensors to a sample file for comparison (no implementation for base dataset)
pass

View file

@ -166,9 +166,11 @@ class BaseEncoder(torch.nn.Module):
# Every 5 epochs, save a test image
if epoch % 5 == 0:
img = self.process_outputs_for_testing(outputs).cpu().data
img = img.view(img.size(0), 3, 32, 32)
save_image(img, os.path.join(TRAIN_TEMP_DATA_BASE_PATH,
f'{self.name}_{dataset.name}_linear_ae_image{epoch}.png'))
dataset.save_batch_to_sample(
batch=img,
filename=os.path.join(TRAIN_TEMP_DATA_BASE_PATH,
f'{self.name}_{dataset.name}_linear_ae_image{epoch}.png')
)
return losses
@ -180,16 +182,21 @@ class BaseEncoder(torch.nn.Module):
self.log.debug(f"Start testing...")
i = 0
for batch in test_loader:
img = batch.view(batch.size(0), 3, 32, 32)
save_image(img, os.path.join(TEST_TEMP_DATA_BASE_PATH,
f'{self.name}_{dataset.name}_test_input_{i}.png'))
dataset.save_batch_to_sample(
batch=batch,
filename=os.path.join(TEST_TEMP_DATA_BASE_PATH,
f'{self.name}_{dataset.name}_test_input_{i}')
)
# load batch features to the active device
batch = batch.to(self.device)
outputs = self.process_outputs_for_testing(self(batch))
img = outputs.cpu().data
img = img.view(outputs.size(0), 3, 32, 32)
save_image(img, os.path.join(TEST_TEMP_DATA_BASE_PATH,
f'{self.name}_{dataset.name}_test_reconstruction_{i}.png'))
dataset.save_batch_to_sample(
batch=img,
filename=os.path.join(TEST_TEMP_DATA_BASE_PATH,
f'{self.name}_{dataset.name}_test_reconstruction_{i}')
)
i += 1
break

View file

@ -4,6 +4,7 @@ from typing import Optional
import numpy
from torchvision import transforms
from torchvision.utils import save_image
from config import DATASET_STORAGE_BASE_PATH
from models.base_dataset import BaseDataset
@ -12,13 +13,9 @@ from models.base_dataset import BaseDataset
class Cifar10Dataset(BaseDataset):
name = "CIFAR-10"
# transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),
# torchvision.transforms.Normalize((0.5, ), (0.5, ))
# ])
transform = transforms.Compose([
transforms.ToPILImage(),
transforms.ToTensor(),
# transforms.Normalize((0.5,), (0.5,))
transforms.ToTensor()
])
def unpickle(self, filename):
@ -72,3 +69,7 @@ class Cifar10Dataset(BaseDataset):
img = img.view(-1, 32 * 32 * 3)
return img
def save_batch_to_sample(self, batch, filename):
img = batch.view(batch.size(0), 3, 32, 32)
save_image(img, f"{filename}.png")

View file

@ -31,6 +31,7 @@ class GaussianCorruption(BaseCorruption):
return dataset.__class__.get_new(
name=f"{dataset.name} Corrupted",
data=data,
labels=dataset._labels,
source_path=dataset._source_path,
train_set=train_set,
test_set=test_set)

58
models/mnist_dataset.py Normal file
View file

@ -0,0 +1,58 @@
import os
from typing import Optional
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.utils import save_image
from config import DATASET_STORAGE_BASE_PATH
from models.base_dataset import BaseDataset
class MNISTDataset(BaseDataset):
name = "MNIST"
transform = transforms.Compose([
transforms.ToPILImage(),
transforms.ToTensor()
])
def load(self, name: Optional[str] = None, path: Optional[str] = None):
if name is not None:
self.name = name
if path is not None:
self._source_path = path
train_dataset = MNIST(root=os.path.join(DATASET_STORAGE_BASE_PATH, self._source_path), train=True, download=True)
train_data = [x for x in train_dataset.data]
self._data = train_data
self._trainset = self.__class__.get_new(name=f"{self.name} Training", data=train_data[:],
source_path=self._source_path)
test_dataset = MNIST(root=os.path.join(DATASET_STORAGE_BASE_PATH, self._source_path), train=False, download=True)
test_data = [x for x in test_dataset.data]
self._data.extend(test_data)
self._testset = self.__class__.get_new(name=f"{self.name} Testing", data=test_data[:],
source_path=self._source_path)
self.log.info(f"Loaded {self}, divided into {self._trainset} and {self._testset}")
def __getitem__(self, item):
# Get image data
img = self._data[item]
# Run transforms
if self.transform is not None:
img = self.transform(img)
# Reshape the 28x28x1 image to a 1x784 array for the Linear layer
img = img.view(-1, 28 * 28)
return img
def save_batch_to_sample(self, batch, filename):
img = batch.view(batch.size(0), 1, 28, 28)
save_image(img, f"{filename}.png")