Move saving of samples to dataset, as the process differs per dataset. Add MNIST dataset. Allow saving labels with the dataset (for use in tabular data in the future)
This commit is contained in:
parent
fb9ce46bd8
commit
bc95548ae3
|
@ -17,6 +17,7 @@ class BaseDataset(Dataset):
|
||||||
name = "BaseDataset"
|
name = "BaseDataset"
|
||||||
_source_path = None
|
_source_path = None
|
||||||
_data = None
|
_data = None
|
||||||
|
_labels = None
|
||||||
_trainset: 'BaseDataset' = None
|
_trainset: 'BaseDataset' = None
|
||||||
_testset: 'BaseDataset' = None
|
_testset: 'BaseDataset' = None
|
||||||
transform = None
|
transform = None
|
||||||
|
@ -43,11 +44,12 @@ class BaseDataset(Dataset):
|
||||||
return self.transform(self._data[item]) if self.transform else self._data[item]
|
return self.transform(self._data[item]) if self.transform else self._data[item]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_new(cls, name: str, data: Optional[list] = None, source_path: Optional[str] = None,
|
def get_new(cls, name: str, data: Optional[list] = None, labels: Optional[dict] = None, source_path: Optional[str] = None,
|
||||||
train_set: Optional['BaseDataset'] = None, test_set: Optional['BaseDataset'] = None):
|
train_set: Optional['BaseDataset'] = None, test_set: Optional['BaseDataset'] = None):
|
||||||
dset = cls()
|
dset = cls()
|
||||||
dset.name = name
|
dset.name = name
|
||||||
dset._data = data
|
dset._data = data
|
||||||
|
dset._labels = labels
|
||||||
dset._source_path = source_path
|
dset._source_path = source_path
|
||||||
dset._trainset = train_set
|
dset._trainset = train_set
|
||||||
dset._testset = test_set
|
dset._testset = test_set
|
||||||
|
@ -75,8 +77,8 @@ class BaseDataset(Dataset):
|
||||||
raise ValueError("Cannot subdivide! Invalid amount given, "
|
raise ValueError("Cannot subdivide! Invalid amount given, "
|
||||||
"must be either a fraction between 0 and 1, or an integer.")
|
"must be either a fraction between 0 and 1, or an integer.")
|
||||||
|
|
||||||
self._trainset = self.__class__.get_new(name=f"{self.name} Training", data=train_data, source_path=self._source_path)
|
self._trainset = self.__class__.get_new(name=f"{self.name} Training", data=train_data, labels=self._labels, source_path=self._source_path)
|
||||||
self._testset = self.__class__.get_new(name=f"{self.name} Testing", data=test_data, source_path=self._source_path)
|
self._testset = self.__class__.get_new(name=f"{self.name} Testing", data=test_data, labels=self._labels, source_path=self._source_path)
|
||||||
|
|
||||||
def has_train(self):
|
def has_train(self):
|
||||||
return self._trainset is not None
|
return self._trainset is not None
|
||||||
|
@ -103,3 +105,8 @@ class BaseDataset(Dataset):
|
||||||
|
|
||||||
def get_test_loader(self, batch_size: int = 128, num_workers: int = 4) -> torch.utils.data.DataLoader:
|
def get_test_loader(self, batch_size: int = 128, num_workers: int = 4) -> torch.utils.data.DataLoader:
|
||||||
return self.get_loader(self.get_test(), batch_size=batch_size, num_workers=num_workers)
|
return self.get_loader(self.get_test(), batch_size=batch_size, num_workers=num_workers)
|
||||||
|
|
||||||
|
def save_batch_to_sample(self, batch, filename):
|
||||||
|
# Save a batch of tensors to a sample file for comparison (no implementation for base dataset)
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
|
@ -166,9 +166,11 @@ class BaseEncoder(torch.nn.Module):
|
||||||
# Every 5 epochs, save a test image
|
# Every 5 epochs, save a test image
|
||||||
if epoch % 5 == 0:
|
if epoch % 5 == 0:
|
||||||
img = self.process_outputs_for_testing(outputs).cpu().data
|
img = self.process_outputs_for_testing(outputs).cpu().data
|
||||||
img = img.view(img.size(0), 3, 32, 32)
|
dataset.save_batch_to_sample(
|
||||||
save_image(img, os.path.join(TRAIN_TEMP_DATA_BASE_PATH,
|
batch=img,
|
||||||
f'{self.name}_{dataset.name}_linear_ae_image{epoch}.png'))
|
filename=os.path.join(TRAIN_TEMP_DATA_BASE_PATH,
|
||||||
|
f'{self.name}_{dataset.name}_linear_ae_image{epoch}.png')
|
||||||
|
)
|
||||||
|
|
||||||
return losses
|
return losses
|
||||||
|
|
||||||
|
@ -180,16 +182,21 @@ class BaseEncoder(torch.nn.Module):
|
||||||
self.log.debug(f"Start testing...")
|
self.log.debug(f"Start testing...")
|
||||||
i = 0
|
i = 0
|
||||||
for batch in test_loader:
|
for batch in test_loader:
|
||||||
img = batch.view(batch.size(0), 3, 32, 32)
|
dataset.save_batch_to_sample(
|
||||||
save_image(img, os.path.join(TEST_TEMP_DATA_BASE_PATH,
|
batch=batch,
|
||||||
f'{self.name}_{dataset.name}_test_input_{i}.png'))
|
filename=os.path.join(TEST_TEMP_DATA_BASE_PATH,
|
||||||
|
f'{self.name}_{dataset.name}_test_input_{i}')
|
||||||
|
)
|
||||||
|
|
||||||
# load batch features to the active device
|
# load batch features to the active device
|
||||||
batch = batch.to(self.device)
|
batch = batch.to(self.device)
|
||||||
outputs = self.process_outputs_for_testing(self(batch))
|
outputs = self.process_outputs_for_testing(self(batch))
|
||||||
img = outputs.cpu().data
|
img = outputs.cpu().data
|
||||||
img = img.view(outputs.size(0), 3, 32, 32)
|
dataset.save_batch_to_sample(
|
||||||
save_image(img, os.path.join(TEST_TEMP_DATA_BASE_PATH,
|
batch=img,
|
||||||
f'{self.name}_{dataset.name}_test_reconstruction_{i}.png'))
|
filename=os.path.join(TEST_TEMP_DATA_BASE_PATH,
|
||||||
|
f'{self.name}_{dataset.name}_test_reconstruction_{i}')
|
||||||
|
)
|
||||||
i += 1
|
i += 1
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
|
@ -4,6 +4,7 @@ from typing import Optional
|
||||||
|
|
||||||
import numpy
|
import numpy
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
|
from torchvision.utils import save_image
|
||||||
|
|
||||||
from config import DATASET_STORAGE_BASE_PATH
|
from config import DATASET_STORAGE_BASE_PATH
|
||||||
from models.base_dataset import BaseDataset
|
from models.base_dataset import BaseDataset
|
||||||
|
@ -12,13 +13,9 @@ from models.base_dataset import BaseDataset
|
||||||
class Cifar10Dataset(BaseDataset):
|
class Cifar10Dataset(BaseDataset):
|
||||||
name = "CIFAR-10"
|
name = "CIFAR-10"
|
||||||
|
|
||||||
# transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),
|
|
||||||
# torchvision.transforms.Normalize((0.5, ), (0.5, ))
|
|
||||||
# ])
|
|
||||||
transform = transforms.Compose([
|
transform = transforms.Compose([
|
||||||
transforms.ToPILImage(),
|
transforms.ToPILImage(),
|
||||||
transforms.ToTensor(),
|
transforms.ToTensor()
|
||||||
# transforms.Normalize((0.5,), (0.5,))
|
|
||||||
])
|
])
|
||||||
|
|
||||||
def unpickle(self, filename):
|
def unpickle(self, filename):
|
||||||
|
@ -72,3 +69,7 @@ class Cifar10Dataset(BaseDataset):
|
||||||
img = img.view(-1, 32 * 32 * 3)
|
img = img.view(-1, 32 * 32 * 3)
|
||||||
|
|
||||||
return img
|
return img
|
||||||
|
|
||||||
|
def save_batch_to_sample(self, batch, filename):
|
||||||
|
img = batch.view(batch.size(0), 3, 32, 32)
|
||||||
|
save_image(img, f"{filename}.png")
|
||||||
|
|
|
@ -31,6 +31,7 @@ class GaussianCorruption(BaseCorruption):
|
||||||
return dataset.__class__.get_new(
|
return dataset.__class__.get_new(
|
||||||
name=f"{dataset.name} Corrupted",
|
name=f"{dataset.name} Corrupted",
|
||||||
data=data,
|
data=data,
|
||||||
|
labels=dataset._labels,
|
||||||
source_path=dataset._source_path,
|
source_path=dataset._source_path,
|
||||||
train_set=train_set,
|
train_set=train_set,
|
||||||
test_set=test_set)
|
test_set=test_set)
|
||||||
|
|
58
models/mnist_dataset.py
Normal file
58
models/mnist_dataset.py
Normal file
|
@ -0,0 +1,58 @@
|
||||||
|
import os
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from torchvision import transforms
|
||||||
|
from torchvision.datasets import MNIST
|
||||||
|
from torchvision.utils import save_image
|
||||||
|
|
||||||
|
from config import DATASET_STORAGE_BASE_PATH
|
||||||
|
from models.base_dataset import BaseDataset
|
||||||
|
|
||||||
|
|
||||||
|
class MNISTDataset(BaseDataset):
|
||||||
|
name = "MNIST"
|
||||||
|
|
||||||
|
transform = transforms.Compose([
|
||||||
|
transforms.ToPILImage(),
|
||||||
|
transforms.ToTensor()
|
||||||
|
])
|
||||||
|
|
||||||
|
def load(self, name: Optional[str] = None, path: Optional[str] = None):
|
||||||
|
if name is not None:
|
||||||
|
self.name = name
|
||||||
|
if path is not None:
|
||||||
|
self._source_path = path
|
||||||
|
|
||||||
|
train_dataset = MNIST(root=os.path.join(DATASET_STORAGE_BASE_PATH, self._source_path), train=True, download=True)
|
||||||
|
train_data = [x for x in train_dataset.data]
|
||||||
|
self._data = train_data
|
||||||
|
|
||||||
|
self._trainset = self.__class__.get_new(name=f"{self.name} Training", data=train_data[:],
|
||||||
|
source_path=self._source_path)
|
||||||
|
|
||||||
|
test_dataset = MNIST(root=os.path.join(DATASET_STORAGE_BASE_PATH, self._source_path), train=False, download=True)
|
||||||
|
test_data = [x for x in test_dataset.data]
|
||||||
|
self._data.extend(test_data)
|
||||||
|
|
||||||
|
self._testset = self.__class__.get_new(name=f"{self.name} Testing", data=test_data[:],
|
||||||
|
source_path=self._source_path)
|
||||||
|
|
||||||
|
self.log.info(f"Loaded {self}, divided into {self._trainset} and {self._testset}")
|
||||||
|
|
||||||
|
def __getitem__(self, item):
|
||||||
|
# Get image data
|
||||||
|
img = self._data[item]
|
||||||
|
|
||||||
|
# Run transforms
|
||||||
|
if self.transform is not None:
|
||||||
|
img = self.transform(img)
|
||||||
|
|
||||||
|
# Reshape the 28x28x1 image to a 1x784 array for the Linear layer
|
||||||
|
img = img.view(-1, 28 * 28)
|
||||||
|
|
||||||
|
return img
|
||||||
|
|
||||||
|
def save_batch_to_sample(self, batch, filename):
|
||||||
|
img = batch.view(batch.size(0), 1, 28, 28)
|
||||||
|
save_image(img, f"{filename}.png")
|
Loading…
Reference in a new issue