diff --git a/logging.conf b/logging.conf index 56846c9..9e370b7 100644 --- a/logging.conf +++ b/logging.conf @@ -25,7 +25,7 @@ args=('output.log', 'w') [handler_consoleHandler] class=StreamHandler -level=INFO +level=DEBUG formatter=simpleFormatter args=(sys.stdout,) diff --git a/main.py b/main.py index 0e6f1ae..8e87eb0 100644 --- a/main.py +++ b/main.py @@ -34,6 +34,8 @@ def run_tests(): # Create TestRun instance dataset = dataset_model(**test['dataset_kwargs']) + if test['encoder_kwargs'].get('input_shape', None) is None: + test['encoder_kwargs']['input_shape'] = dataset.get_input_shape() encoder = encoder_model(**test['encoder_kwargs']) encoder.after_init() corruption = corruption_model(**test['corruption_kwargs']) diff --git a/models/base_corruption.py b/models/base_corruption.py index eb8cd3d..b365ec0 100644 --- a/models/base_corruption.py +++ b/models/base_corruption.py @@ -1,3 +1,5 @@ +import torch + from models.base_dataset import BaseDataset @@ -30,8 +32,8 @@ class NoCorruption(BaseCorruption): name = "No corruption" @classmethod - def corrupt_image(cls, image): - return image + def corrupt_image(cls, image: torch.Tensor): + return image.numpy() @classmethod def corrupt_dataset(cls, dataset: BaseDataset) -> BaseDataset: diff --git a/models/base_dataset.py b/models/base_dataset.py index f5513e6..091a076 100644 --- a/models/base_dataset.py +++ b/models/base_dataset.py @@ -62,6 +62,9 @@ class BaseDataset(Dataset): self._source_path = path raise NotImplementedError() + def get_input_shape(self): + return None + def _subdivide(self, amount: Union[int, float]): if self._data is None: raise ValueError("Cannot subdivide! Data not loaded, call `load()` first to load data") diff --git a/models/base_encoder.py b/models/base_encoder.py index 1453e0e..df386c3 100644 --- a/models/base_encoder.py +++ b/models/base_encoder.py @@ -117,7 +117,7 @@ class BaseEncoder(torch.nn.Module): outputs = None for epoch in range(epochs): self.log.debug(f"Start training epoch {epoch + 1}...") - loss = 0 + loss = [] for i, batch_features in enumerate(train_loader): # # load batch features to the active device # batch_features = batch_features.to(self.device) @@ -151,15 +151,15 @@ class BaseEncoder(torch.nn.Module): self.optimizer.step() # add the mini-batch training loss to epoch loss - loss += train_loss.item() + loss.append(train_loss.item()) # Print progress every 50 batches - if i % 50 == 0: + if i % 100 == 0: self.log.debug(f" progress: [{i * len(batch_features)}/{len(train_loader.dataset)} " f"({(100 * i / len(train_loader)):.0f}%)]") # compute the epoch training loss - loss = loss / len(train_loader) + loss = sum(loss) / len(loss) # display the epoch training loss self.log.info("epoch : {}/{}, loss = {:.6f}".format(epoch + 1, epochs, loss)) diff --git a/models/cifar10_dataset.py b/models/cifar10_dataset.py index 120e591..ba2661c 100644 --- a/models/cifar10_dataset.py +++ b/models/cifar10_dataset.py @@ -50,6 +50,9 @@ class Cifar10Dataset(BaseDataset): self.log.info(f"Loaded {self}, divided into {self._trainset} and {self._testset}") + def get_input_shape(self): + return 3072 # 32x32x3 (32x32px, 3 colors) + def __getitem__(self, item): # Get image data img = self._data[item] @@ -87,7 +90,7 @@ class Cifar10Dataset(BaseDataset): return img def save_batch_to_sample(self, batch, filename): - img = batch.view(batch.size(0), 3, 32, 32) + img = batch.view(batch.size(0), 3, 32, 32)[:48] save_image(img, f"{filename}.png") def calculate_score(self, originals, reconstruction, device): diff --git a/models/mnist_dataset.py b/models/mnist_dataset.py index cc106b6..8dd6c9b 100644 --- a/models/mnist_dataset.py +++ b/models/mnist_dataset.py @@ -41,6 +41,9 @@ class MNISTDataset(BaseDataset): self.log.info(f"Loaded {self}, divided into {self._trainset} and {self._testset}") + def get_input_shape(self): + return 784 # 28x28x1 (28x28px, 1 color) + def __getitem__(self, item): # Get image data img = self._data[item] @@ -55,7 +58,7 @@ class MNISTDataset(BaseDataset): return img def save_batch_to_sample(self, batch, filename): - img = batch.view(batch.size(0), 1, 28, 28) + img = batch.view(batch.size(0), 1, 28, 28)[:48] save_image(img, f"{filename}.png") def calculate_score(self, originals, reconstruction, device): diff --git a/models/random_corruption.py b/models/random_corruption.py new file mode 100644 index 0000000..27e38a2 --- /dev/null +++ b/models/random_corruption.py @@ -0,0 +1,48 @@ +import random + +from torch import Tensor + +from models.base_corruption import BaseCorruption +from models.base_dataset import BaseDataset +import numpy + + +def add_noise(image): + if isinstance(image, Tensor): + image = image.numpy() + image = image.astype(numpy.float32) + + # 90% chance to corrupt something + if random.random() < 0.9: + corrupt_index1, corrupt_index2 = random.sample(range(len(image)), 2) + image[corrupt_index1] = 0 + # 10% chance to corrupt a second column + if random.random() < 0.1: + image[corrupt_index2] = 0 + + return image + + +class RandomCorruption(BaseCorruption): + """ + Corruption model that clears random fields of data. + """ + name = "Gaussian" + + @classmethod + def corrupt_image(cls, image: Tensor): + return add_noise(image.numpy()) + + @classmethod + def corrupt_dataset(cls, dataset: BaseDataset) -> BaseDataset: + data = [cls.corrupt_image(x) for x in dataset] + # data = list(map(add_noise, dataset._data)) + train_set = cls.corrupt_dataset(dataset.get_train()) if dataset.has_train() else None + test_set = cls.corrupt_dataset(dataset.get_test()) if dataset.has_test() else None + return dataset.__class__.get_new( + name=f"{dataset.name} Corrupted", + data=data, + labels=dataset._labels, + source_path=dataset._source_path, + train_set=train_set, + test_set=test_set) diff --git a/models/test_run.py b/models/test_run.py index ff37380..f3e0cd3 100644 --- a/models/test_run.py +++ b/models/test_run.py @@ -4,7 +4,7 @@ import multiprocessing from models.base_corruption import BaseCorruption from models.base_dataset import BaseDataset from models.base_encoder import BaseEncoder -from utils import save_train_loss_graph +from utils import save_train_loss_graph, save_train_loss_values class TestRun: @@ -41,7 +41,7 @@ class TestRun: if retrain: # Train encoder self.log.info("Training auto-encoder...") - train_loss = self.encoder.train_encoder(self.dataset, epochs=50, num_workers=multiprocessing.cpu_count() - 1) + train_losses = self.encoder.train_encoder(self.dataset, epochs=50, num_workers=multiprocessing.cpu_count() - 1) if save_model: self.log.info("Saving auto-encoder model...") @@ -49,7 +49,8 @@ class TestRun: # Save train loss graph self.log.info("Saving loss graph...") - save_train_loss_graph(train_loss, f"{self.encoder.name}_{self.dataset.name}") + save_train_loss_graph(train_losses, f"{self.encoder.name}_{self.dataset.name}") + save_train_loss_values(train_losses, f"{self.encoder.name}_{self.dataset.name}") else: self.log.info("Loading saved auto-encoder...") load_success = self.encoder.load_model(f"{self.encoder.name}_{self.dataset.name}") diff --git a/models/usweather_dataset.py b/models/usweather_dataset.py new file mode 100644 index 0000000..5ba1a33 --- /dev/null +++ b/models/usweather_dataset.py @@ -0,0 +1,192 @@ +import csv +import os +from collections import defaultdict +from datetime import datetime + +from typing import Optional + +import numpy +import torch + +from config import DATASET_STORAGE_BASE_PATH +from models.base_dataset import BaseDataset + + +class USWeatherEventsDataset(BaseDataset): + # Source: https://smoosavi.org/datasets/lstw + # https://www.kaggle.com/sobhanmoosavi/us-weather-events + name = "US Weather Events" + + def transform(self, data): + return torch.from_numpy(numpy.array(data, numpy.float32, copy=False)) + + def unpickle(self, filename): + import pickle + with open(filename, 'rb') as fo: + dict = pickle.load(fo, encoding='bytes') + return dict + + def load(self, name: Optional[str] = None, path: Optional[str] = None): + if name is not None: + self.name = name + if path is not None: + self._source_path = path + + self._data = [] + self._labels = defaultdict(list) + + # Load from cache pickle file if it exists, else create cache file and load from csv + if os.path.isfile(os.path.join(DATASET_STORAGE_BASE_PATH, self._source_path, "weather_py_data.pickle"))\ + and os.path.isfile(os.path.join(DATASET_STORAGE_BASE_PATH, self._source_path, "weather_py_labels.pickle")): + self.log.info("Loading cached version of dataset...") + self._data = self.unpickle(os.path.join(DATASET_STORAGE_BASE_PATH, self._source_path, "weather_py_data.pickle")) + self._labels = self.unpickle(os.path.join(DATASET_STORAGE_BASE_PATH, self._source_path, "weather_py_labels.pickle")) + else: + self.log.info("Creating cached version of dataset...") + size = 5023709 + with open(os.path.join(DATASET_STORAGE_BASE_PATH, self._source_path, "WeatherEvents_Aug16_June20_Publish.csv")) as f: + data = csv.DictReader(f) + # Build label map before processing for 1-hot encoding + self.log.info("Preparing labels...") + for i, row in enumerate(data): + if i % 500000 == 0: + self.log.debug(f"{i} / ~{size} ({((i / size) * 100):.4f}%)") + + for label_type in ['Type', 'Severity', 'TimeZone', 'State']: + if row[label_type] not in self._labels[label_type]: + self._labels[label_type].append(row[label_type]) + + with open(os.path.join(DATASET_STORAGE_BASE_PATH, self._source_path, "WeatherEvents_Aug16_June20_Publish.csv")) as f: + data = csv.DictReader(f) + self.log.info("Processing data...") + for i, row in enumerate(data): + self._data.append(numpy.array([] + + # Event ID doesn't matter + # 1-hot encoded event type columns + [int(row['Type'] == self._labels['Type'][i]) for i in range(len(self._labels['Type']))] + + + # 1-hot encoded event severity columns + [int(row['Severity'] == self._labels['Severity'][i]) for i in range(len(self._labels['Severity']))] + + + [ + # Start time as unix timestamp + datetime.strptime(row['StartTime(UTC)'], "%Y-%m-%d %H:%M:%S").timestamp(), + # End time as unix timestamp + datetime.strptime(row['EndTime(UTC)'], "%Y-%m-%d %H:%M:%S").timestamp() + ] + + + # 1-hot encoded event timezone columns + [int(row['TimeZone'] == self._labels['TimeZone'][i]) for i in range(len(self._labels['TimeZone']))] + + + [ + # Location Latitude as float + float(row['LocationLat']), + # Location Longitude as float + float(row['LocationLng']), + ] + + + # 1-hot encoded event state columns + [int(row['State'] == self._labels['State'][i]) for i in range(len(self._labels['State']))] + + # Airport code, city, county and zip code are not considered, + # as they have too many unique values for 1-hot encoding. + )) + + if i % 500000 == 0: + self.log.debug(f"{i} / ~{size} ({((i / size) * 100):.4f}%)") + + self.log.info("Shuffling data...") + rng = numpy.random.default_rng() + rng.shuffle(self._data) + + self.log.info("Saving cached version...") + import pickle + with open(os.path.join(DATASET_STORAGE_BASE_PATH, self._source_path, "weather_py_data.pickle"), 'wb') as f: + pickle.dump(self._data, f) + with open(os.path.join(DATASET_STORAGE_BASE_PATH, self._source_path, "weather_py_labels.pickle"), 'wb') as f: + pickle.dump(dict(self._labels), f) + self.log.info("Cached version created.") + + train_data, test_data = self._data[:2500000], self._data[2500000:] + + self._trainset = self.__class__.get_new(name=f"{self.name} Training", data=train_data, labels=self._labels, + source_path=self._source_path) + + self._testset = self.__class__.get_new(name=f"{self.name} Testing", data=test_data, labels=self._labels, + source_path=self._source_path) + + self.log.info(f"Loaded {self}, divided into {self._trainset} and {self._testset}") + + def get_input_shape(self): + if os.path.isfile(os.path.join(DATASET_STORAGE_BASE_PATH, self._source_path, "weather_py_labels.pickle")): + labels = self.unpickle(os.path.join(DATASET_STORAGE_BASE_PATH, self._source_path, "weather_py_labels.pickle")) + size = 0 + size += len(labels['Type']) + size += len(labels['Severity']) + size += 2 + size += len(labels['TimeZone']) + size += 2 + size += len(labels['State']) + return size + else: + return 69 + + def __getitem__(self, item): + data = self._data[item] + + # Run transforms + if self.transform is not None: + data = self.transform(data) + + return data + + def save_batch_to_sample(self, batch, filename): + res = ["Type,Severity,StartTime(UTC),EndTime(UTC),TimeZone,LocationLat,LocationLng,State\n"] + + for row in batch: + # Get 1-hot encoded values as list per value, and other values as value + row = row.tolist() + start = 0 + length = len(self._labels['Type']) + event_types = row[start:start+length] + start += length + length = len(self._labels['Severity']) + severities = row[start:start+length] + start += length + start_time = row[start] + end_time = row[start+1] + start += 2 + length = len(self._labels['TimeZone']) + timezones = row[start:start+length] + start += length + location_lat = row[start] + location_lng = row[start+1] + start += 2 + length = len(self._labels['State']) + states = row[start:start+length] + + # Convert 1-hot encodings to normal labels, assume highest value as the true value. + event_type = self._labels['Type'][event_types.index(max(event_types))] + severity = self._labels['Severity'][severities.index(max(severities))] + timezone = self._labels['TimeZone'][timezones.index(max(timezones))] + state = self._labels['State'][states.index(max(states))] + + # Convert timestamp float into string time + start_time = datetime.fromtimestamp(start_time).strftime("%Y-%m-%d %H:%M:%S") + end_time = datetime.fromtimestamp(end_time).strftime("%Y-%m-%d %H:%M:%S") + + res.append(f"{event_type},{severity},{start_time},{end_time},{timezone},{location_lat},{location_lng},{state}\n") + + with open(f"{filename}.csv", "w") as f: + f.writelines(res) + + def calculate_score(self, originals, reconstruction, device): + originals = originals.to(device) + reconstruction = reconstruction.to(device) + + total_score = 0 + for i in range(len(originals)): + original, recon = originals[i], reconstruction[i] + total_score += sum(int(original[j] == recon[j]) for j in range(len(original))) / len(original) + + return total_score / len(originals) diff --git a/utils.py b/utils.py index b80b69c..d2a808d 100644 --- a/utils.py +++ b/utils.py @@ -116,3 +116,8 @@ def save_train_loss_graph(train_loss, filename): plt.ylabel('Loss') plt.yscale('log') plt.savefig(os.path.join(TRAIN_TEMP_DATA_BASE_PATH, f'{filename}_loss.png')) + + +def save_train_loss_values(train_loss, filename): + with open(os.path.join(TRAIN_TEMP_DATA_BASE_PATH, f'{filename}_loss.csv'), 'w') as f: + f.write(",".join(map(str, train_loss)))