From d0785e12e26313d931863519bec3b0c766b2603c Mon Sep 17 00:00:00 2001 From: Kevin Alberts <kevin@kevinalberts.nl> Date: Fri, 29 Jan 2021 10:52:49 +0100 Subject: [PATCH] Small changes to make the US weather dataset work properly, example test runs in config file. --- config.example.py | 151 ++++++++++++++++++++++++++++++++-- main.py | 2 +- models/contractive_encoder.py | 3 +- models/usweather_dataset.py | 54 ++---------- 4 files changed, 154 insertions(+), 56 deletions(-) diff --git a/config.example.py b/config.example.py index ef885de..162a62c 100644 --- a/config.example.py +++ b/config.example.py @@ -3,14 +3,147 @@ DATASET_STORAGE_BASE_PATH = "/path/to/this/project/datasets" TRAIN_TEMP_DATA_BASE_PATH = "/path/to/this/project/train_temp" TEST_TEMP_DATA_BASE_PATH = "/path/to/this/project/test_temp" + TEST_RUNS = [ - { - 'name': "Basic test run", - 'encoder_model': "models.base_encoder.BaseEncoder", - 'encoder_kwargs': {}, - 'dataset_model': "models.base_dataset.BaseDataset", - 'dataset_kwargs': {}, - 'corruption_model': "models.base_corruption.NoCorruption", - 'corruption_kwargs': {}, - }, + # CIFAR-10 dataset + # { + # 'name': "CIFAR-10 on basic auto-encoder", + # 'encoder_model': "models.basic_encoder.BasicAutoEncoder", + # 'encoder_kwargs': {}, + # 'dataset_model': "models.cifar10_dataset.Cifar10Dataset", + # 'dataset_kwargs': {"path": "cifar-10-batches-py"}, + # 'corruption_model': "models.gaussian_corruption.GaussianCorruption", + # 'corruption_kwargs': {}, + # }, + # { + # 'name': "CIFAR-10 on sparse L1 auto-encoder", + # 'encoder_model': "models.sparse_encoder.SparseL1AutoEncoder", + # 'encoder_kwargs': {}, + # 'dataset_model': "models.cifar10_dataset.Cifar10Dataset", + # 'dataset_kwargs': {"path": "cifar-10-batches-py"}, + # 'corruption_model': "models.gaussian_corruption.GaussianCorruption", + # 'corruption_kwargs': {}, + # }, + # { + # 'name': "CIFAR-10 on denoising auto-encoder", + # 'encoder_model': "models.denoising_encoder.DenoisingAutoEncoder", + # 'encoder_kwargs': {'input_corruption_model': "models.gaussian_corruption.GaussianCorruption"}, + # 'dataset_model': "models.cifar10_dataset.Cifar10Dataset", + # 'dataset_kwargs': {"path": "cifar-10-batches-py"}, + # 'corruption_model': "models.gaussian_corruption.GaussianCorruption", + # 'corruption_kwargs': {}, + # }, + # { + # 'name': "CIFAR-10 on contractive auto-encoder", + # 'encoder_model': "models.contractive_encoder.ContractiveAutoEncoder", + # 'encoder_kwargs': {}, + # 'dataset_model': "models.cifar10_dataset.Cifar10Dataset", + # 'dataset_kwargs': {"path": "cifar-10-batches-py"}, + # 'corruption_model': "models.gaussian_corruption.GaussianCorruption", + # 'corruption_kwargs': {}, + # }, + # { + # 'name': "CIFAR-10 on variational auto-encoder", + # 'encoder_model': "models.variational_encoder.VariationalAutoEncoder", + # 'encoder_kwargs': {}, + # 'dataset_model': "models.cifar10_dataset.Cifar10Dataset", + # 'dataset_kwargs': {"path": "cifar-10-batches-py"}, + # 'corruption_model': "models.gaussian_corruption.GaussianCorruption", + # 'corruption_kwargs': {}, + # }, + + # MNIST dataset + # { + # 'name': "MNIST on basic auto-encoder", + # 'encoder_model': "models.basic_encoder.BasicAutoEncoder", + # 'encoder_kwargs': {}, + # 'dataset_model': "models.mnist_dataset.MNISTDataset", + # 'dataset_kwargs': {"path": "mnist"}, + # 'corruption_model': "models.gaussian_corruption.GaussianCorruption", + # 'corruption_kwargs': {}, + # }, + # { + # 'name': "MNIST on sparse L1 auto-encoder", + # 'encoder_model': "models.sparse_encoder.SparseL1AutoEncoder", + # 'encoder_kwargs': {}, + # 'dataset_model': "models.mnist_dataset.MNISTDataset", + # 'dataset_kwargs': {"path": "mnist"}, + # 'corruption_model': "models.gaussian_corruption.GaussianCorruption", + # 'corruption_kwargs': {}, + # }, + # { + # 'name': "MNIST on denoising auto-encoder", + # 'encoder_model': "models.denoising_encoder.DenoisingAutoEncoder", + # 'encoder_kwargs': {'input_corruption_model': "models.gaussian_corruption.GaussianCorruption"}, + # 'dataset_model': "models.mnist_dataset.MNISTDataset", + # 'dataset_kwargs': {"path": "mnist"}, + # 'corruption_model': "models.gaussian_corruption.GaussianCorruption", + # 'corruption_kwargs': {}, + # }, + # { + # 'name': "MNIST on contractive auto-encoder", + # 'encoder_model': "models.contractive_encoder.ContractiveAutoEncoder", + # 'encoder_kwargs': {}, + # 'dataset_model': "models.mnist_dataset.MNISTDataset", + # 'dataset_kwargs': {"path": "mnist"}, + # 'corruption_model': "models.gaussian_corruption.GaussianCorruption", + # 'corruption_kwargs': {}, + # }, + # { + # 'name': "MNIST on variational auto-encoder", + # 'encoder_model': "models.variational_encoder.VariationalAutoEncoder", + # 'encoder_kwargs': {}, + # 'dataset_model': "models.mnist_dataset.MNISTDataset", + # 'dataset_kwargs': {"path": "mnist"}, + # 'corruption_model': "models.gaussian_corruption.GaussianCorruption", + # 'corruption_kwargs': {}, + # }, + + # US Weather Events dataset + # { + # 'name': "US Weather Events on basic auto-encoder", + # 'encoder_model': "models.basic_encoder.BasicAutoEncoder", + # 'encoder_kwargs': {}, + # 'dataset_model': "models.usweather_dataset.USWeatherEventsDataset", + # 'dataset_kwargs': {"path": "weather-events"}, + # 'corruption_model': "models.random_corruption.RandomCorruption", + # 'corruption_kwargs': {}, + # }, + # { + # 'name': "US Weather Events on sparse L1 auto-encoder", + # 'encoder_model': "models.sparse_encoder.SparseL1AutoEncoder", + # 'encoder_kwargs': {}, + # 'dataset_model': "models.usweather_dataset.USWeatherEventsDataset", + # 'dataset_kwargs': {"path": "weather-events"}, + # 'corruption_model': "models.random_corruption.RandomCorruption", + # 'corruption_kwargs': {}, + # }, + # { + # 'name': "US Weather Events on denoising auto-encoder", + # 'encoder_model': "models.denoising_encoder.DenoisingAutoEncoder", + # 'encoder_kwargs': {'input_corruption_model': "models.random_corruption.RandomCorruption"}, + # 'dataset_model': "models.usweather_dataset.USWeatherEventsDataset", + # 'dataset_kwargs': {"path": "weather-events"}, + # 'corruption_model': "models.random_corruption.RandomCorruption", + # 'corruption_kwargs': {}, + # }, + # { + # 'name': "US Weather Events on contractive auto-encoder", + # 'encoder_model': "models.contractive_encoder.ContractiveAutoEncoder", + # 'encoder_kwargs': {}, + # 'dataset_model': "models.usweather_dataset.USWeatherEventsDataset", + # 'dataset_kwargs': {"path": "weather-events"}, + # 'corruption_model': "models.random_corruption.RandomCorruption", + # 'corruption_kwargs': {}, + # }, + # { + # 'name': "US Weather Events on variational auto-encoder", + # 'encoder_model': "models.variational_encoder.VariationalAutoEncoder", + # 'encoder_kwargs': {}, + # 'dataset_model': "models.usweather_dataset.USWeatherEventsDataset", + # 'dataset_kwargs': {"path": "weather-events"}, + # 'corruption_model': "models.random_corruption.RandomCorruption", + # 'corruption_kwargs': {}, + # }, ] + diff --git a/main.py b/main.py index 294f460..c204bb6 100644 --- a/main.py +++ b/main.py @@ -44,7 +44,7 @@ def run_tests(): test_run = TestRun(dataset=dataset, encoder=encoder, corruption=corruption) # Run TestRun - test_run.run(retrain=True) + test_run.run(retrain=False) # Cleanup to avoid out-of-memory situations when running lots of tests del test_run diff --git a/models/contractive_encoder.py b/models/contractive_encoder.py index eeadb98..b128dce 100644 --- a/models/contractive_encoder.py +++ b/models/contractive_encoder.py @@ -73,7 +73,8 @@ class ContractiveAutoEncoder(BaseEncoder): weights = self.state_dict()['encoder.2.weight'] # Hadamard product - hidden_output = hidden_output.reshape(hidden_output.shape[0], hidden_output.shape[2]) + if len(hidden_output.shape) > 2: + hidden_output = hidden_output.reshape(hidden_output.shape[0], hidden_output.shape[2]) dh = hidden_output * (1 - hidden_output) # Sum through input dimension to improve efficiency (suggested in reference) diff --git a/models/usweather_dataset.py b/models/usweather_dataset.py index 6d71f06..9e1ec04 100644 --- a/models/usweather_dataset.py +++ b/models/usweather_dataset.py @@ -1,7 +1,6 @@ import csv import os from collections import defaultdict -from datetime import datetime from typing import Optional @@ -26,31 +25,22 @@ class USWeatherLoss(_Loss): def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: losses = [] start = 0 + length = len(self.dataset._labels['Type']) # Type is 1-hot encoded, so use cross entropy loss - losses.append(self.ce_loss(input[start:start+length], torch.argmax(target[start:start+length].long(), dim=1))) + losses.append(self.ce_loss(input[:, start:start+length], torch.argmax(target[:, start:start+length].long(), dim=1))) start += length length = len(self.dataset._labels['Severity']) # Severity is 1-hot encoded, so use cross entropy loss - losses.append(self.ce_loss(input[start:start+length], torch.argmax(target[start:start+length].long(), dim=1))) + losses.append(self.ce_loss(input[:, start:start+length], torch.argmax(target[:, start:start+length].long(), dim=1))) start += length - # Start time is a number, so use L1 loss - losses.append(self.l1_loss(input[start], target[start])) - # End time is a number, so use L1 loss - losses.append(self.l1_loss(input[start + 1], target[start + 1])) - start += 2 length = len(self.dataset._labels['TimeZone']) # TimeZone is 1-hot encoded, so use cross entropy loss - losses.append(self.ce_loss(input[start:start+length], torch.argmax(target[start:start+length].long(), dim=1))) + losses.append(self.ce_loss(input[:, start:start+length], torch.argmax(target[:, start:start+length].long(), dim=1))) start += length - # Location latitude is a number, so use L1 loss - losses.append(self.l1_loss(input[start], target[start])) - # Location longitude is a number, so use L1 loss - losses.append(self.l1_loss(input[start + 1], target[start + 1])) - start += 2 length = len(self.dataset._labels['State']) # State is 1-hot encoded, so use cross entropy loss - losses.append(self.ce_loss(input[start:start+length], torch.argmax(target[start:start+length].long(), dim=1))) + losses.append(self.ce_loss(input[:, start:start+length], torch.argmax(target[:, start:start+length].long(), dim=1))) return sum(losses) @@ -110,23 +100,9 @@ class USWeatherEventsDataset(BaseDataset): # 1-hot encoded event severity columns [int(row['Severity'] == self._labels['Severity'][i]) for i in range(len(self._labels['Severity']))] + - [ - # Start time as unix timestamp - datetime.strptime(row['StartTime(UTC)'], "%Y-%m-%d %H:%M:%S").timestamp(), - # End time as unix timestamp - datetime.strptime(row['EndTime(UTC)'], "%Y-%m-%d %H:%M:%S").timestamp() - ] + - # 1-hot encoded event timezone columns [int(row['TimeZone'] == self._labels['TimeZone'][i]) for i in range(len(self._labels['TimeZone']))] + - [ - # Location Latitude as float - float(row['LocationLat']), - # Location Longitude as float - float(row['LocationLng']), - ] + - # 1-hot encoded event state columns [int(row['State'] == self._labels['State'][i]) for i in range(len(self._labels['State']))] @@ -151,7 +127,7 @@ class USWeatherEventsDataset(BaseDataset): # train_data, test_data = self._data[:2500000], self._data[2500000:] # Speed up training a bit - train_data, test_data = self._data[:50000], self._data[100000:150000] + train_data, test_data = self._data[:250000], self._data[250000:500000] self._trainset = self.__class__.get_new(name=f"{self.name} Training", data=train_data, labels=self._labels, source_path=self._source_path) @@ -167,13 +143,11 @@ class USWeatherEventsDataset(BaseDataset): size = 0 size += len(labels['Type']) size += len(labels['Severity']) - size += 2 size += len(labels['TimeZone']) - size += 2 size += len(labels['State']) return size else: - return 69 + return 65 def __getitem__(self, item): data = self._data[item] @@ -196,15 +170,9 @@ class USWeatherEventsDataset(BaseDataset): length = len(self._labels['Severity']) severities = output[start:start+length] start += length - start_time = output[start] - end_time = output[start+1] - start += 2 length = len(self._labels['TimeZone']) timezones = output[start:start+length] start += length - location_lat = output[start] - location_lng = output[start+1] - start += 2 length = len(self._labels['State']) states = output[start:start+length] @@ -214,14 +182,10 @@ class USWeatherEventsDataset(BaseDataset): timezone = self._labels['TimeZone'][timezones.index(max(timezones))] state = self._labels['State'][states.index(max(states))] - # Convert timestamp float into string time - start_time = datetime.fromtimestamp(start_time).strftime("%Y-%m-%d %H:%M:%S") - end_time = datetime.fromtimestamp(end_time).strftime("%Y-%m-%d %H:%M:%S") - - return [event_type, severity, start_time, end_time, timezone, location_lat, location_lng, state] + return [event_type, severity, timezone, state] def save_batch_to_sample(self, batch, filename): - res = ["Type,Severity,StartTime(UTC),EndTime(UTC),TimeZone,LocationLat,LocationLng,State\n"] + res = ["Type,Severity,TimeZone,State\n"] for row in batch: row = row.tolist()