Allow loss function to be defined by the dataset. Add specialized loss function for US weather dataset

This commit is contained in:
Kevin Alberts 2021-01-21 21:47:00 +01:00
parent f6a19c4921
commit e4c51e2d3d
Signed by: Kurocon
GPG key ID: BCD496FEBA0C6BC1
9 changed files with 116 additions and 46 deletions

10
main.py
View file

@ -36,13 +36,21 @@ def run_tests():
dataset = dataset_model(**test['dataset_kwargs'])
if test['encoder_kwargs'].get('input_shape', None) is None:
test['encoder_kwargs']['input_shape'] = dataset.get_input_shape()
if test['encoder_kwargs'].get('loss_function', None) is None:
test['encoder_kwargs']['loss_function'] = dataset.get_loss_function()
encoder = encoder_model(**test['encoder_kwargs'])
encoder.after_init()
corruption = corruption_model(**test['corruption_kwargs'])
test_run = TestRun(dataset=dataset, encoder=encoder, corruption=corruption)
# Run TestRun
test_run.run(retrain=False)
test_run.run(retrain=True)
# Cleanup to avoid out-of-memory situations when running lots of tests
del test_run
del corruption
del encoder
del dataset
if __name__ == '__main__':

View file

@ -65,6 +65,9 @@ class BaseDataset(Dataset):
def get_input_shape(self):
return None
def get_loss_function(self):
return torch.nn.MSELoss()
def _subdivide(self, amount: Union[int, float]):
if self._data is None:
raise ValueError("Cannot subdivide! Data not loaded, call `load()` first to load data")

View file

