From d0785e12e26313d931863519bec3b0c766b2603c Mon Sep 17 00:00:00 2001
From: Kevin Alberts <kevin@kevinalberts.nl>
Date: Fri, 29 Jan 2021 10:52:49 +0100
Subject: [PATCH] Small changes to make the US weather dataset work properly,
 example test runs in config file.

---
 config.example.py             | 151 ++++++++++++++++++++++++++++++++--
 main.py                       |   2 +-
 models/contractive_encoder.py |   3 +-
 models/usweather_dataset.py   |  54 ++----------
 4 files changed, 154 insertions(+), 56 deletions(-)

diff --git a/config.example.py b/config.example.py
index ef885de..162a62c 100644
--- a/config.example.py
+++ b/config.example.py
@@ -3,14 +3,147 @@ DATASET_STORAGE_BASE_PATH = "/path/to/this/project/datasets"
 TRAIN_TEMP_DATA_BASE_PATH = "/path/to/this/project/train_temp"
 TEST_TEMP_DATA_BASE_PATH = "/path/to/this/project/test_temp"
 
+
 TEST_RUNS = [
-    {
-        'name': "Basic test run",
-        'encoder_model': "models.base_encoder.BaseEncoder",
-        'encoder_kwargs': {},
-        'dataset_model': "models.base_dataset.BaseDataset",
-        'dataset_kwargs': {},
-        'corruption_model': "models.base_corruption.NoCorruption",
-        'corruption_kwargs': {},
-    },
+    # CIFAR-10 dataset
+    # {
+    #     'name': "CIFAR-10 on basic auto-encoder",
+    #     'encoder_model': "models.basic_encoder.BasicAutoEncoder",
+    #     'encoder_kwargs': {},
+    #     'dataset_model': "models.cifar10_dataset.Cifar10Dataset",
+    #     'dataset_kwargs': {"path": "cifar-10-batches-py"},
+    #     'corruption_model': "models.gaussian_corruption.GaussianCorruption",
+    #     'corruption_kwargs': {},
+    # },
+    # {
+    #     'name': "CIFAR-10 on sparse L1 auto-encoder",
+    #     'encoder_model': "models.sparse_encoder.SparseL1AutoEncoder",
+    #     'encoder_kwargs': {},
+    #     'dataset_model': "models.cifar10_dataset.Cifar10Dataset",
+    #     'dataset_kwargs': {"path": "cifar-10-batches-py"},
+    #     'corruption_model': "models.gaussian_corruption.GaussianCorruption",
+    #     'corruption_kwargs': {},
+    # },
+    # {
+    #     'name': "CIFAR-10 on denoising auto-encoder",
+    #     'encoder_model': "models.denoising_encoder.DenoisingAutoEncoder",
+    #     'encoder_kwargs': {'input_corruption_model': "models.gaussian_corruption.GaussianCorruption"},
+    #     'dataset_model': "models.cifar10_dataset.Cifar10Dataset",
+    #     'dataset_kwargs': {"path": "cifar-10-batches-py"},
+    #     'corruption_model': "models.gaussian_corruption.GaussianCorruption",
+    #     'corruption_kwargs': {},
+    # },
+    # {
+    #     'name': "CIFAR-10 on contractive auto-encoder",
+    #     'encoder_model': "models.contractive_encoder.ContractiveAutoEncoder",
+    #     'encoder_kwargs': {},
+    #     'dataset_model': "models.cifar10_dataset.Cifar10Dataset",
+    #     'dataset_kwargs': {"path": "cifar-10-batches-py"},
+    #     'corruption_model': "models.gaussian_corruption.GaussianCorruption",
+    #     'corruption_kwargs': {},
+    # },
+    # {
+    #     'name': "CIFAR-10 on variational auto-encoder",
+    #     'encoder_model': "models.variational_encoder.VariationalAutoEncoder",
+    #     'encoder_kwargs': {},
+    #     'dataset_model': "models.cifar10_dataset.Cifar10Dataset",
+    #     'dataset_kwargs': {"path": "cifar-10-batches-py"},
+    #     'corruption_model': "models.gaussian_corruption.GaussianCorruption",
+    #     'corruption_kwargs': {},
+    # },
+
+    # MNIST dataset
+    # {
+    #     'name': "MNIST on basic auto-encoder",
+    #     'encoder_model': "models.basic_encoder.BasicAutoEncoder",
+    #     'encoder_kwargs': {},
+    #     'dataset_model': "models.mnist_dataset.MNISTDataset",
+    #     'dataset_kwargs': {"path": "mnist"},
+    #     'corruption_model': "models.gaussian_corruption.GaussianCorruption",
+    #     'corruption_kwargs': {},
+    # },
+    # {
+    #     'name': "MNIST on sparse L1 auto-encoder",
+    #     'encoder_model': "models.sparse_encoder.SparseL1AutoEncoder",
+    #     'encoder_kwargs': {},
+    #     'dataset_model': "models.mnist_dataset.MNISTDataset",
+    #     'dataset_kwargs': {"path": "mnist"},
+    #     'corruption_model': "models.gaussian_corruption.GaussianCorruption",
+    #     'corruption_kwargs': {},
+    # },
+    # {
+    #     'name': "MNIST on denoising auto-encoder",
+    #     'encoder_model': "models.denoising_encoder.DenoisingAutoEncoder",
+    #     'encoder_kwargs': {'input_corruption_model': "models.gaussian_corruption.GaussianCorruption"},
+    #     'dataset_model': "models.mnist_dataset.MNISTDataset",
+    #     'dataset_kwargs': {"path": "mnist"},
+    #     'corruption_model': "models.gaussian_corruption.GaussianCorruption",
+    #     'corruption_kwargs': {},
+    # },
+    # {
+    #     'name': "MNIST on contractive auto-encoder",
+    #     'encoder_model': "models.contractive_encoder.ContractiveAutoEncoder",
+    #     'encoder_kwargs': {},
+    #     'dataset_model': "models.mnist_dataset.MNISTDataset",
+    #     'dataset_kwargs': {"path": "mnist"},
+    #     'corruption_model': "models.gaussian_corruption.GaussianCorruption",
+    #     'corruption_kwargs': {},
+    # },
+    # {
+    #     'name': "MNIST on variational auto-encoder",
+    #     'encoder_model': "models.variational_encoder.VariationalAutoEncoder",
+    #     'encoder_kwargs': {},
+    #     'dataset_model': "models.mnist_dataset.MNISTDataset",
+    #     'dataset_kwargs': {"path": "mnist"},
+    #     'corruption_model': "models.gaussian_corruption.GaussianCorruption",
+    #     'corruption_kwargs': {},
+    # },
+
+    # US Weather Events dataset
+    # {
+    #     'name': "US Weather Events on basic auto-encoder",
+    #     'encoder_model': "models.basic_encoder.BasicAutoEncoder",
+    #     'encoder_kwargs': {},
+    #     'dataset_model': "models.usweather_dataset.USWeatherEventsDataset",
+    #     'dataset_kwargs': {"path": "weather-events"},
+    #     'corruption_model': "models.random_corruption.RandomCorruption",
+    #     'corruption_kwargs': {},
+    # },
+    # {
+    #     'name': "US Weather Events on sparse L1 auto-encoder",
+    #     'encoder_model': "models.sparse_encoder.SparseL1AutoEncoder",
+    #     'encoder_kwargs': {},
+    #     'dataset_model': "models.usweather_dataset.USWeatherEventsDataset",
+    #     'dataset_kwargs': {"path": "weather-events"},
+    #     'corruption_model': "models.random_corruption.RandomCorruption",
+    #     'corruption_kwargs': {},
+    # },
+    # {
+    #     'name': "US Weather Events on denoising auto-encoder",
+    #     'encoder_model': "models.denoising_encoder.DenoisingAutoEncoder",
+    #     'encoder_kwargs': {'input_corruption_model': "models.random_corruption.RandomCorruption"},
+    #     'dataset_model': "models.usweather_dataset.USWeatherEventsDataset",
+    #     'dataset_kwargs': {"path": "weather-events"},
+    #     'corruption_model': "models.random_corruption.RandomCorruption",
+    #     'corruption_kwargs': {},
+    # },
+    # {
+    #     'name': "US Weather Events on contractive auto-encoder",
+    #     'encoder_model': "models.contractive_encoder.ContractiveAutoEncoder",
+    #     'encoder_kwargs': {},
+    #     'dataset_model': "models.usweather_dataset.USWeatherEventsDataset",
+    #     'dataset_kwargs': {"path": "weather-events"},
+    #     'corruption_model': "models.random_corruption.RandomCorruption",
+    #     'corruption_kwargs': {},
+    # },
+    # {
+    #     'name': "US Weather Events on variational auto-encoder",
+    #     'encoder_model': "models.variational_encoder.VariationalAutoEncoder",
+    #     'encoder_kwargs': {},
+    #     'dataset_model': "models.usweather_dataset.USWeatherEventsDataset",
+    #     'dataset_kwargs': {"path": "weather-events"},
+    #     'corruption_model': "models.random_corruption.RandomCorruption",
+    #     'corruption_kwargs': {},
+    # },
 ]
