Add mssim score calculation, corrupt test data before testing, clip noise to avoid invalid values
This commit is contained in:
		
							parent
							
								
									bc95548ae3
								
							
						
					
					
						commit
						f76374111c
					
				
					 7 changed files with 65 additions and 8 deletions
				
			
		| 
						 | 
					@ -110,3 +110,7 @@ class BaseDataset(Dataset):
 | 
				
			||||||
        # Save a batch of tensors to a sample file for comparison (no implementation for base dataset)
 | 
					        # Save a batch of tensors to a sample file for comparison (no implementation for base dataset)
 | 
				
			||||||
        pass
 | 
					        pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def calculate_score(self, originals, reconstruction, device):
 | 
				
			||||||
 | 
					        # Calculate the score given an uncorrupted and a corrupted batch. (no implementation for base dataset)
 | 
				
			||||||
 | 
					        pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -6,10 +6,12 @@ import torch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from typing import Optional
 | 
					from typing import Optional
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from pytorch_msssim import ssim
 | 
				
			||||||
from torch.nn.modules.loss import _Loss
 | 
					from torch.nn.modules.loss import _Loss
 | 
				
			||||||
from torchvision.utils import save_image
 | 
					from torchvision.utils import save_image
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from config import TRAIN_TEMP_DATA_BASE_PATH, TEST_TEMP_DATA_BASE_PATH, MODEL_STORAGE_BASE_PATH
 | 
					from config import TRAIN_TEMP_DATA_BASE_PATH, TEST_TEMP_DATA_BASE_PATH, MODEL_STORAGE_BASE_PATH
 | 
				
			||||||
 | 
					from models.base_corruption import BaseCorruption
 | 
				
			||||||
