Add mssim score calculation, corrupt test data before testing, clip noise to avoid invalid values
This commit is contained in:
parent
bc95548ae3
commit
f76374111c
|
@ -110,3 +110,7 @@ class BaseDataset(Dataset):
|
||||||
# Save a batch of tensors to a sample file for comparison (no implementation for base dataset)
|
# Save a batch of tensors to a sample file for comparison (no implementation for base dataset)
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def calculate_score(self, originals, reconstruction, device):
|
||||||
|
# Calculate the score given an uncorrupted and a corrupted batch. (no implementation for base dataset)
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
|
@ -6,10 +6,12 @@ import torch
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
from pytorch_msssim import ssim
|
||||||
from torch.nn.modules.loss import _Loss
|
from torch.nn.modules.loss import _Loss
|
||||||
from torchvision.utils import save_image
|
from torchvision.utils import save_image
|
||||||
|
|
||||||
from config import TRAIN_TEMP_DATA_BASE_PATH, TEST_TEMP_DATA_BASE_PATH, MODEL_STORAGE_BASE_PATH
|
from config import TRAIN_TEMP_DATA_BASE_PATH, TEST_TEMP_DATA_BASE_PATH, MODEL_STORAGE_BASE_PATH
|
||||||
|
from models.base_corruption import BaseCorruption
|
||||||
from models.base_dataset import BaseDataset
|
from models.base_dataset import BaseDataset
|
||||||
|
|
||||||
|
|
||||||
|
@ -175,31 +177,46 @@ class BaseEncoder(torch.nn.Module):
|
||||||
return losses
|
return losses
|
||||||
|
|
||||||
|
|
||||||
def test_encoder(self, dataset: BaseDataset, batch_size: int = 128, num_workers: int = 4):
|
def test_encoder(self, dataset: BaseDataset, corruption: BaseCorruption, batch_size: int = 128, num_workers: int = 4):
|
||||||
self.log.debug("Getting testing dataset DataLoader.")
|
self.log.debug("Getting testing dataset DataLoader.")
|
||||||
test_loader = dataset.get_test_loader(batch_size=batch_size, num_workers=num_workers)
|
test_loader = dataset.get_test_loader(batch_size=batch_size, num_workers=num_workers)
|
||||||
|
|
||||||
self.log.debug(f"Start testing...")
|
self.log.debug(f"Start testing...")
|
||||||
|
avg_scores = []
|
||||||
i = 0
|
i = 0
|
||||||
for batch in test_loader:
|
for batch in test_loader:
|
||||||
dataset.save_batch_to_sample(
|
dataset.save_batch_to_sample(
|
||||||
batch=batch,
|
batch=batch,
|
||||||
filename=os.path.join(TEST_TEMP_DATA_BASE_PATH,
|
filename=os.path.join(TEST_TEMP_DATA_BASE_PATH,
|
||||||
f'{self.name}_{dataset.name}_test_input_{i}')
|
f'{self.name}_{dataset.name}_test_input_{i}_uncorrupted')
|
||||||
|
)
|
||||||
|
corrupted_batch = torch.tensor([corruption.corrupt_image(i) for i in batch], dtype=torch.float32)
|
||||||
|
dataset.save_batch_to_sample(
|
||||||
|
batch=corrupted_batch,
|
||||||
|
filename=os.path.join(TEST_TEMP_DATA_BASE_PATH,
|
||||||
|
f'{self.name}_{dataset.name}_test_input_{i}_corrupted')
|
||||||
)
|
)
|
||||||
|
|
||||||
# load batch features to the active device
|
# load batch features to the active device
|
||||||
batch = batch.to(self.device)
|
corrupted_batch = corrupted_batch.to(self.device)
|
||||||
outputs = self.process_outputs_for_testing(self(batch))
|
outputs = self.process_outputs_for_testing(self(corrupted_batch))
|
||||||
img = outputs.cpu().data
|
img = outputs.cpu().data
|
||||||
dataset.save_batch_to_sample(
|
dataset.save_batch_to_sample(
|
||||||
batch=img,
|
batch=img,
|
||||||
filename=os.path.join(TEST_TEMP_DATA_BASE_PATH,
|
filename=os.path.join(TEST_TEMP_DATA_BASE_PATH,
|
||||||
f'{self.name}_{dataset.name}_test_reconstruction_{i}')
|
f'{self.name}_{dataset.name}_test_reconstruction_{i}')
|
||||||
)
|
)
|
||||||
|
|
||||||
|
batch_score = dataset.calculate_score(batch, img, self.device)
|
||||||
|
avg_scores.append(batch_score)
|
||||||
|
|
||||||
i += 1
|
i += 1
|
||||||
break
|
break
|
||||||
|
|
||||||
|
avg_score = sum(avg_scores) / len(avg_scores)
|
||||||
|
# self.log.warning(f"Testing results - Average score: {avg_score}")
|
||||||
|
print(f"Testing results - Average score: {avg_score}")
|
||||||
|
|
||||||
def process_loss(self, train_loss, features, outputs) -> _Loss:
|
def process_loss(self, train_loss, features, outputs) -> _Loss:
|
||||||
return train_loss
|
return train_loss
|
||||||
|
|
||||||
|
|
|
@ -3,6 +3,7 @@ import os
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import numpy
|
import numpy
|
||||||
|
from pytorch_msssim import ssim
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
from torchvision.utils import save_image
|
from torchvision.utils import save_image
|
||||||
|
|
||||||
|
@ -70,6 +71,28 @@ class Cifar10Dataset(BaseDataset):
|
||||||
|
|
||||||
return img
|
return img
|
||||||
|
|
||||||
|
def get_as_image_array(self, item):
|
||||||
|
# Get image data
|
||||||
|
img = self._data[item]
|
||||||
|
|
||||||
|
img = img.reshape((3, 1024))
|
||||||
|
|
||||||
|
# Run transforms
|
||||||
|
if self.transform is not None:
|
||||||
|
img = self.transform(img)
|
||||||
|
|
||||||
|
# Reshape the 32x32x3 image to a 1x3072 array for the Linear layer
|
||||||
|
img = img.view(-1, 3, 32, 32)
|
||||||
|
|
||||||
|
return img
|
||||||
|
|
||||||
def save_batch_to_sample(self, batch, filename):
|
def save_batch_to_sample(self, batch, filename):
|
||||||
img = batch.view(batch.size(0), 3, 32, 32)
|
img = batch.view(batch.size(0), 3, 32, 32)
|
||||||
save_image(img, f"{filename}.png")
|
save_image(img, f"{filename}.png")
|
||||||
|
|
||||||
|
def calculate_score(self, originals, reconstruction, device):
|
||||||
|
# Calculate SSIM
|
||||||
|
originals = originals.view(originals.size(0), 3, 32, 32).to(device)
|
||||||
|
reconstruction = reconstruction.view(reconstruction.size(0), 3, 32, 32).to(device)
|
||||||
|
batch_average_score = ssim(originals, reconstruction, data_range=1, size_average=True)
|
||||||
|
return batch_average_score
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
from models.base_corruption import BaseCorruption
|
from models.base_corruption import BaseCorruption
|
||||||
|
@ -6,11 +7,13 @@ import numpy
|
||||||
|
|
||||||
|
|
||||||
def add_noise(image):
|
def add_noise(image):
|
||||||
|
if isinstance(image, Tensor):
|
||||||
|
image = image.numpy()
|
||||||
image = image.astype(numpy.float32)
|
image = image.astype(numpy.float32)
|
||||||
mean, variance = 0, 0.1
|
mean, variance = 0, 0.1
|
||||||
sigma = variance ** 0.5
|
sigma = variance ** 0.5
|
||||||
noise = numpy.random.normal(mean, sigma, image.shape).reshape(image.shape)
|
noise = numpy.random.normal(mean, sigma, image.shape).reshape(image.shape)
|
||||||
return image + noise
|
return numpy.clip(image + noise, 0, 1)
|
||||||
|
|
||||||
|
|
||||||
class GaussianCorruption(BaseCorruption):
|
class GaussianCorruption(BaseCorruption):
|
||||||
|
@ -25,7 +28,8 @@ class GaussianCorruption(BaseCorruption):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def corrupt_dataset(cls, dataset: BaseDataset) -> BaseDataset:
|
def corrupt_dataset(cls, dataset: BaseDataset) -> BaseDataset:
|
||||||
data = list(map(add_noise, dataset))
|
data = [cls.corrupt_image(x) for x in dataset]
|
||||||
|
# data = list(map(add_noise, dataset._data))
|
||||||
train_set = cls.corrupt_dataset(dataset.get_train()) if dataset.has_train() else None
|
train_set = cls.corrupt_dataset(dataset.get_train()) if dataset.has_train() else None
|
||||||
test_set = cls.corrupt_dataset(dataset.get_test()) if dataset.has_test() else None
|
test_set = cls.corrupt_dataset(dataset.get_test()) if dataset.has_test() else None
|
||||||
return dataset.__class__.get_new(
|
return dataset.__class__.get_new(
|
||||||
|
|
|
@ -2,6 +2,7 @@ import os
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
from pytorch_msssim import ssim
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
from torchvision.datasets import MNIST
|
from torchvision.datasets import MNIST
|
||||||
from torchvision.utils import save_image
|
from torchvision.utils import save_image
|
||||||
|
@ -56,3 +57,10 @@ class MNISTDataset(BaseDataset):
|
||||||
def save_batch_to_sample(self, batch, filename):
|
def save_batch_to_sample(self, batch, filename):
|
||||||
img = batch.view(batch.size(0), 1, 28, 28)
|
img = batch.view(batch.size(0), 1, 28, 28)
|
||||||
save_image(img, f"{filename}.png")
|
save_image(img, f"{filename}.png")
|
||||||
|
|
||||||
|
def calculate_score(self, originals, reconstruction, device):
|
||||||
|
# Calculate SSIM
|
||||||
|
originals = originals.view(originals.size(0), 1, 28, 28).to(device)
|
||||||
|
reconstruction = reconstruction.view(reconstruction.size(0), 1, 28, 28).to(device)
|
||||||
|
batch_average_score = ssim(originals, reconstruction, data_range=1, size_average=True)
|
||||||
|
return batch_average_score
|
||||||
|
|
|
@ -59,7 +59,7 @@ class TestRun:
|
||||||
|
|
||||||
# Test encoder
|
# Test encoder
|
||||||
self.log.info("Testing auto-encoder...")
|
self.log.info("Testing auto-encoder...")
|
||||||
self.encoder.test_encoder(self.dataset, num_workers=multiprocessing.cpu_count() - 1)
|
self.encoder.test_encoder(self.dataset, corruption=self.corruption, num_workers=multiprocessing.cpu_count() - 1)
|
||||||
|
|
||||||
self.log.info("Done!")
|
self.log.info("Done!")
|
||||||
|
|
||||||
|
|
|
@ -2,4 +2,5 @@ torch==1.7.1
|
||||||
torchvision==0.8.2
|
torchvision==0.8.2
|
||||||
torchaudio===0.7.2
|
torchaudio===0.7.2
|
||||||
tabulate
|
tabulate
|
||||||
matplotlib
|
matplotlib
|
||||||
|
pytorch-msssim
|
Loading…
Reference in a new issue