+
diff --git a/main.py b/main.py
index 294f460..c204bb6 100644
--- a/main.py
+++ b/main.py
@@ -44,7 +44,7 @@ def run_tests():
         test_run = TestRun(dataset=dataset, encoder=encoder, corruption=corruption)
 
         # Run TestRun
-        test_run.run(retrain=True)
+        test_run.run(retrain=False)
 
         # Cleanup to avoid out-of-memory situations when running lots of tests
         del test_run
diff --git a/models/contractive_encoder.py b/models/contractive_encoder.py
index eeadb98..b128dce 100644
--- a/models/contractive_encoder.py
+++ b/models/contractive_encoder.py
@@ -73,7 +73,8 @@ class ContractiveAutoEncoder(BaseEncoder):
         weights = self.state_dict()['encoder.2.weight']
 
         # Hadamard product
-        hidden_output = hidden_output.reshape(hidden_output.shape[0], hidden_output.shape[2])
+        if len(hidden_output.shape) > 2:
+            hidden_output = hidden_output.reshape(hidden_output.shape[0], hidden_output.shape[2])
         dh = hidden_output * (1 - hidden_output)
 
         # Sum through input dimension to improve efficiency (suggested in reference)
diff --git a/models/usweather_dataset.py b/models/usweather_dataset.py
index 6d71f06..9e1ec04 100644
--- a/models/usweather_dataset.py
+++ b/models/usweather_dataset.py
@@ -1,7 +1,6 @@
 import csv
 import os
 from collections import defaultdict
