49 lines
1.4 KiB
Python
49 lines
1.4 KiB
Python
import random
|
|
|
|
from torch import Tensor
|
|
|
|
from models.base_corruption import BaseCorruption
|
|
from models.base_dataset import BaseDataset
|
|
import numpy
|
|
|
|
|
|
def add_noise(image):
|
|
if isinstance(image, Tensor):
|
|
image = image.numpy()
|
|
image = image.astype(numpy.float32)
|
|
|
|
# 90% chance to corrupt something
|
|
if random.random() < 0.9:
|
|
corrupt_index1, corrupt_index2 = random.sample(range(len(image)), 2)
|
|
image[corrupt_index1] = 0
|
|
# 10% chance to corrupt a second column
|
|
if random.random() < 0.1:
|
|
image[corrupt_index2] = 0
|
|
|
|
return image
|
|
|
|
|
|
class RandomCorruption(BaseCorruption):
|
|
"""
|
|
Corruption model that clears random fields of data.
|
|
"""
|
|
name = "Gaussian"
|
|
|
|
@classmethod
|
|
def corrupt_image(cls, image: Tensor):
|
|
return add_noise(image.numpy())
|
|
|
|
@classmethod
|
|
def corrupt_dataset(cls, dataset: BaseDataset) -> BaseDataset:
|
|
data = [cls.corrupt_image(x) for x in dataset]
|
|
# data = list(map(add_noise, dataset._data))
|
|
train_set = cls.corrupt_dataset(dataset.get_train()) if dataset.has_train() else None
|
|
test_set = cls.corrupt_dataset(dataset.get_test()) if dataset.has_test() else None
|
|
return dataset.__class__.get_new(
|
|
name=f"{dataset.name} Corrupted",
|
|
data=data,
|
|
labels=dataset._labels,
|
|
source_path=dataset._source_path,
|
|
train_set=train_set,
|
|
test_set=test_set)
|