RP_AutoEncoderComparison/models/gaussian_corruption.py

42 lines
1.3 KiB
Python

import torch
from torch import Tensor
from models.base_corruption import BaseCorruption
from models.base_dataset import BaseDataset
import numpy
def add_noise(image):
if isinstance(image, Tensor):
image = image.numpy()
image = image.astype(numpy.float32)
mean, variance = 0, 0.1
sigma = variance ** 0.5
noise = numpy.random.normal(mean, sigma, image.shape).reshape(image.shape)
return numpy.clip(image + noise, 0, 1)
class GaussianCorruption(BaseCorruption):
"""
Corruption model that adds Gaussian noise to the dataset.
"""
name = "Gaussian"
@classmethod
def corrupt_image(cls, image: Tensor):
return add_noise(image.numpy())
@classmethod
def corrupt_dataset(cls, dataset: BaseDataset) -> BaseDataset:
data = [cls.corrupt_image(x) for x in dataset]
# data = list(map(add_noise, dataset._data))
train_set = cls.corrupt_dataset(dataset.get_train()) if dataset.has_train() else None
test_set = cls.corrupt_dataset(dataset.get_test()) if dataset.has_test() else None
return dataset.__class__.get_new(
name=f"{dataset.name} Corrupted",
data=data,
labels=dataset._labels,
source_path=dataset._source_path,
train_set=train_set,
test_set=test_set)