-from datetime import datetime
 
 from typing import Optional
 
@@ -26,31 +25,22 @@ class USWeatherLoss(_Loss):
     def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
         losses = []
         start = 0
+
         length = len(self.dataset._labels['Type'])
         # Type is 1-hot encoded, so use cross entropy loss
-        losses.append(self.ce_loss(input[start:start+length], torch.argmax(target[start:start+length].long(), dim=1)))
+        losses.append(self.ce_loss(input[:, start:start+length], torch.argmax(target[:, start:start+length].long(), dim=1)))
         start += length
         length = len(self.dataset._labels['Severity'])
         # Severity is 1-hot encoded, so use cross entropy loss
-        losses.append(self.ce_loss(input[start:start+length], torch.argmax(target[start:start+length].long(), dim=1)))
+        losses.append(self.ce_loss(input[:, start:start+length], torch.argmax(target[:, start:start+length].long(), dim=1)))
         start += length
-        # Start time is a number, so use L1 loss
-        losses.append(self.l1_loss(input[start], target[start]))
-        # End time is a number, so use L1 loss
-        losses.append(self.l1_loss(input[start + 1], target[start + 1]))
-        start += 2
         length = len(self.dataset._labels['TimeZone'])
         # TimeZone is 1-hot encoded, so use cross entropy loss
-        losses.append(self.ce_loss(input[start:start+length], torch.argmax(target[start:start+length].long(), dim=1)))
+        losses.append(self.ce_loss(input[:, start:start+length], torch.argmax(target[:, start:start+length].long(), dim=1)))
         start += length
