Allow loss function to be defined by the dataset. Add specialized loss function for US weather dataset
This commit is contained in:
parent
f6a19c4921
commit
e4c51e2d3d
10
main.py
10
main.py
|
@ -36,13 +36,21 @@ def run_tests():
|
|||
dataset = dataset_model(**test['dataset_kwargs'])
|
||||
if test['encoder_kwargs'].get('input_shape', None) is None:
|
||||
test['encoder_kwargs']['input_shape'] = dataset.get_input_shape()
|
||||
if test['encoder_kwargs'].get('loss_function', None) is None:
|
||||
test['encoder_kwargs']['loss_function'] = dataset.get_loss_function()
|
||||
encoder = encoder_model(**test['encoder_kwargs'])
|
||||
encoder.after_init()
|
||||
corruption = corruption_model(**test['corruption_kwargs'])
|
||||
test_run = TestRun(dataset=dataset, encoder=encoder, corruption=corruption)
|
||||
|
||||
# Run TestRun
|
||||
test_run.run(retrain=False)
|
||||
test_run.run(retrain=True)
|
||||
|
||||
# Cleanup to avoid out-of-memory situations when running lots of tests
|
||||
del test_run
|
||||
del corruption
|
||||
del encoder
|
||||
del dataset
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -65,6 +65,9 @@ class BaseDataset(Dataset):
|
|||
def get_input_shape(self):
|
||||
return None
|
||||
|
||||
def get_loss_function(self):
|
||||
return torch.nn.MSELoss()
|
||||
|
||||
def _subdivide(self, amount: Union[int, float]):
|
||||
if self._data is None:
|
||||
raise ValueError("Cannot subdivide! Data not loaded, call `load()` first to load data")
|
||||
|
|
|
@ -19,7 +19,7 @@ class BaseEncoder(torch.nn.Module):
|
|||
# Based on https://medium.com/pytorch/implementing-an-autoencoder-in-pytorch-19baa22647d1
|
||||
name = "BaseEncoder"
|
||||
|
||||
def __init__(self, name: Optional[str] = None, input_shape: int = 0):
|
||||
def __init__(self, name: Optional[str] = None, input_shape: int = 0, loss_function=None):
|
||||
super(BaseEncoder, self).__init__()
|
||||
self.log = logging.getLogger(self.__class__.__name__)
|
||||
|
||||
|
@ -51,7 +51,10 @@ class BaseEncoder(torch.nn.Module):
|
|||
self.optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
|
||||
|
||||
# Mean Squared Error loss function
|
||||
self.loss_function = torch.nn.MSELoss()
|
||||
if loss_function is not None:
|
||||
self.loss_function = loss_function
|
||||
else:
|
||||
self.loss_function = torch.nn.MSELoss()
|
||||
|
||||
def after_init(self):
|
||||
self.log.info(f"Auto-encoder {self.__class__.__name__} initialized with "
|
||||
|
@ -163,6 +166,9 @@ class BaseEncoder(torch.nn.Module):
|
|||
|
||||
# display the epoch training loss
|
||||
self.log.info("epoch : {}/{}, loss = {:.6f}".format(epoch + 1, epochs, loss))
|
||||
self.log.debug(f"Expected: {compare_features.cpu().detach().numpy()[0].tolist()}")
|
||||
self.log.debug(f"Outputs: {outputs_for_loss.cpu().detach().numpy()[0].tolist()}")
|
||||
self.log.debug(f"Loss: {train_loss}")
|
||||
losses.append(loss)
|
||||
|
||||
# Every 5 epochs, save a test image
|
||||
|
|
|
@ -9,10 +9,10 @@ class BasicAutoEncoder(BaseEncoder):
|
|||
# Based on https://medium.com/pytorch/implementing-an-autoencoder-in-pytorch-19baa22647d1
|
||||
name = "BasicAutoEncoder"
|
||||
|
||||
def __init__(self, name: Optional[str] = None, input_shape: int = 0):
|
||||
def __init__(self, name: Optional[str] = None, input_shape: int = 0, loss_function=None):
|
||||
self.log = logging.getLogger(self.__class__.__name__)
|
||||
|
||||
# Call superclass to initialize parameters.
|
||||
super(BasicAutoEncoder, self).__init__(name, input_shape)
|
||||
super(BasicAutoEncoder, self).__init__(name, input_shape, loss_function)
|
||||
|
||||
# Network, optimizer and loss function are the same as defined in the base encoder.
|
||||
|
|
|
@ -13,11 +13,11 @@ class ContractiveAutoEncoder(BaseEncoder):
|
|||
# Based on https://github.com/avijit9/Contractive_Autoencoder_in_Pytorch/blob/master/CAE_pytorch.py
|
||||
name = "ContractiveAutoEncoder"
|
||||
|
||||
def __init__(self, name: Optional[str] = None, input_shape: int = 0, regularizer_weight: float = 1e-4):
|
||||
def __init__(self, name: Optional[str] = None, input_shape: int = 0, loss_function=None, regularizer_weight: float = 1e-4):
|
||||
self.log = logging.getLogger(self.__class__.__name__)
|
||||
|
||||
# Call superclass to initialize parameters.
|
||||
super(ContractiveAutoEncoder, self).__init__(name, input_shape)
|
||||
super(ContractiveAutoEncoder, self).__init__(name, input_shape, loss_function)
|
||||
|
||||
self.regularizer_weight = regularizer_weight
|
||||
|
||||
|
|
|
@ -15,12 +15,12 @@ class DenoisingAutoEncoder(BaseEncoder):
|
|||
# Based on https://github.com/pranjaldatta/Denoising-Autoencoder-in-Pytorch/blob/master/DenoisingAutoencoder.ipynb
|
||||
name = "DenoisingAutoEncoder"
|
||||
|
||||
def __init__(self, name: Optional[str] = None, input_shape: int = 0,
|
||||
def __init__(self, name: Optional[str] = None, input_shape: int = 0, loss_function=None,
|
||||
input_corruption_model: BaseCorruption = NoCorruption):
|
||||
self.log = logging.getLogger(self.__class__.__name__)
|
||||
|
||||
# Call superclass to initialize parameters.
|
||||
super(DenoisingAutoEncoder, self).__init__(name, input_shape)
|
||||
super(DenoisingAutoEncoder, self).__init__(name, input_shape, loss_function)
|
||||
|
||||
# Network, optimizer and loss function are the same as defined in the base encoder.
|
||||
|
||||
|
|
|
@ -14,11 +14,11 @@ class SparseL1AutoEncoder(BaseEncoder):
|
|||
# Based on https://debuggercafe.com/sparse-autoencoders-using-l1-regularization-with-pytorch/
|
||||
name = "SparseL1AutoEncoder"
|
||||
|
||||
def __init__(self, name: Optional[str] = None, input_shape: int = 0, regularization_parameter: float = 0.001):
|
||||
def __init__(self, name: Optional[str] = None, input_shape: int = 0, loss_function=None, regularization_parameter: float = 0.001):
|
||||
self.log = logging.getLogger(self.__class__.__name__)
|
||||
|
||||
# Call superclass to initialize parameters.
|
||||
super(SparseL1AutoEncoder, self).__init__(name, input_shape)
|
||||
super(SparseL1AutoEncoder, self).__init__(name, input_shape, loss_function)
|
||||
|
||||
# Override parameters to custom values for this encoder type
|
||||
|
||||
|
|
|
@ -7,11 +7,53 @@ from typing import Optional
|
|||
|
||||
import numpy
|
||||
import torch
|
||||
from torch.nn.modules.loss import _Loss
|
||||
|
||||
from config import DATASET_STORAGE_BASE_PATH
|
||||
from models.base_dataset import BaseDataset
|
||||
|
||||
|
||||
class USWeatherLoss(_Loss):
|
||||
__constants__ = ['reduction']
|
||||
|
||||
def __init__(self, dataset=None, size_average=None, reduce=None, reduction: str = 'mean') -> None:
|
||||
self.dataset = dataset
|
||||
super(USWeatherLoss, self).__init__(size_average, reduce, reduction)
|
||||
|
||||
self.ce_loss = torch.nn.CrossEntropyLoss()
|
||||
self.l1_loss = torch.nn.L1Loss()
|
||||
|
||||
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
||||
losses = []
|
||||
start = 0
|
||||
length = len(self.dataset._labels['Type'])
|
||||
# Type is 1-hot encoded, so use cross entropy loss
|
||||
losses.append(self.ce_loss(input[start:start+length], torch.argmax(target[start:start+length].long(), dim=1)))
|
||||
start += length
|
||||
length = len(self.dataset._labels['Severity'])
|
||||
# Severity is 1-hot encoded, so use cross entropy loss
|
||||
losses.append(self.ce_loss(input[start:start+length], torch.argmax(target[start:start+length].long(), dim=1)))
|
||||
start += length
|
||||
# Start time is a number, so use L1 loss
|
||||
losses.append(self.l1_loss(input[start], target[start]))
|
||||
# End time is a number, so use L1 loss
|
||||
losses.append(self.l1_loss(input[start + 1], target[start + 1]))
|
||||
start += 2
|
||||
length = len(self.dataset._labels['TimeZone'])
|
||||
# TimeZone is 1-hot encoded, so use cross entropy loss
|
||||
losses.append(self.ce_loss(input[start:start+length], torch.argmax(target[start:start+length].long(), dim=1)))
|
||||
start += length
|
||||
# Location latitude is a number, so use L1 loss
|
||||
losses.append(self.l1_loss(input[start], target[start]))
|
||||
# Location longitude is a number, so use L1 loss
|
||||
losses.append(self.l1_loss(input[start + 1], target[start + 1]))
|
||||
start += 2
|
||||
length = len(self.dataset._labels['State'])
|
||||
# State is 1-hot encoded, so use cross entropy loss
|
||||
losses.append(self.ce_loss(input[start:start+length], torch.argmax(target[start:start+length].long(), dim=1)))
|
||||
return sum(losses)
|
||||
|
||||
|
||||
class USWeatherEventsDataset(BaseDataset):
|
||||
# Source: https://smoosavi.org/datasets/lstw
|
||||
# https://www.kaggle.com/sobhanmoosavi/us-weather-events
|
||||
|
@ -107,7 +149,9 @@ class USWeatherEventsDataset(BaseDataset):
|
|||
pickle.dump(dict(self._labels), f)
|
||||
self.log.info("Cached version created.")
|
||||
|
||||
train_data, test_data = self._data[:2500000], self._data[2500000:]
|
||||
# train_data, test_data = self._data[:2500000], self._data[2500000:]
|
||||
# Speed up training a bit
|
||||
train_data, test_data = self._data[:50000], self._data[100000:150000]
|
||||
|
||||
self._trainset = self.__class__.get_new(name=f"{self.name} Training", data=train_data, labels=self._labels,
|
||||
source_path=self._source_path)
|
||||
|
@ -140,42 +184,48 @@ class USWeatherEventsDataset(BaseDataset):
|
|||
|
||||
return data
|
||||
|
||||
def output_to_result_row(self, output):
|
||||
# Get 1-hot encoded values as list per value, and other values as value
|
||||
if not isinstance(output, list):
|
||||
output = output.tolist()
|
||||
|
||||
start = 0
|
||||
length = len(self._labels['Type'])
|
||||
event_types = output[start:start+length]
|
||||
start += length
|
||||
length = len(self._labels['Severity'])
|
||||
severities = output[start:start+length]
|
||||
start += length
|
||||
start_time = output[start]
|
||||
end_time = output[start+1]
|
||||
start += 2
|
||||
length = len(self._labels['TimeZone'])
|
||||
timezones = output[start:start+length]
|
||||
start += length
|
||||
location_lat = output[start]
|
||||
location_lng = output[start+1]
|
||||
start += 2
|
||||
length = len(self._labels['State'])
|
||||
states = output[start:start+length]
|
||||
|
||||
# Convert 1-hot encodings to normal labels, assume highest value as the true value.
|
||||
event_type = self._labels['Type'][event_types.index(max(event_types))]
|
||||
severity = self._labels['Severity'][severities.index(max(severities))]
|
||||
timezone = self._labels['TimeZone'][timezones.index(max(timezones))]
|
||||
state = self._labels['State'][states.index(max(states))]
|
||||
|
||||
# Convert timestamp float into string time
|
||||
start_time = datetime.fromtimestamp(start_time).strftime("%Y-%m-%d %H:%M:%S")
|
||||
end_time = datetime.fromtimestamp(end_time).strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
return [event_type, severity, start_time, end_time, timezone, location_lat, location_lng, state]
|
||||
|
||||
def save_batch_to_sample(self, batch, filename):
|
||||
res = ["Type,Severity,StartTime(UTC),EndTime(UTC),TimeZone,LocationLat,LocationLng,State\n"]
|
||||
|
||||
for row in batch:
|
||||
# Get 1-hot encoded values as list per value, and other values as value
|
||||
row = row.tolist()
|
||||
start = 0
|
||||
length = len(self._labels['Type'])
|
||||
event_types = row[start:start+length]
|
||||
start += length
|
||||
length = len(self._labels['Severity'])
|
||||
severities = row[start:start+length]
|
||||
start += length
|
||||
start_time = row[start]
|
||||
end_time = row[start+1]
|
||||
start += 2
|
||||
length = len(self._labels['TimeZone'])
|
||||
timezones = row[start:start+length]
|
||||
start += length
|
||||
location_lat = row[start]
|
||||
location_lng = row[start+1]
|
||||
start += 2
|
||||
length = len(self._labels['State'])
|
||||
states = row[start:start+length]
|
||||
|
||||
# Convert 1-hot encodings to normal labels, assume highest value as the true value.
|
||||
event_type = self._labels['Type'][event_types.index(max(event_types))]
|
||||
severity = self._labels['Severity'][severities.index(max(severities))]
|
||||
timezone = self._labels['TimeZone'][timezones.index(max(timezones))]
|
||||
state = self._labels['State'][states.index(max(states))]
|
||||
|
||||
# Convert timestamp float into string time
|
||||
start_time = datetime.fromtimestamp(start_time).strftime("%Y-%m-%d %H:%M:%S")
|
||||
end_time = datetime.fromtimestamp(end_time).strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
res.append(f"{event_type},{severity},{start_time},{end_time},{timezone},{location_lat},{location_lng},{state}\n")
|
||||
res.append(",".join(map(lambda x: f'{x}', self.output_to_result_row(row)))+"\n")
|
||||
|
||||
with open(f"{filename}.csv", "w") as f:
|
||||
f.writelines(res)
|
||||
|
@ -186,7 +236,10 @@ class USWeatherEventsDataset(BaseDataset):
|
|||
|
||||
total_score = 0
|
||||
for i in range(len(originals)):
|
||||
original, recon = originals[i], reconstruction[i]
|
||||
original, recon = self.output_to_result_row(originals[i]), self.output_to_result_row(reconstruction[i])
|
||||
total_score += sum(int(original[j] == recon[j]) for j in range(len(original))) / len(original)
|
||||
|
||||
return total_score / len(originals)
|
||||
|
||||
def get_loss_function(self):
|
||||
return USWeatherLoss(dataset=self)
|
||||
|
|
|
@ -12,11 +12,11 @@ class VariationalAutoEncoder(BaseEncoder):
|
|||
# and https://github.com/pytorch/examples/blob/master/vae/main.py
|
||||
name = "VariationalAutoEncoder"
|
||||
|
||||
def __init__(self, name: Optional[str] = None, input_shape: int = 0):
|
||||
def __init__(self, name: Optional[str] = None, input_shape: int = 0, loss_function=None):
|
||||
self.log = logging.getLogger(self.__class__.__name__)
|
||||
|
||||
# Call superclass to initialize parameters.
|
||||
super(VariationalAutoEncoder, self).__init__(name, input_shape)
|
||||
super(VariationalAutoEncoder, self).__init__(name, input_shape, loss_function)
|
||||
|
||||
# VAE needs intermediate output of the encoder stage, so split up the network into encoder/decoder
|
||||
# with no ReLU layer at the end of the encoder so we have access to the mu and variance.
|
||||
|
|
Loading…
Reference in a new issue