from models.base_dataset import BaseDataset
 | 
					from models.base_dataset import BaseDataset
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -175,31 +177,46 @@ class BaseEncoder(torch.nn.Module):
 | 
				
			||||||
        return losses
 | 
					        return losses
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_encoder(self, dataset: BaseDataset, batch_size: int = 128, num_workers: int = 4):
 | 
					    def test_encoder(self, dataset: BaseDataset, corruption: BaseCorruption, batch_size: int = 128, num_workers: int = 4):
 | 
				
			||||||
        self.log.debug("Getting testing dataset DataLoader.")
 | 
					        self.log.debug("Getting testing dataset DataLoader.")
 | 
				
			||||||
        test_loader = dataset.get_test_loader(batch_size=batch_size, num_workers=num_workers)
 | 
					        test_loader = dataset.get_test_loader(batch_size=batch_size, num_workers=num_workers)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.log.debug(f"Start testing...")
 | 
					        self.log.debug(f"Start testing...")
 | 
				
			||||||
 | 
					        avg_scores = []
 | 
				
			||||||
        i = 0
 | 
					        i = 0
 | 
				
			||||||
        for batch in test_loader:
 | 
					        for batch in test_loader:
 | 
				
			||||||
            dataset.save_batch_to_sample(
 | 
					            dataset.save_batch_to_sample(
 | 
				
			||||||
                batch=batch,
 | 
					                batch=batch,
 | 
				
			||||||
                filename=os.path.join(TEST_TEMP_DATA_BASE_PATH,
 | 
					                filename=os.path.join(TEST_TEMP_DATA_BASE_PATH,
 | 
				
			||||||
                                      f'{self.name}_{dataset.name}_test_input_{i}')
 | 
					                                      f'{self.name}_{dataset.name}_test_input_{i}_uncorrupted')
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            corrupted_batch = torch.tensor([corruption.corrupt_image(i) for i in batch], dtype=torch.float32)
 | 
				
			||||||
 | 
					            dataset.save_batch_to_sample(
 | 
				
			||||||
 | 
					                batch=corrupted_batch,
 | 
				
			||||||
 | 
					                filename=os.path.join(TEST_TEMP_DATA_BASE_PATH,
 | 
				
			||||||
 | 
					                                      f'{self.name}_{dataset.name}_test_input_{i}_corrupted')
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            # load batch features to the active device
 | 
					            # load batch features to the active device
 | 
				
			||||||
            batch = batch.to(self.device)
 | 
					            corrupted_batch = corrupted_batch.to(self.device)
 | 
				
			||||||
            outputs = self.process_outputs_for_testing(self(batch))
 | 
					            outputs = self.process_outputs_for_testing(self(corrupted_batch))
 | 
				
			||||||
            img = outputs.cpu().data
 | 
					            img = outputs.cpu().data
 | 
				
			||||||
            dataset.save_batch_to_sample(
 | 
					            dataset.save_batch_to_sample(
 | 
				
			||||||
                batch=img,
 | 
					                batch=img,
 | 
				
			||||||
                filename=os.path.join(TEST_TEMP_DATA_BASE_PATH,
 | 
					                filename=os.path.join(TEST_TEMP_DATA_BASE_PATH,
 | 
				
			||||||
                                      f'{self.name}_{dataset.name}_test_reconstruction_{i}')
 | 
					                                      f'{self.name}_{dataset.name}_test_reconstruction_{i}')
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            batch_score = dataset.calculate_score(batch, img, self.device)
 | 
				
			||||||
 | 
					            avg_scores.append(batch_score)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            i += 1
 | 
					            i += 1
 | 
				
			||||||
            break
 | 
					            break
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        avg_score = sum(avg_scores) / len(avg_scores)
 | 
				
			||||||
 | 
					        # self.log.warning(f"Testing results - Average score: {avg_score}")
 | 
				
			||||||
 | 
					        print(f"Testing results - Average score: {avg_score}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def process_loss(self, train_loss, features, outputs) -> _Loss:
 | 
					    def process_loss(self, train_loss, features, outputs) -> _Loss:
 | 
				
			||||||
        return train_loss
 | 
					        return train_loss
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -3,6 +3,7 @@ import os
 | 
				
			||||||
from typing import Optional
 | 
					from typing import Optional
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import numpy
 | 
					import numpy
 | 
				
			||||||
 | 
					from pytorch_msssim import ssim
 | 
				
			||||||
from torchvision import transforms
 | 
					from torchvision import transforms
 | 
				
			||||||
from torchvision.utils import save_image
 | 
					from torchvision.utils import save_image
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -70,6 +71,28 @@ class Cifar10Dataset(BaseDataset):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        return img
 | 
					        return img
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def get_as_image_array(self, item):
 | 
				
			||||||
 | 
					        # Get image data
 | 
				
			||||||
 | 
					        img = self._data[item]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        img = img.reshape((3, 1024))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Run transforms
 | 
				
			||||||
 | 
					        if self.transform is not None:
 | 
				
			||||||
 | 
					            img = self.transform(img)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Reshape the 32x32x3 image to a 1x3072 array for the Linear layer
 | 
				
			||||||
 | 
					        img = img.view(-1, 3, 32, 32)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return img
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def save_batch_to_sample(self, batch, filename):
 | 
					    def save_batch_to_sample(self, batch, filename):
 | 
				
			||||||
        img = batch.view(batch.size(0), 3, 32, 32)
 | 
					        img = batch.view(batch.size(0), 3, 32, 32)
 | 
				
			||||||
        save_image(img, f"{filename}.png")
 | 
					        save_image(img, f"{filename}.png")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def calculate_score(self, originals, reconstruction, device):
 | 
				
			||||||
 | 
					        # Calculate SSIM
 | 
				
			||||||
 | 
					        originals = originals.view(originals.size(0), 3, 32, 32).to(device)
 | 
				
			||||||
 | 
					        reconstruction = reconstruction.view(reconstruction.size(0), 3, 32, 32).to(device)
 | 
				
			||||||
 | 
					        batch_average_score = ssim(originals, reconstruction, data_range=1, size_average=True)
 | 
				
			||||||
 | 
					        return batch_average_score
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,3 +1,4 @@
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
from torch import Tensor
 | 
					from torch import Tensor
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from models.base_corruption import BaseCorruption
 | 
					from models.base_corruption import BaseCorruption
 | 
				
			||||||
| 
						 | 
					@ -6,11 +7,13 @@ import numpy
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def add_noise(image):
 | 
					def add_noise(image):
 | 
				
			||||||
 | 
					    if isinstance(image, Tensor):
 | 
				
			||||||
 | 
					        image = image.numpy()
 | 
				
			||||||
    image = image.astype(numpy.float32)
 | 
					    image = image.astype(numpy.float32)
 | 
				
			||||||
    mean, variance = 0, 0.1
 | 
					    mean, variance = 0, 0.1
 | 
				
			||||||
    sigma = variance ** 0.5
 | 
					    sigma = variance ** 0.5
 | 
				
			||||||
    noise = numpy.random.normal(mean, sigma, image.shape).reshape(image.shape)
 | 
					    noise = numpy.random.normal(mean, sigma, image.shape).reshape(image.shape)
 | 
				
			||||||
    return image + noise
 | 
					    return numpy.clip(image + noise, 0, 1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class GaussianCorruption(BaseCorruption):
 | 
					class GaussianCorruption(BaseCorruption):
 | 
				
			||||||
| 
						 | 
					@ -25,7 +28,8 @@ class GaussianCorruption(BaseCorruption):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @classmethod
 | 
					    @classmethod
 | 
				
			||||||
    def corrupt_dataset(cls, dataset: BaseDataset) -> BaseDataset:
 | 
					    def corrupt_dataset(cls, dataset: BaseDataset) -> BaseDataset:
 | 
				
			||||||
        data = list(map(add_noise, dataset))
 | 
					        data = [cls.corrupt_image(x) for x in dataset]
 | 
				
			||||||
 | 
					        # data = list(map(add_noise, dataset._data))
 | 
				
			||||||
        train_set = cls.corrupt_dataset(dataset.get_train()) if dataset.has_train() else None
 | 
					        train_set = cls.corrupt_dataset(dataset.get_train()) if dataset.has_train() else None
 | 
				
			||||||
        test_set = cls.corrupt_dataset(dataset.get_test()) if dataset.has_test() else None
 | 
					        test_set = cls.corrupt_dataset(dataset.get_test()) if dataset.has_test() else None
 | 
				
			||||||
        return dataset.__class__.get_new(
 | 
					        return dataset.__class__.get_new(
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -2,6 +2,7 @@ import os
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from typing import Optional
 | 
					from typing import Optional
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from pytorch_msssim import ssim
 | 
				
			||||||
from torchvision import transforms
 | 
					from torchvision import transforms
 | 
				
			||||||
from torchvision.datasets import MNIST
 | 
					from torchvision.datasets import MNIST
 | 
				
			||||||
from torchvision.utils import save_image
 | 
					from torchvision.utils import save_image
 | 
				
			||||||
| 
						 | 
					@ -56,3 +57,10 @@ class MNISTDataset(BaseDataset):
 | 
				
			||||||
    def save_batch_to_sample(self, batch, filename):
 | 
					    def save_batch_to_sample(self, batch, filename):
 | 
				
			||||||
        img = batch.view(batch.size(0), 1, 28, 28)
 | 
					        img = batch.view(batch.size(0), 1, 28, 28)
 | 
				
			||||||
        save_image(img, f"{filename}.png")
 | 
					        save_image(img, f"{filename}.png")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def calculate_score(self, originals, reconstruction, device):
 | 
				
			||||||
 | 
					        # Calculate SSIM
 | 
				
			||||||
 | 
					        originals = originals.view(originals.size(0), 1, 28, 28).to(device)
 | 
				
			||||||
 | 
					        reconstruction = reconstruction.view(reconstruction.size(0), 1, 28, 28).to(device)
 | 
				
			||||||
 | 
					        batch_average_score = ssim(originals, reconstruction, data_range=1, size_average=True)
 | 
				
			||||||
 | 
					        return batch_average_score
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -59,7 +59,7 @@ class TestRun:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Test encoder
 | 
					        # Test encoder
 | 
				
			||||||
        self.log.info("Testing auto-encoder...")
 | 
					        self.log.info("Testing auto-encoder...")
 | 
				
			||||||
        self.encoder.test_encoder(self.dataset, num_workers=multiprocessing.cpu_count() - 1)
 | 
					        self.encoder.test_encoder(self.dataset, corruption=self.corruption, num_workers=multiprocessing.cpu_count() - 1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.log.info("Done!")
 | 
					        self.log.info("Done!")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -3,3 +3,4 @@ torchvision==0.8.2
 | 
				
			||||||
torchaudio===0.7.2
 | 
					torchaudio===0.7.2
 | 
				
			||||||
tabulate
 | 
					tabulate
 | 
				
			||||||
matplotlib
 | 
					matplotlib
 | 
				
			||||||
 | 
					pytorch-msssim
 | 
				
			||||||
		Reference in a new issue