-        # Location latitude is a number, so use L1 loss
-        losses.append(self.l1_loss(input[start], target[start]))
-        # Location longitude is a number, so use L1 loss
-        losses.append(self.l1_loss(input[start + 1], target[start + 1]))
-        start += 2
         length = len(self.dataset._labels['State'])
         # State is 1-hot encoded, so use cross entropy loss
-        losses.append(self.ce_loss(input[start:start+length], torch.argmax(target[start:start+length].long(), dim=1)))
+        losses.append(self.ce_loss(input[:, start:start+length], torch.argmax(target[:, start:start+length].long(), dim=1)))
         return sum(losses)
 
 
@@ -110,23 +100,9 @@ class USWeatherEventsDataset(BaseDataset):
                         # 1-hot encoded event severity columns
                         [int(row['Severity'] == self._labels['Severity'][i]) for i in range(len(self._labels['Severity']))] +
 
-                        [
-                            # Start time as unix timestamp
-                            datetime.strptime(row['StartTime(UTC)'], "%Y-%m-%d %H:%M:%S").timestamp(),
-                            # End time as unix timestamp
-                            datetime.strptime(row['EndTime(UTC)'], "%Y-%m-%d %H:%M:%S").timestamp()
-                        ] +
-
                         # 1-hot encoded event timezone columns
                         [int(row['TimeZone'] == self._labels['TimeZone'][i]) for i in range(len(self._labels['TimeZone']))] +
 
-                        [
-                            # Location Latitude as float
-                            float(row['LocationLat']),
-                            # Location Longitude as float
-                            float(row['LocationLng']),
-                        ] +
-
                         # 1-hot encoded event state columns
                         [int(row['State'] == self._labels['State'][i]) for i in range(len(self._labels['State']))]
 
@@ -151,7 +127,7 @@ class USWeatherEventsDataset(BaseDataset):
 
         # train_data, test_data = self._data[:2500000], self._data[2500000:]
         # Speed up training a bit
-        train_data, test_data = self._data[:50000], self._data[100000:150000]
+        train_data, test_data = self._data[:250000], self._data[250000:500000]
 
         self._trainset = self.__class__.get_new(name=f"{self.name} Training", data=train_data, labels=self._labels,
                                                 source_path=self._source_path)
@@ -167,13 +143,11 @@ class USWeatherEventsDataset(BaseDataset):
             size = 0
             size += len(labels['Type'])
             size += len(labels['Severity'])
-            size += 2
             size += len(labels['TimeZone'])
-            size += 2
             size += len(labels['State'])
             return size
         else:
-            return 69
+            return 65
 
     def __getitem__(self, item):
         data = self._data[item]
@@ -196,15 +170,9 @@ class USWeatherEventsDataset(BaseDataset):
         length = len(self._labels['Severity'])
         severities = output[start:start+length]
         start += length
-        start_time = output[start]
-        end_time = output[start+1]
-        start += 2
         length = len(self._labels['TimeZone'])
         timezones = output[start:start+length]
         start += length
-        location_lat = output[start]
-        location_lng = output[start+1]
-        start += 2
         length = len(self._labels['State'])
         states = output[start:start+length]
 
@@ -214,14 +182,10 @@ class USWeatherEventsDataset(BaseDataset):
         timezone = self._labels['TimeZone'][timezones.index(max(timezones))]
         state = self._labels['State'][states.index(max(states))]
 
-        # Convert timestamp float into string time
-        start_time = datetime.fromtimestamp(start_time).strftime("%Y-%m-%d %H:%M:%S")
-        end_time = datetime.fromtimestamp(end_time).strftime("%Y-%m-%d %H:%M:%S")
-
-        return [event_type, severity, start_time, end_time, timezone, location_lat, location_lng, state]
+        return [event_type, severity, timezone, state]
 
     def save_batch_to_sample(self, batch, filename):
-        res = ["Type,Severity,StartTime(UTC),EndTime(UTC),TimeZone,LocationLat,LocationLng,State\n"]
+        res = ["Type,Severity,TimeZone,State\n"]
 
         for row in batch:
             row = row.tolist()