RP_AutoEncoderComparison/models/cifar10_dataset.py

76 lines
2.6 KiB
Python

import os
from typing import Optional
import numpy
import torchvision
from PIL import Image
from torchvision import transforms
from config import DATASET_STORAGE_BASE_PATH
from models.base_dataset import BaseDataset
class Cifar10Dataset(BaseDataset):
# transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),
# torchvision.transforms.Normalize((0.5, ), (0.5, ))
# ])
transform = transforms.Compose([
transforms.ToPILImage(),
transforms.ToTensor(),
# transforms.Normalize((0.5,), (0.5,))
])
def unpickle(self, filename):
import pickle
with open(filename, 'rb') as fo:
dict = pickle.load(fo, encoding='bytes')
return dict
def load(self, name: Optional[str] = None, path: Optional[str] = None):
if name is not None:
self.name = name
if path is not None:
self._source_path = path
self._data = []
for i in range(1, 6):
data = self.unpickle(os.path.join(DATASET_STORAGE_BASE_PATH,
self._source_path,
f"data_batch_{i}"))
self._data.extend(data[b'data'])
self._trainset = self.__class__.get_new(name=f"{self.name} Training", data=self._data[:],
source_path=self._source_path)
test_data = self.unpickle(os.path.join(DATASET_STORAGE_BASE_PATH,
self._source_path,
f"test_batch"))
self._data.extend(test_data[b'data'])
self._testset = self.__class__.get_new(name=f"{self.name} Testing", data=test_data[b'data'][:],
source_path=self._source_path)
self.log.info(f"Loaded {self}, divided into {self._trainset} and {self._testset}")
def __getitem__(self, item):
# Get image data
img = self._data[item]
img_r, img_g, img_b = img.reshape((3, 1024))
img_r = img_r.reshape((32, 32))
img_g = img_g.reshape((32, 32))
img_b = img_b.reshape((32, 32))
# Reshape to 32x32x3 image
img = numpy.stack((img_r, img_g, img_b), axis=2)
# Run transforms
if self.transform is not None:
img = self.transform(img)
# Reshape the 32x32x3 image to a 1x3072 array for the Linear layer
img = img.view(-1, 32 * 32 * 3)
return img