@ -19,7 +19,7 @@ class BaseEncoder(torch.nn.Module):
# Based on https://medium.com/pytorch/implementing-an-autoencoder-in-pytorch-19baa22647d1
name = "BaseEncoder"
def __init__(self, name: Optional[str] = None, input_shape: int = 0):
def __init__(self, name: Optional[str] = None, input_shape: int = 0, loss_function=None):
super(BaseEncoder, self).__init__()
self.log = logging.getLogger(self.__class__.__name__)
@ -51,7 +51,10 @@ class BaseEncoder(torch.nn.Module):
self.optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
# Mean Squared Error loss function
self.loss_function = torch.nn.MSELoss()
if loss_function is not None:
self.loss_function = loss_function
else:
self.loss_function = torch.nn.MSELoss()
def after_init(self):
self.log.info(f"Auto-encoder {self.__class__.__name__} initialized with "
@ -163,6 +166,9 @@ class BaseEncoder(torch.nn.Module):
# display the epoch training loss
self.log.info("epoch : {}/{}, loss = {:.6f}".format(epoch + 1, epochs, loss))
self.log.debug(f"Expected: {compare_features.cpu().detach().numpy()[0].tolist()}")
self.log.debug(f"Outputs: {outputs_for_loss.cpu().detach().numpy()[0].tolist()}")
self.log.debug(f"Loss: {train_loss}")
losses.append(loss)
# Every 5 epochs, save a test image

View file

@ -9,10 +9,10 @@ class BasicAutoEncoder(BaseEncoder):
# Based on https://medium.com/pytorch/implementing-an-autoencoder-in-pytorch-19baa22647d1
name = "BasicAutoEncoder"
def __init__(self, name: Optional[str] = None, input_shape: int = 0):
def __init__(self, name: Optional[str] = None, input_shape: int = 0, loss_function=None):
self.log = logging.getLogger(self.__class__.__name__)
# Call superclass to initialize parameters.
super(BasicAutoEncoder, self).__init__(name, input_shape)
super(BasicAutoEncoder, self).__init__(name, input_shape, loss_function)
# Network, optimizer and loss function are the same as defined in the base encoder.

View file

@ -13,11 +13,11 @@ class ContractiveAutoEncoder(BaseEncoder):
# Based on https://github.com/avijit9/Contractive_Autoencoder_in_Pytorch/blob/master/CAE_pytorch.py
name = "ContractiveAutoEncoder"
def __init__(self, name: Optional[str] = None, input_shape: int = 0, regularizer_weight: float = 1e-4):
def __init__(self, name: Optional[str] = None, input_shape: int = 0, loss_function=None, regularizer_weight: float = 1e-4):
self.log = logging.getLogger(self.__class__.__name__)
# Call superclass to initialize parameters.
super(ContractiveAutoEncoder, self).__init__(name, input_shape)
super(ContractiveAutoEncoder, self).__init__(name, input_shape, loss_function)
self.regularizer_weight = regularizer_weight

View file

@ -15,12 +15,12 @@ class DenoisingAutoEncoder(BaseEncoder):
# Based on https://github.com/pranjaldatta/Denoising-Autoencoder-in-Pytorch/blob/master/DenoisingAutoencoder.ipynb
name = "DenoisingAutoEncoder"
def __init__(self, name: Optional[str] = None, input_shape: int = 0,
def __init__(self, name: Optional[str] = None, input_shape: int = 0, loss_function=None,
input_corruption_model: BaseCorruption = NoCorruption):
self.log = logging.getLogger(self.__class__.__name__)
# Call superclass to initialize parameters.
super(DenoisingAutoEncoder, self).__init__(name, input_shape)
super(DenoisingAutoEncoder, self).__init__(name, input_shape, loss_function)
# Network, optimizer and loss function are the same as defined in the base encoder.

View file

@ -14,11 +14,11 @@ class SparseL1AutoEncoder(BaseEncoder):
# Based on https://debuggercafe.com/sparse-autoencoders-using-l1-regularization-with-pytorch/
name = "SparseL1AutoEncoder"
def __init__(self, name: Optional[str] = None, input_shape: int = 0, regularization_parameter: float = 0.001):
def __init__(self, name: Optional[str] = None, input_shape: int = 0, loss_function=None, regularization_parameter: float = 0.001):
self.log = logging.getLogger(self.__class__.__name__)
# Call superclass to initialize parameters.
super(SparseL1AutoEncoder, self).__init__(name, input_shape)
super(SparseL1AutoEncoder, self).__init__(name, input_shape, loss_function)
# Override parameters to custom values for this encoder type

View file

@ -7,11 +7,53 @@ from typing import Optional
import numpy
import torch
from torch.nn.modules.loss import _Loss
from config import DATASET_STORAGE_BASE_PATH
from models.base_dataset import BaseDataset
class USWeatherLoss(_Loss):
__constants__ = ['reduction']
def __init__(self, dataset=None, size_average=None, reduce=None, reduction: str = 'mean') -> None:
self.dataset = dataset
super(USWeatherLoss, self).__init__(size_average, reduce, reduction)
self.ce_loss = torch.nn.CrossEntropyLoss()
self.l1_loss = torch.nn.L1Loss()
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
losses = []
start = 0
length = len(self.dataset._labels['Type'])
# Type is 1-hot encoded, so use cross entropy loss
losses.append(self.ce_loss(input[start:start+length], torch.argmax(target[start:start+length].long(), dim=1)))
start += length
length = len(self.dataset._labels['Severity'])
# Severity is 1-hot encoded, so use cross entropy loss
losses.append(self.ce_loss(input[start:start+length], torch.argmax(target[start:start+length].long(), dim=1)))
start += length
# Start time is a number, so use L1 loss
losses.append(self.l1_loss(input[start], target[start]))
# End time is a number, so use L1 loss
losses.append(self.l1_loss(input[start + 1], target[start + 1]))
start += 2
length = len(self.dataset._labels['TimeZone'])
# TimeZone is 1-hot encoded, so use cross entropy loss
losses.append(self.ce_loss(input[start:start+length], torch.argmax(target[start:start+length].long(), dim=1)))
start += length
# Location latitude is a number, so use L1 loss
losses.append(self.l1_loss(input[start], target[start]))
# Location longitude is a number, so use L1 loss
losses.append(self.l1_loss(input[start + 1], target[start + 1]))
start += 2
length = len(self.dataset._labels['State'])
# State is 1-hot encoded, so use cross entropy loss
losses.append(self.ce_loss(input[start:start+length], torch.argmax(target[start:start+length].long(), dim=1)))
return sum(losses)
class USWeatherEventsDataset(BaseDataset):
# Source: https://smoosavi.org/datasets/lstw
# https://www.kaggle.com/sobhanmoosavi/us-weather-events
@ -107,7 +149,9 @@ class USWeatherEventsDataset(BaseDataset):
pickle.dump(dict(self._labels), f)
self.log.info("Cached version created.")
train_data, test_data = self._data[:2500000], self._data[2500000:]
# train_data, test_data = self._data[:2500000], self._data[2500000:]
# Speed up training a bit
train_data, test_data = self._data[:50000], self._data[100000:150000]
self._trainset = self.__class__.get_new(name=f"{self.name} Training", data=train_data, labels=self._labels,
source_path=self._source_path)
@ -140,42 +184,48 @@ class USWeatherEventsDataset(BaseDataset):
return data
def output_to_result_row(self, output):
# Get 1-hot encoded values as list per value, and other values as value
if not isinstance(output, list):
output = output.tolist()
start = 0
length = len(self._labels['Type'])
event_types = output[start:start+length]
start += length
length = len(self._labels['Severity'])
severities = output[start:start+length]
start += length
start_time = output[start]
end_time = output[start+1]
start += 2
length = len(self._labels['TimeZone'])
timezones = output[start:start+length]
start += length
location_lat = output[start]
location_lng = output[start+1]
start += 2
length = len(self._labels['State'])
states = output[start:start+length]
# Convert 1-hot encodings to normal labels, assume highest value as the true value.
event_type = self._labels['Type'][event_types.index(max(event_types))]
severity = self._labels['Severity'][severities.index(max(severities))]
timezone = self._labels['TimeZone'][timezones.index(max(timezones))]
state = self._labels['State'][states.index(max(states))]
# Convert timestamp float into string time
start_time = datetime.fromtimestamp(start_time).strftime("%Y-%m-%d %H:%M:%S")
end_time = datetime.fromtimestamp(end_time).strftime("%Y-%m-%d %H:%M:%S")
return [event_type, severity, start_time, end_time, timezone, location_lat, location_lng, state]
def save_batch_to_sample(self, batch, filename):
res = ["Type,Severity,StartTime(UTC),EndTime(UTC),TimeZone,LocationLat,LocationLng,State\n"]
for row in batch:
# Get 1-hot encoded values as list per value, and other values as value
row = row.tolist()
start = 0
length = len(self._labels['Type'])
event_types = row[start:start+length]
start += length
length = len(self._labels['Severity'])
severities = row[start:start+length]
start += length
start_time = row[start]
end_time = row[start+1]
start += 2
length = len(self._labels['TimeZone'])
timezones = row[start:start+length]
start += length
location_lat = row[start]
location_lng = row[start+1]
start += 2
length = len(self._labels['State'])
states = row[start:start+length]
# Convert 1-hot encodings to normal labels, assume highest value as the true value.
event_type = self._labels['Type'][event_types.index(max(event_types))]
severity = self._labels['Severity'][severities.index(max(severities))]
timezone = self._labels['TimeZone'][timezones.index(max(timezones))]
state = self._labels['State'][states.index(max(states))]
# Convert timestamp float into string time
start_time = datetime.fromtimestamp(start_time).strftime("%Y-%m-%d %H:%M:%S")
end_time = datetime.fromtimestamp(end_time).strftime("%Y-%m-%d %H:%M:%S")
res.append(f"{event_type},{severity},{start_time},{end_time},{timezone},{location_lat},{location_lng},{state}\n")
res.append(",".join(map(lambda x: f'{x}', self.output_to_result_row(row)))+"\n")
with open(f"{filename}.csv", "w") as f:
f.writelines(res)
@ -186,7 +236,10 @@ class USWeatherEventsDataset(BaseDataset):
total_score = 0
for i in range(len(originals)):
original, recon = originals[i], reconstruction[i]
original, recon = self.output_to_result_row(originals[i]), self.output_to_result_row(reconstruction[i])
total_score += sum(int(original[j] == recon[j]) for j in range(len(original))) / len(original)
return total_score / len(originals)
def get_loss_function(self):
return USWeatherLoss(dataset=self)

View file

@ -12,11 +12,11 @@ class VariationalAutoEncoder(BaseEncoder):
# and https://github.com/pytorch/examples/blob/master/vae/main.py
name = "VariationalAutoEncoder"
def __init__(self, name: Optional[str] = None, input_shape: int = 0):
def __init__(self, name: Optional[str] = None, input_shape: int = 0, loss_function=None):
self.log = logging.getLogger(self.__class__.__name__)
# Call superclass to initialize parameters.
super(VariationalAutoEncoder, self).__init__(name, input_shape)
super(VariationalAutoEncoder, self).__init__(name, input_shape, loss_function)
# VAE needs intermediate output of the encoder stage, so split up the network into encoder/decoder
# with no ReLU layer at the end of the encoder so we have access to the mu and variance.