Small changes to make the US weather dataset work properly, example test runs in config file.
This commit is contained in:
parent
e4c51e2d3d
commit
d0785e12e2
|
@ -3,14 +3,147 @@ DATASET_STORAGE_BASE_PATH = "/path/to/this/project/datasets"
|
||||||
TRAIN_TEMP_DATA_BASE_PATH = "/path/to/this/project/train_temp"
|
TRAIN_TEMP_DATA_BASE_PATH = "/path/to/this/project/train_temp"
|
||||||
TEST_TEMP_DATA_BASE_PATH = "/path/to/this/project/test_temp"
|
TEST_TEMP_DATA_BASE_PATH = "/path/to/this/project/test_temp"
|
||||||
|
|
||||||
|
|
||||||
TEST_RUNS = [
|
TEST_RUNS = [
|
||||||
{
|
# CIFAR-10 dataset
|
||||||
'name': "Basic test run",
|
# {
|
||||||
'encoder_model': "models.base_encoder.BaseEncoder",
|
# 'name': "CIFAR-10 on basic auto-encoder",
|
||||||
'encoder_kwargs': {},
|
# 'encoder_model': "models.basic_encoder.BasicAutoEncoder",
|
||||||
'dataset_model': "models.base_dataset.BaseDataset",
|
# 'encoder_kwargs': {},
|
||||||
'dataset_kwargs': {},
|
# 'dataset_model': "models.cifar10_dataset.Cifar10Dataset",
|
||||||
'corruption_model': "models.base_corruption.NoCorruption",
|
# 'dataset_kwargs': {"path": "cifar-10-batches-py"},
|
||||||
'corruption_kwargs': {},
|
# 'corruption_model': "models.gaussian_corruption.GaussianCorruption",
|
||||||
},
|
# 'corruption_kwargs': {},
|
||||||
|
# },
|
||||||
|
# {
|
||||||
|
# 'name': "CIFAR-10 on sparse L1 auto-encoder",
|
||||||
|
# 'encoder_model': "models.sparse_encoder.SparseL1AutoEncoder",
|
||||||
|
# 'encoder_kwargs': {},
|
||||||
|
# 'dataset_model': "models.cifar10_dataset.Cifar10Dataset",
|
||||||
|
# 'dataset_kwargs': {"path": "cifar-10-batches-py"},
|
||||||
|
# 'corruption_model': "models.gaussian_corruption.GaussianCorruption",
|
||||||
|
# 'corruption_kwargs': {},
|
||||||
|
# },
|
||||||
|
# {
|
||||||
|
# 'name': "CIFAR-10 on denoising auto-encoder",
|
||||||
|
# 'encoder_model': "models.denoising_encoder.DenoisingAutoEncoder",
|
||||||
|
# 'encoder_kwargs': {'input_corruption_model': "models.gaussian_corruption.GaussianCorruption"},
|
||||||
|
# 'dataset_model': "models.cifar10_dataset.Cifar10Dataset",
|
||||||
|
# 'dataset_kwargs': {"path": "cifar-10-batches-py"},
|
||||||
|
# 'corruption_model': "models.gaussian_corruption.GaussianCorruption",
|
||||||
|
# 'corruption_kwargs': {},
|
||||||
|
# },
|
||||||
|
# {
|
||||||
|
# 'name': "CIFAR-10 on contractive auto-encoder",
|
||||||
|
# 'encoder_model': "models.contractive_encoder.ContractiveAutoEncoder",
|
||||||
|
# 'encoder_kwargs': {},
|
||||||
|
# 'dataset_model': "models.cifar10_dataset.Cifar10Dataset",
|
||||||
|
# 'dataset_kwargs': {"path": "cifar-10-batches-py"},
|
||||||
|
# 'corruption_model': "models.gaussian_corruption.GaussianCorruption",
|
||||||
|
# 'corruption_kwargs': {},
|
||||||
|
# },
|
||||||
|
# {
|
||||||
|
# 'name': "CIFAR-10 on variational auto-encoder",
|
||||||
|
# 'encoder_model': "models.variational_encoder.VariationalAutoEncoder",
|
||||||
|
# 'encoder_kwargs': {},
|
||||||
|
# 'dataset_model': "models.cifar10_dataset.Cifar10Dataset",
|
||||||
|
# 'dataset_kwargs': {"path": "cifar-10-batches-py"},
|
||||||
|
# 'corruption_model': "models.gaussian_corruption.GaussianCorruption",
|
||||||
|
# 'corruption_kwargs': {},
|
||||||
|
# },
|
||||||
|
|
||||||
|
# MNIST dataset
|
||||||
|
# {
|
||||||
|
# 'name': "MNIST on basic auto-encoder",
|
||||||
|
# 'encoder_model': "models.basic_encoder.BasicAutoEncoder",
|
||||||
|
# 'encoder_kwargs': {},
|
||||||
|
# 'dataset_model': "models.mnist_dataset.MNISTDataset",
|
||||||
|
# 'dataset_kwargs': {"path": "mnist"},
|
||||||
|
# 'corruption_model': "models.gaussian_corruption.GaussianCorruption",
|
||||||
|
# 'corruption_kwargs': {},
|
||||||
|
# },
|
||||||
|
# {
|
||||||
|
# 'name': "MNIST on sparse L1 auto-encoder",
|
||||||
|
# 'encoder_model': "models.sparse_encoder.SparseL1AutoEncoder",
|
||||||
|
# 'encoder_kwargs': {},
|
||||||
|
# 'dataset_model': "models.mnist_dataset.MNISTDataset",
|
||||||
|
# 'dataset_kwargs': {"path": "mnist"},
|
||||||
|
# 'corruption_model': "models.gaussian_corruption.GaussianCorruption",
|
||||||
|
# 'corruption_kwargs': {},
|
||||||
|
# },
|
||||||
|
# {
|
||||||
|
# 'name': "MNIST on denoising auto-encoder",
|
||||||
|
# 'encoder_model': "models.denoising_encoder.DenoisingAutoEncoder",
|
||||||
|
# 'encoder_kwargs': {'input_corruption_model': "models.gaussian_corruption.GaussianCorruption"},
|
||||||
|
# 'dataset_model': "models.mnist_dataset.MNISTDataset",
|
||||||
|
# 'dataset_kwargs': {"path": "mnist"},
|
||||||
|
# 'corruption_model': "models.gaussian_corruption.GaussianCorruption",
|
||||||
|
# 'corruption_kwargs': {},
|
||||||
|
# },
|
||||||
|
# {
|
||||||
|
# 'name': "MNIST on contractive auto-encoder",
|
||||||
|
# 'encoder_model': "models.contractive_encoder.ContractiveAutoEncoder",
|
||||||
|
# 'encoder_kwargs': {},
|
||||||
|
# 'dataset_model': "models.mnist_dataset.MNISTDataset",
|
||||||
|
# 'dataset_kwargs': {"path": "mnist"},
|
||||||
|
# 'corruption_model': "models.gaussian_corruption.GaussianCorruption",
|
||||||
|
# 'corruption_kwargs': {},
|
||||||
|
# },
|
||||||
|
# {
|
||||||
|
# 'name': "MNIST on variational auto-encoder",
|
||||||
|
# 'encoder_model': "models.variational_encoder.VariationalAutoEncoder",
|
||||||
|
# 'encoder_kwargs': {},
|
||||||
|
# 'dataset_model': "models.mnist_dataset.MNISTDataset",
|
||||||
|
# 'dataset_kwargs': {"path": "mnist"},
|
||||||
|
# 'corruption_model': "models.gaussian_corruption.GaussianCorruption",
|
||||||
|
# 'corruption_kwargs': {},
|
||||||
|
# },
|
||||||
|
|
||||||
|
# US Weather Events dataset
|
||||||
|
# {
|
||||||
|
# 'name': "US Weather Events on basic auto-encoder",
|
||||||
|
# 'encoder_model': "models.basic_encoder.BasicAutoEncoder",
|
||||||
|
# 'encoder_kwargs': {},
|
||||||
|
# 'dataset_model': "models.usweather_dataset.USWeatherEventsDataset",
|
||||||
|
# 'dataset_kwargs': {"path": "weather-events"},
|
||||||
|
# 'corruption_model': "models.random_corruption.RandomCorruption",
|
||||||
|
# 'corruption_kwargs': {},
|
||||||
|
# },
|
||||||
|
# {
|
||||||
|
# 'name': "US Weather Events on sparse L1 auto-encoder",
|
||||||
|
# 'encoder_model': "models.sparse_encoder.SparseL1AutoEncoder",
|
||||||
|
# 'encoder_kwargs': {},
|
||||||
|
# 'dataset_model': "models.usweather_dataset.USWeatherEventsDataset",
|
||||||
|
# 'dataset_kwargs': {"path": "weather-events"},
|
||||||
|
# 'corruption_model': "models.random_corruption.RandomCorruption",
|
||||||
|
# 'corruption_kwargs': {},
|
||||||
|
# },
|
||||||
|
# {
|
||||||
|
# 'name': "US Weather Events on denoising auto-encoder",
|
||||||
|
# 'encoder_model': "models.denoising_encoder.DenoisingAutoEncoder",
|
||||||
|
# 'encoder_kwargs': {'input_corruption_model': "models.random_corruption.RandomCorruption"},
|
||||||
|
# 'dataset_model': "models.usweather_dataset.USWeatherEventsDataset",
|
||||||
|
# 'dataset_kwargs': {"path": "weather-events"},
|
||||||
|
# 'corruption_model': "models.random_corruption.RandomCorruption",
|
||||||
|
# 'corruption_kwargs': {},
|
||||||
|
# },
|
||||||
|
# {
|
||||||
|
# 'name': "US Weather Events on contractive auto-encoder",
|
||||||
|
# 'encoder_model': "models.contractive_encoder.ContractiveAutoEncoder",
|
||||||
|
# 'encoder_kwargs': {},
|
||||||
|
# 'dataset_model': "models.usweather_dataset.USWeatherEventsDataset",
|
||||||
|
# 'dataset_kwargs': {"path": "weather-events"},
|
||||||
|
# 'corruption_model': "models.random_corruption.RandomCorruption",
|
||||||
|
# 'corruption_kwargs': {},
|
||||||
|
# },
|
||||||
|
# {
|
||||||
|
# 'name': "US Weather Events on variational auto-encoder",
|
||||||
|
# 'encoder_model': "models.variational_encoder.VariationalAutoEncoder",
|
||||||
|
# 'encoder_kwargs': {},
|
||||||
|
# 'dataset_model': "models.usweather_dataset.USWeatherEventsDataset",
|
||||||
|
# 'dataset_kwargs': {"path": "weather-events"},
|
||||||
|
# 'corruption_model': "models.random_corruption.RandomCorruption",
|
||||||
|
# 'corruption_kwargs': {},
|
||||||
|
# },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
2
main.py
2
main.py
|
@ -44,7 +44,7 @@ def run_tests():
|
||||||
test_run = TestRun(dataset=dataset, encoder=encoder, corruption=corruption)
|
test_run = TestRun(dataset=dataset, encoder=encoder, corruption=corruption)
|
||||||
|
|
||||||
# Run TestRun
|
# Run TestRun
|
||||||
test_run.run(retrain=True)
|
test_run.run(retrain=False)
|
||||||
|
|
||||||
# Cleanup to avoid out-of-memory situations when running lots of tests
|
# Cleanup to avoid out-of-memory situations when running lots of tests
|
||||||
del test_run
|
del test_run
|
||||||
|
|
|
@ -73,7 +73,8 @@ class ContractiveAutoEncoder(BaseEncoder):
|
||||||
weights = self.state_dict()['encoder.2.weight']
|
weights = self.state_dict()['encoder.2.weight']
|
||||||
|
|
||||||
# Hadamard product
|
# Hadamard product
|
||||||
hidden_output = hidden_output.reshape(hidden_output.shape[0], hidden_output.shape[2])
|
if len(hidden_output.shape) > 2:
|
||||||
|
hidden_output = hidden_output.reshape(hidden_output.shape[0], hidden_output.shape[2])
|
||||||
dh = hidden_output * (1 - hidden_output)
|
dh = hidden_output * (1 - hidden_output)
|
||||||
|
|
||||||
# Sum through input dimension to improve efficiency (suggested in reference)
|
# Sum through input dimension to improve efficiency (suggested in reference)
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
import csv
|
import csv
|
||||||
import os
|
import os
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
@ -26,31 +25,22 @@ class USWeatherLoss(_Loss):
|
||||||
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
||||||
losses = []
|
losses = []
|
||||||
start = 0
|
start = 0
|
||||||
|
|
||||||
length = len(self.dataset._labels['Type'])
|
length = len(self.dataset._labels['Type'])
|
||||||
# Type is 1-hot encoded, so use cross entropy loss
|
# Type is 1-hot encoded, so use cross entropy loss
|
||||||
losses.append(self.ce_loss(input[start:start+length], torch.argmax(target[start:start+length].long(), dim=1)))
|
losses.append(self.ce_loss(input[:, start:start+length], torch.argmax(target[:, start:start+length].long(), dim=1)))
|
||||||
start += length
|
start += length
|
||||||
length = len(self.dataset._labels['Severity'])
|
length = len(self.dataset._labels['Severity'])
|
||||||
# Severity is 1-hot encoded, so use cross entropy loss
|
# Severity is 1-hot encoded, so use cross entropy loss
|
||||||
losses.append(self.ce_loss(input[start:start+length], torch.argmax(target[start:start+length].long(), dim=1)))
|
losses.append(self.ce_loss(input[:, start:start+length], torch.argmax(target[:, start:start+length].long(), dim=1)))
|
||||||
start += length
|
start += length
|
||||||
# Start time is a number, so use L1 loss
|
|
||||||
losses.append(self.l1_loss(input[start], target[start]))
|
|
||||||
# End time is a number, so use L1 loss
|
|
||||||
losses.append(self.l1_loss(input[start + 1], target[start + 1]))
|
|
||||||
start += 2
|
|
||||||
length = len(self.dataset._labels['TimeZone'])
|
length = len(self.dataset._labels['TimeZone'])
|
||||||
# TimeZone is 1-hot encoded, so use cross entropy loss
|
# TimeZone is 1-hot encoded, so use cross entropy loss
|
||||||
losses.append(self.ce_loss(input[start:start+length], torch.argmax(target[start:start+length].long(), dim=1)))
|
losses.append(self.ce_loss(input[:, start:start+length], torch.argmax(target[:, start:start+length].long(), dim=1)))
|
||||||
start += length
|
start += length
|
||||||
# Location latitude is a number, so use L1 loss
|
|
||||||
losses.append(self.l1_loss(input[start], target[start]))
|
|
||||||
# Location longitude is a number, so use L1 loss
|
|
||||||
losses.append(self.l1_loss(input[start + 1], target[start + 1]))
|
|
||||||
start += 2
|
|
||||||
length = len(self.dataset._labels['State'])
|
length = len(self.dataset._labels['State'])
|
||||||
# State is 1-hot encoded, so use cross entropy loss
|
# State is 1-hot encoded, so use cross entropy loss
|
||||||
losses.append(self.ce_loss(input[start:start+length], torch.argmax(target[start:start+length].long(), dim=1)))
|
losses.append(self.ce_loss(input[:, start:start+length], torch.argmax(target[:, start:start+length].long(), dim=1)))
|
||||||
return sum(losses)
|
return sum(losses)
|
||||||
|
|
||||||
|
|
||||||
|
@ -110,23 +100,9 @@ class USWeatherEventsDataset(BaseDataset):
|
||||||
# 1-hot encoded event severity columns
|
# 1-hot encoded event severity columns
|
||||||
[int(row['Severity'] == self._labels['Severity'][i]) for i in range(len(self._labels['Severity']))] +
|
[int(row['Severity'] == self._labels['Severity'][i]) for i in range(len(self._labels['Severity']))] +
|
||||||
|
|
||||||
[
|
|
||||||
# Start time as unix timestamp
|
|
||||||
datetime.strptime(row['StartTime(UTC)'], "%Y-%m-%d %H:%M:%S").timestamp(),
|
|
||||||
# End time as unix timestamp
|
|
||||||
datetime.strptime(row['EndTime(UTC)'], "%Y-%m-%d %H:%M:%S").timestamp()
|
|
||||||
] +
|
|
||||||
|
|
||||||
# 1-hot encoded event timezone columns
|
# 1-hot encoded event timezone columns
|
||||||
[int(row['TimeZone'] == self._labels['TimeZone'][i]) for i in range(len(self._labels['TimeZone']))] +
|
[int(row['TimeZone'] == self._labels['TimeZone'][i]) for i in range(len(self._labels['TimeZone']))] +
|
||||||
|
|
||||||
[
|
|
||||||
# Location Latitude as float
|
|
||||||
float(row['LocationLat']),
|
|
||||||
# Location Longitude as float
|
|
||||||
float(row['LocationLng']),
|
|
||||||
] +
|
|
||||||
|
|
||||||
# 1-hot encoded event state columns
|
# 1-hot encoded event state columns
|
||||||
[int(row['State'] == self._labels['State'][i]) for i in range(len(self._labels['State']))]
|
[int(row['State'] == self._labels['State'][i]) for i in range(len(self._labels['State']))]
|
||||||
|
|
||||||
|
@ -151,7 +127,7 @@ class USWeatherEventsDataset(BaseDataset):
|
||||||
|
|
||||||
# train_data, test_data = self._data[:2500000], self._data[2500000:]
|
# train_data, test_data = self._data[:2500000], self._data[2500000:]
|
||||||
# Speed up training a bit
|
# Speed up training a bit
|
||||||
train_data, test_data = self._data[:50000], self._data[100000:150000]
|
train_data, test_data = self._data[:250000], self._data[250000:500000]
|
||||||
|
|
||||||
self._trainset = self.__class__.get_new(name=f"{self.name} Training", data=train_data, labels=self._labels,
|
self._trainset = self.__class__.get_new(name=f"{self.name} Training", data=train_data, labels=self._labels,
|
||||||
source_path=self._source_path)
|
source_path=self._source_path)
|
||||||
|
@ -167,13 +143,11 @@ class USWeatherEventsDataset(BaseDataset):
|
||||||
size = 0
|
size = 0
|
||||||
size += len(labels['Type'])
|
size += len(labels['Type'])
|
||||||
size += len(labels['Severity'])
|
size += len(labels['Severity'])
|
||||||
size += 2
|
|
||||||
size += len(labels['TimeZone'])
|
size += len(labels['TimeZone'])
|
||||||
size += 2
|
|
||||||
size += len(labels['State'])
|
size += len(labels['State'])
|
||||||
return size
|
return size
|
||||||
else:
|
else:
|
||||||
return 69
|
return 65
|
||||||
|
|
||||||
def __getitem__(self, item):
|
def __getitem__(self, item):
|
||||||
data = self._data[item]
|
data = self._data[item]
|
||||||
|
@ -196,15 +170,9 @@ class USWeatherEventsDataset(BaseDataset):
|
||||||
length = len(self._labels['Severity'])
|
length = len(self._labels['Severity'])
|
||||||
severities = output[start:start+length]
|
severities = output[start:start+length]
|
||||||
start += length
|
start += length
|
||||||
start_time = output[start]
|
|
||||||
end_time = output[start+1]
|
|
||||||
start += 2
|
|
||||||
length = len(self._labels['TimeZone'])
|
length = len(self._labels['TimeZone'])
|
||||||
timezones = output[start:start+length]
|
timezones = output[start:start+length]
|
||||||
start += length
|
start += length
|
||||||
location_lat = output[start]
|
|
||||||
location_lng = output[start+1]
|
|
||||||
start += 2
|
|
||||||
length = len(self._labels['State'])
|
length = len(self._labels['State'])
|
||||||
states = output[start:start+length]
|
states = output[start:start+length]
|
||||||
|
|
||||||
|
@ -214,14 +182,10 @@ class USWeatherEventsDataset(BaseDataset):
|
||||||
timezone = self._labels['TimeZone'][timezones.index(max(timezones))]
|
timezone = self._labels['TimeZone'][timezones.index(max(timezones))]
|
||||||
state = self._labels['State'][states.index(max(states))]
|
state = self._labels['State'][states.index(max(states))]
|
||||||
|
|
||||||
# Convert timestamp float into string time
|
return [event_type, severity, timezone, state]
|
||||||
start_time = datetime.fromtimestamp(start_time).strftime("%Y-%m-%d %H:%M:%S")
|
|
||||||
end_time = datetime.fromtimestamp(end_time).strftime("%Y-%m-%d %H:%M:%S")
|
|
||||||
|
|
||||||
return [event_type, severity, start_time, end_time, timezone, location_lat, location_lng, state]
|
|
||||||
|
|
||||||
def save_batch_to_sample(self, batch, filename):
|
def save_batch_to_sample(self, batch, filename):
|
||||||
res = ["Type,Severity,StartTime(UTC),EndTime(UTC),TimeZone,LocationLat,LocationLng,State\n"]
|
res = ["Type,Severity,TimeZone,State\n"]
|
||||||
|
|
||||||
for row in batch:
|
for row in batch:
|
||||||
row = row.tolist()
|
row = row.tolist()
|
||||||
|
|
Loading…
Reference in a new issue