diff --git a/models/base_dataset.py b/models/base_dataset.py index a837ade..f5513e6 100644 --- a/models/base_dataset.py +++ b/models/base_dataset.py @@ -110,3 +110,7 @@ class BaseDataset(Dataset): # Save a batch of tensors to a sample file for comparison (no implementation for base dataset) pass + def calculate_score(self, originals, reconstruction, device): + # Calculate the score given an uncorrupted and a corrupted batch. (no implementation for base dataset) + pass + diff --git a/models/base_encoder.py b/models/base_encoder.py index 885dc0f..1453e0e 100644 --- a/models/base_encoder.py +++ b/models/base_encoder.py @@ -6,10 +6,12 @@ import torch from typing import Optional +from pytorch_msssim import ssim from torch.nn.modules.loss import _Loss from torchvision.utils import save_image from config import TRAIN_TEMP_DATA_BASE_PATH, TEST_TEMP_DATA_BASE_PATH, MODEL_STORAGE_BASE_PATH +from models.base_corruption import BaseCorruption from models.base_dataset import BaseDataset @@ -175,31 +177,46 @@ class BaseEncoder(torch.nn.Module): return losses - def test_encoder(self, dataset: BaseDataset, batch_size: int = 128, num_workers: int = 4): + def test_encoder(self, dataset: BaseDataset, corruption: BaseCorruption, batch_size: int = 128, num_workers: int = 4): self.log.debug("Getting testing dataset DataLoader.") test_loader = dataset.get_test_loader(batch_size=batch_size, num_workers=num_workers) self.log.debug(f"Start testing...") + avg_scores = [] i = 0 for batch in test_loader: dataset.save_batch_to_sample( batch=batch, filename=os.path.join(TEST_TEMP_DATA_BASE_PATH, - f'{self.name}_{dataset.name}_test_input_{i}') + f'{self.name}_{dataset.name}_test_input_{i}_uncorrupted') + ) + corrupted_batch = torch.tensor([corruption.corrupt_image(i) for i in batch], dtype=torch.float32) + dataset.save_batch_to_sample( + batch=corrupted_batch, + filename=os.path.join(TEST_TEMP_DATA_BASE_PATH, + f'{self.name}_{dataset.name}_test_input_{i}_corrupted') ) # load batch features to the active device - batch = batch.to(self.device) - outputs = self.process_outputs_for_testing(self(batch)) + corrupted_batch = corrupted_batch.to(self.device) + outputs = self.process_outputs_for_testing(self(corrupted_batch)) img = outputs.cpu().data dataset.save_batch_to_sample( batch=img, filename=os.path.join(TEST_TEMP_DATA_BASE_PATH, f'{self.name}_{dataset.name}_test_reconstruction_{i}') ) + + batch_score = dataset.calculate_score(batch, img, self.device) + avg_scores.append(batch_score) + i += 1 break + avg_score = sum(avg_scores) / len(avg_scores) + # self.log.warning(f"Testing results - Average score: {avg_score}") + print(f"Testing results - Average score: {avg_score}") + def process_loss(self, train_loss, features, outputs) -> _Loss: return train_loss diff --git a/models/cifar10_dataset.py b/models/cifar10_dataset.py index f2257ca..120e591 100644 --- a/models/cifar10_dataset.py +++ b/models/cifar10_dataset.py @@ -3,6 +3,7 @@ import os from typing import Optional import numpy +from pytorch_msssim import ssim from torchvision import transforms from torchvision.utils import save_image @@ -70,6 +71,28 @@ class Cifar10Dataset(BaseDataset): return img + def get_as_image_array(self, item): + # Get image data + img = self._data[item] + + img = img.reshape((3, 1024)) + + # Run transforms + if self.transform is not None: + img = self.transform(img) + + # Reshape the 32x32x3 image to a 1x3072 array for the Linear layer + img = img.view(-1, 3, 32, 32) + + return img + def save_batch_to_sample(self, batch, filename): img = batch.view(batch.size(0), 3, 32, 32) save_image(img, f"{filename}.png") + + def calculate_score(self, originals, reconstruction, device): + # Calculate SSIM + originals = originals.view(originals.size(0), 3, 32, 32).to(device) + reconstruction = reconstruction.view(reconstruction.size(0), 3, 32, 32).to(device) + batch_average_score = ssim(originals, reconstruction, data_range=1, size_average=True) + return batch_average_score diff --git a/models/gaussian_corruption.py b/models/gaussian_corruption.py index 5a301cf..4f3df27 100644 --- a/models/gaussian_corruption.py +++ b/models/gaussian_corruption.py @@ -1,3 +1,4 @@ +import torch from torch import Tensor from models.base_corruption import BaseCorruption @@ -6,11 +7,13 @@ import numpy def add_noise(image): + if isinstance(image, Tensor): + image = image.numpy() image = image.astype(numpy.float32) mean, variance = 0, 0.1 sigma = variance ** 0.5 noise = numpy.random.normal(mean, sigma, image.shape).reshape(image.shape) - return image + noise + return numpy.clip(image + noise, 0, 1) class GaussianCorruption(BaseCorruption): @@ -25,7 +28,8 @@ class GaussianCorruption(BaseCorruption): @classmethod def corrupt_dataset(cls, dataset: BaseDataset) -> BaseDataset: - data = list(map(add_noise, dataset)) + data = [cls.corrupt_image(x) for x in dataset] + # data = list(map(add_noise, dataset._data)) train_set = cls.corrupt_dataset(dataset.get_train()) if dataset.has_train() else None test_set = cls.corrupt_dataset(dataset.get_test()) if dataset.has_test() else None return dataset.__class__.get_new( diff --git a/models/mnist_dataset.py b/models/mnist_dataset.py index 1582e08..cc106b6 100644 --- a/models/mnist_dataset.py +++ b/models/mnist_dataset.py @@ -2,6 +2,7 @@ import os from typing import Optional +from pytorch_msssim import ssim from torchvision import transforms from torchvision.datasets import MNIST from torchvision.utils import save_image @@ -56,3 +57,10 @@ class MNISTDataset(BaseDataset): def save_batch_to_sample(self, batch, filename): img = batch.view(batch.size(0), 1, 28, 28) save_image(img, f"{filename}.png") + + def calculate_score(self, originals, reconstruction, device): + # Calculate SSIM + originals = originals.view(originals.size(0), 1, 28, 28).to(device) + reconstruction = reconstruction.view(reconstruction.size(0), 1, 28, 28).to(device) + batch_average_score = ssim(originals, reconstruction, data_range=1, size_average=True) + return batch_average_score diff --git a/models/test_run.py b/models/test_run.py index e4b1c2d..ff37380 100644 --- a/models/test_run.py +++ b/models/test_run.py @@ -59,7 +59,7 @@ class TestRun: # Test encoder self.log.info("Testing auto-encoder...") - self.encoder.test_encoder(self.dataset, num_workers=multiprocessing.cpu_count() - 1) + self.encoder.test_encoder(self.dataset, corruption=self.corruption, num_workers=multiprocessing.cpu_count() - 1) self.log.info("Done!") diff --git a/requirements.txt b/requirements.txt index 20c67e6..7db4710 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,4 +2,5 @@ torch==1.7.1 torchvision==0.8.2 torchaudio===0.7.2 tabulate -matplotlib \ No newline at end of file +matplotlib +pytorch-msssim \ No newline at end of file