Allow loss function to be defined by the dataset. Add specialized loss function for US weather dataset
This commit is contained in:
parent
f6a19c4921
commit
e4c51e2d3d
10
main.py
10
main.py
|
@ -36,13 +36,21 @@ def run_tests():
|
||||||
dataset = dataset_model(**test['dataset_kwargs'])
|
dataset = dataset_model(**test['dataset_kwargs'])
|
||||||
if test['encoder_kwargs'].get('input_shape', None) is None:
|
if test['encoder_kwargs'].get('input_shape', None) is None:
|
||||||
test['encoder_kwargs']['input_shape'] = dataset.get_input_shape()
|
test['encoder_kwargs']['input_shape'] = dataset.get_input_shape()
|
||||||
|
if test['encoder_kwargs'].get('loss_function', None) is None:
|
||||||
|
test['encoder_kwargs']['loss_function'] = dataset.get_loss_function()
|
||||||
encoder = encoder_model(**test['encoder_kwargs'])
|
encoder = encoder_model(**test['encoder_kwargs'])
|
||||||
encoder.after_init()
|
encoder.after_init()
|
||||||
corruption = corruption_model(**test['corruption_kwargs'])
|
corruption = corruption_model(**test['corruption_kwargs'])
|
||||||
test_run = TestRun(dataset=dataset, encoder=encoder, corruption=corruption)
|
test_run = TestRun(dataset=dataset, encoder=encoder, corruption=corruption)
|
||||||
|
|
||||||
# Run TestRun
|
# Run TestRun
|
||||||
test_run.run(retrain=False)
|
test_run.run(retrain=True)
|
||||||
|
|
||||||
|
# Cleanup to avoid out-of-memory situations when running lots of tests
|
||||||
|
del test_run
|
||||||
|
del corruption
|
||||||
|
del encoder
|
||||||
|
del dataset
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
|
@ -65,6 +65,9 @@ class BaseDataset(Dataset):
|
||||||
def get_input_shape(self):
|
def get_input_shape(self):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def get_loss_function(self):
|
||||||
|
return torch.nn.MSELoss()
|
||||||
|
|
||||||
def _subdivide(self, amount: Union[int, float]):
|
def _subdivide(self, amount: Union[int, float]):
|
||||||
if self._data is None:
|
if self._data is None:
|
||||||
raise ValueError("Cannot subdivide! Data not loaded, call `load()` first to load data")
|
raise ValueError("Cannot subdivide! Data not loaded, call `load()` first to load data")
|
||||||
|
|
|
@ -19,7 +19,7 @@ class BaseEncoder(torch.nn.Module):
|
||||||
# Based on https://medium.com/pytorch/implementing-an-autoencoder-in-pytorch-19baa22647d1
|
# Based on https://medium.com/pytorch/implementing-an-autoencoder-in-pytorch-19baa22647d1
|
||||||
name = "BaseEncoder"
|
name = "BaseEncoder"
|
||||||
|
|
||||||
def __init__(self, name: Optional[str] = None, input_shape: int = 0):
|
def __init__(self, name: Optional[str] = None, input_shape: int = 0, loss_function=None):
|
||||||
super(BaseEncoder, self).__init__()
|
super(BaseEncoder, self).__init__()
|
||||||
self.log = logging.getLogger(self.__class__.__name__)
|
self.log = logging.getLogger(self.__class__.__name__)
|
||||||
|
|
||||||
|
@ -51,7 +51,10 @@ class BaseEncoder(torch.nn.Module):
|
||||||
self.optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
|
self.optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
|
||||||
|
|
||||||
# Mean Squared Error loss function
|
# Mean Squared Error loss function
|
||||||
self.loss_function = torch.nn.MSELoss()
|
if loss_function is not None:
|
||||||
|
self.loss_function = loss_function
|
||||||
|
else:
|
||||||
|
self.loss_function = torch.nn.MSELoss()
|
||||||
|
|
||||||
def after_init(self):
|
def after_init(self):
|
||||||
self.log.info(f"Auto-encoder {self.__class__.__name__} initialized with "
|
self.log.info(f"Auto-encoder {self.__class__.__name__} initialized with "
|
||||||
|
@ -163,6 +166,9 @@ class BaseEncoder(torch.nn.Module):
|
||||||
|
|
||||||
# display the epoch training loss
|
# display the epoch training loss
|
||||||
self.log.info("epoch : {}/{}, loss = {:.6f}".format(epoch + 1, epochs, loss))
|
self.log.info("epoch : {}/{}, loss = {:.6f}".format(epoch + 1, epochs, loss))
|
||||||
|
self.log.debug(f"Expected: {compare_features.cpu().detach().numpy()[0].tolist()}")
|
||||||
|
self.log.debug(f"Outputs: {outputs_for_loss.cpu().detach().numpy()[0].tolist()}")
|
||||||
|
self.log.debug(f"Loss: {train_loss}")
|
||||||
losses.append(loss)
|
losses.append(loss)
|
||||||
|
|
||||||
# Every 5 epochs, save a test image
|
# Every 5 epochs, save a test image
|
||||||
|
|
|
@ -9,10 +9,10 @@ class BasicAutoEncoder(BaseEncoder):
|
||||||
# Based on https://medium.com/pytorch/implementing-an-autoencoder-in-pytorch-19baa22647d1
|
# Based on https://medium.com/pytorch/implementing-an-autoencoder-in-pytorch-19baa22647d1
|
||||||
name = "BasicAutoEncoder"
|
name = "BasicAutoEncoder"
|
||||||
|
|
||||||
def __init__(self, name: Optional[str] = None, input_shape: int = 0):
|
def __init__(self, name: Optional[str] = None, input_shape: int = 0, loss_function=None):
|
||||||
self.log = logging.getLogger(self.__class__.__name__)
|
self.log = logging.getLogger(self.__class__.__name__)
|
||||||
|
|
||||||
# Call superclass to initialize parameters.
|
# Call superclass to initialize parameters.
|
||||||
super(BasicAutoEncoder, self).__init__(name, input_shape)
|
super(BasicAutoEncoder, self).__init__(name, input_shape, loss_function)
|
||||||
|
|
||||||
# Network, optimizer and loss function are the same as defined in the base encoder.
|
# Network, optimizer and loss function are the same as defined in the base encoder.
|
||||||
|
|
|
@ -13,11 +13,11 @@ class ContractiveAutoEncoder(BaseEncoder):
|
||||||
# Based on https://github.com/avijit9/Contractive_Autoencoder_in_Pytorch/blob/master/CAE_pytorch.py
|
# Based on https://github.com/avijit9/Contractive_Autoencoder_in_Pytorch/blob/master/CAE_pytorch.py
|
||||||
name = "ContractiveAutoEncoder"
|
name = "ContractiveAutoEncoder"
|
||||||
|
|
||||||
def __init__(self, name: Optional[str] = None, input_shape: int = 0, regularizer_weight: float = 1e-4):
|
def __init__(self, name: Optional[str] = None, input_shape: int = 0, loss_function=None, regularizer_weight: float = 1e-4):
|
||||||
self.log = logging.getLogger(self.__class__.__name__)
|
self.log = logging.getLogger(self.__class__.__name__)
|
||||||
|
|
||||||
# Call superclass to initialize parameters.
|
# Call superclass to initialize parameters.
|
||||||
super(ContractiveAutoEncoder, self).__init__(name, input_shape)
|
super(ContractiveAutoEncoder, self).__init__(name, input_shape, loss_function)
|
||||||
|
|
||||||
self.regularizer_weight = regularizer_weight
|
self.regularizer_weight = regularizer_weight
|
||||||
|
|
||||||
|
|
|
@ -15,12 +15,12 @@ class DenoisingAutoEncoder(BaseEncoder):
|
||||||
# Based on https://github.com/pranjaldatta/Denoising-Autoencoder-in-Pytorch/blob/master/DenoisingAutoencoder.ipynb
|
# Based on https://github.com/pranjaldatta/Denoising-Autoencoder-in-Pytorch/blob/master/DenoisingAutoencoder.ipynb
|
||||||
name = "DenoisingAutoEncoder"
|
name = "DenoisingAutoEncoder"
|
||||||
|
|
||||||
def __init__(self, name: Optional[str] = None, input_shape: int = 0,
|
def __init__(self, name: Optional[str] = None, input_shape: int = 0, loss_function=None,
|
||||||
input_corruption_model: BaseCorruption = NoCorruption):
|
input_corruption_model: BaseCorruption = NoCorruption):
|
||||||
self.log = logging.getLogger(self.__class__.__name__)
|
self.log = logging.getLogger(self.__class__.__name__)
|
||||||
|
|
||||||
# Call superclass to initialize parameters.
|
# Call superclass to initialize parameters.
|
||||||
super(DenoisingAutoEncoder, self).__init__(name, input_shape)
|
super(DenoisingAutoEncoder, self).__init__(name, input_shape, loss_function)
|
||||||
|
|
||||||
# Network, optimizer and loss function are the same as defined in the base encoder.
|
# Network, optimizer and loss function are the same as defined in the base encoder.
|
||||||
|
|
||||||
|
|
|
@ -14,11 +14,11 @@ class SparseL1AutoEncoder(BaseEncoder):
|
||||||
# Based on https://debuggercafe.com/sparse-autoencoders-using-l1-regularization-with-pytorch/
|
# Based on https://debuggercafe.com/sparse-autoencoders-using-l1-regularization-with-pytorch/
|
||||||
name = "SparseL1AutoEncoder"
|
name = "SparseL1AutoEncoder"
|
||||||
|
|
||||||
def __init__(self, name: Optional[str] = None, input_shape: int = 0, regularization_parameter: float = 0.001):
|
def __init__(self, name: Optional[str] = None, input_shape: int = 0, loss_function=None, regularization_parameter: float = 0.001):
|
||||||
self.log = logging.getLogger(self.__class__.__name__)
|
self.log = logging.getLogger(self.__class__.__name__)
|
||||||
|
|
||||||
# Call superclass to initialize parameters.
|
# Call superclass to initialize parameters.
|
||||||
super(SparseL1AutoEncoder, self).__init__(name, input_shape)
|
super(SparseL1AutoEncoder, self).__init__(name, input_shape, loss_function)
|
||||||
|
|
||||||
# Override parameters to custom values for this encoder type
|
# Override parameters to custom values for this encoder type
|
||||||
|
|
||||||
|
|
|
@ -7,11 +7,53 @@ from typing import Optional
|
||||||
|
|
||||||
import numpy
|
import numpy
|
||||||
import torch
|
import torch
|
||||||
|
from torch.nn.modules.loss import _Loss
|
||||||
|
|
||||||
from config import DATASET_STORAGE_BASE_PATH
|
from config import DATASET_STORAGE_BASE_PATH
|
||||||
from models.base_dataset import BaseDataset
|
from models.base_dataset import BaseDataset
|
||||||
|
|
||||||
|
|
||||||
|
class USWeatherLoss(_Loss):
|
||||||
|
__constants__ = ['reduction']
|
||||||
|
|
||||||
|
def __init__(self, dataset=None, size_average=None, reduce=None, reduction: str = 'mean') -> None:
|
||||||
|
self.dataset = dataset
|
||||||
|
super(USWeatherLoss, self).__init__(size_average, reduce, reduction)
|
||||||
|
|
||||||
|
self.ce_loss = torch.nn.CrossEntropyLoss()
|
||||||
|
self.l1_loss = torch.nn.L1Loss()
|
||||||
|
|
||||||
|
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
||||||
|
losses = []
|
||||||
|
start = 0
|
||||||
|
length = len(self.dataset._labels['Type'])
|
||||||
|
# Type is 1-hot encoded, so use cross entropy loss
|
||||||
|
losses.append(self.ce_loss(input[start:start+length], torch.argmax(target[start:start+length].long(), dim=1)))
|
||||||
|
start += length
|
||||||
|
length = len(self.dataset._labels['Severity'])
|
||||||
|
# Severity is 1-hot encoded, so use cross entropy loss
|
||||||
|
losses.append(self.ce_loss(input[start:start+length], torch.argmax(target[start:start+length].long(), dim=1)))
|
||||||
|
start += length
|
||||||
|
# Start time is a number, so use L1 loss
|
||||||
|
losses.append(self.l1_loss(input[start], target[start]))
|
||||||
|
# End time is a number, so use L1 loss
|
||||||
|
losses.append(self.l1_loss(input[start + 1], target[start + 1]))
|
||||||
|
start += 2
|
||||||
|
length = len(self.dataset._labels['TimeZone'])
|
||||||
|
# TimeZone is 1-hot encoded, so use cross entropy loss
|
||||||
|
losses.append(self.ce_loss(input[start:start+length], torch.argmax(target[start:start+length].long(), dim=1)))
|
||||||
|
start += length
|
||||||
|
# Location latitude is a number, so use L1 loss
|
||||||
|
losses.append(self.l1_loss(input[start], target[start]))
|
||||||
|
# Location longitude is a number, so use L1 loss
|
||||||
|
losses.append(self.l1_loss(input[start + 1], target[start + 1]))
|
||||||
|
start += 2
|
||||||
|
length = len(self.dataset._labels['State'])
|
||||||
|
# State is 1-hot encoded, so use cross entropy loss
|
||||||
|
losses.append(self.ce_loss(input[start:start+length], torch.argmax(target[start:start+length].long(), dim=1)))
|
||||||
|
return sum(losses)
|
||||||
|
|
||||||
|
|
||||||
class USWeatherEventsDataset(BaseDataset):
|
class USWeatherEventsDataset(BaseDataset):
|
||||||
# Source: https://smoosavi.org/datasets/lstw
|
# Source: https://smoosavi.org/datasets/lstw
|
||||||
# https://www.kaggle.com/sobhanmoosavi/us-weather-events
|
# https://www.kaggle.com/sobhanmoosavi/us-weather-events
|
||||||
|
@ -107,7 +149,9 @@ class USWeatherEventsDataset(BaseDataset):
|
||||||
pickle.dump(dict(self._labels), f)
|
pickle.dump(dict(self._labels), f)
|
||||||
self.log.info("Cached version created.")
|
self.log.info("Cached version created.")
|
||||||
|
|
||||||
train_data, test_data = self._data[:2500000], self._data[2500000:]
|
# train_data, test_data = self._data[:2500000], self._data[2500000:]
|
||||||
|
# Speed up training a bit
|
||||||
|
train_data, test_data = self._data[:50000], self._data[100000:150000]
|
||||||
|
|
||||||
self._trainset = self.__class__.get_new(name=f"{self.name} Training", data=train_data, labels=self._labels,
|
self._trainset = self.__class__.get_new(name=f"{self.name} Training", data=train_data, labels=self._labels,
|
||||||
source_path=self._source_path)
|
source_path=self._source_path)
|
||||||
|
@ -140,42 +184,48 @@ class USWeatherEventsDataset(BaseDataset):
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
def output_to_result_row(self, output):
|
||||||
|
# Get 1-hot encoded values as list per value, and other values as value
|
||||||
|
if not isinstance(output, list):
|
||||||
|
output = output.tolist()
|
||||||
|
|
||||||
|
start = 0
|
||||||
|
length = len(self._labels['Type'])
|
||||||
|
event_types = output[start:start+length]
|
||||||
|
start += length
|
||||||
|
length = len(self._labels['Severity'])
|
||||||
|
severities = output[start:start+length]
|
||||||
|
start += length
|
||||||
|
start_time = output[start]
|
||||||
|
end_time = output[start+1]
|
||||||
|
start += 2
|
||||||
|
length = len(self._labels['TimeZone'])
|
||||||
|
timezones = output[start:start+length]
|
||||||
|
start += length
|
||||||
|
location_lat = output[start]
|
||||||
|
location_lng = output[start+1]
|
||||||
|
start += 2
|
||||||
|
length = len(self._labels['State'])
|
||||||
|
states = output[start:start+length]
|
||||||
|
|
||||||
|
# Convert 1-hot encodings to normal labels, assume highest value as the true value.
|
||||||
|
event_type = self._labels['Type'][event_types.index(max(event_types))]
|
||||||
|
severity = self._labels['Severity'][severities.index(max(severities))]
|
||||||
|
timezone = self._labels['TimeZone'][timezones.index(max(timezones))]
|
||||||
|
state = self._labels['State'][states.index(max(states))]
|
||||||
|
|
||||||
|
# Convert timestamp float into string time
|
||||||
|
start_time = datetime.fromtimestamp(start_time).strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
end_time = datetime.fromtimestamp(end_time).strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
|
||||||
|
return [event_type, severity, start_time, end_time, timezone, location_lat, location_lng, state]
|
||||||
|
|
||||||
def save_batch_to_sample(self, batch, filename):
|
def save_batch_to_sample(self, batch, filename):
|
||||||
res = ["Type,Severity,StartTime(UTC),EndTime(UTC),TimeZone,LocationLat,LocationLng,State\n"]
|
res = ["Type,Severity,StartTime(UTC),EndTime(UTC),TimeZone,LocationLat,LocationLng,State\n"]
|
||||||
|
|
||||||
for row in batch:
|
for row in batch:
|
||||||
# Get 1-hot encoded values as list per value, and other values as value
|
|
||||||
row = row.tolist()
|
row = row.tolist()
|
||||||
start = 0
|
res.append(",".join(map(lambda x: f'{x}', self.output_to_result_row(row)))+"\n")
|
||||||
length = len(self._labels['Type'])
|
|
||||||
event_types = row[start:start+length]
|
|
||||||
start += length
|
|
||||||
length = len(self._labels['Severity'])
|
|
||||||
severities = row[start:start+length]
|
|
||||||
start += length
|
|
||||||
start_time = row[start]
|
|
||||||
end_time = row[start+1]
|
|
||||||
start += 2
|
|
||||||
length = len(self._labels['TimeZone'])
|
|
||||||
timezones = row[start:start+length]
|
|
||||||
start += length
|
|
||||||
location_lat = row[start]
|
|
||||||
location_lng = row[start+1]
|
|
||||||
start += 2
|
|
||||||
length = len(self._labels['State'])
|
|
||||||
states = row[start:start+length]
|
|
||||||
|
|
||||||
# Convert 1-hot encodings to normal labels, assume highest value as the true value.
|
|
||||||
event_type = self._labels['Type'][event_types.index(max(event_types))]
|
|
||||||
severity = self._labels['Severity'][severities.index(max(severities))]
|
|
||||||
timezone = self._labels['TimeZone'][timezones.index(max(timezones))]
|
|
||||||
state = self._labels['State'][states.index(max(states))]
|
|
||||||
|
|
||||||
# Convert timestamp float into string time
|
|
||||||
start_time = datetime.fromtimestamp(start_time).strftime("%Y-%m-%d %H:%M:%S")
|
|
||||||
end_time = datetime.fromtimestamp(end_time).strftime("%Y-%m-%d %H:%M:%S")
|
|
||||||
|
|
||||||
res.append(f"{event_type},{severity},{start_time},{end_time},{timezone},{location_lat},{location_lng},{state}\n")
|
|
||||||
|
|
||||||
with open(f"{filename}.csv", "w") as f:
|
with open(f"{filename}.csv", "w") as f:
|
||||||
f.writelines(res)
|
f.writelines(res)
|
||||||
|
@ -186,7 +236,10 @@ class USWeatherEventsDataset(BaseDataset):
|
||||||
|
|
||||||
total_score = 0
|
total_score = 0
|
||||||
for i in range(len(originals)):
|
for i in range(len(originals)):
|
||||||
original, recon = originals[i], reconstruction[i]
|
original, recon = self.output_to_result_row(originals[i]), self.output_to_result_row(reconstruction[i])
|
||||||
total_score += sum(int(original[j] == recon[j]) for j in range(len(original))) / len(original)
|
total_score += sum(int(original[j] == recon[j]) for j in range(len(original))) / len(original)
|
||||||
|
|
||||||
return total_score / len(originals)
|
return total_score / len(originals)
|
||||||
|
|
||||||
|
def get_loss_function(self):
|
||||||
|
return USWeatherLoss(dataset=self)
|
||||||
|
|
|
@ -12,11 +12,11 @@ class VariationalAutoEncoder(BaseEncoder):
|
||||||
# and https://github.com/pytorch/examples/blob/master/vae/main.py
|
# and https://github.com/pytorch/examples/blob/master/vae/main.py
|
||||||
name = "VariationalAutoEncoder"
|
name = "VariationalAutoEncoder"
|
||||||
|
|
||||||
def __init__(self, name: Optional[str] = None, input_shape: int = 0):
|
def __init__(self, name: Optional[str] = None, input_shape: int = 0, loss_function=None):
|
||||||
self.log = logging.getLogger(self.__class__.__name__)
|
self.log = logging.getLogger(self.__class__.__name__)
|
||||||
|
|
||||||
# Call superclass to initialize parameters.
|
# Call superclass to initialize parameters.
|
||||||
super(VariationalAutoEncoder, self).__init__(name, input_shape)
|
super(VariationalAutoEncoder, self).__init__(name, input_shape, loss_function)
|
||||||
|
|
||||||
# VAE needs intermediate output of the encoder stage, so split up the network into encoder/decoder
|
# VAE needs intermediate output of the encoder stage, so split up the network into encoder/decoder
|
||||||
# with no ReLU layer at the end of the encoder so we have access to the mu and variance.
|
# with no ReLU layer at the end of the encoder so we have access to the mu and variance.
|
||||||
|
|
Loading…
Reference in a new issue