From e4c51e2d3d23658706ceb58083afc48567428025 Mon Sep 17 00:00:00 2001 From: Kevin Alberts Date: Thu, 21 Jan 2021 21:47:00 +0100 Subject: [PATCH] Allow loss function to be defined by the dataset. Add specialized loss function for US weather dataset --- main.py | 10 ++- models/base_dataset.py | 3 + models/base_encoder.py | 10 ++- models/basic_encoder.py | 4 +- models/contractive_encoder.py | 4 +- models/denoising_encoder.py | 4 +- models/sparse_encoder.py | 4 +- models/usweather_dataset.py | 119 ++++++++++++++++++++++++---------- models/variational_encoder.py | 4 +- 9 files changed, 116 insertions(+), 46 deletions(-) diff --git a/main.py b/main.py index 8e87eb0..294f460 100644 --- a/main.py +++ b/main.py @@ -36,13 +36,21 @@ def run_tests(): dataset = dataset_model(**test['dataset_kwargs']) if test['encoder_kwargs'].get('input_shape', None) is None: test['encoder_kwargs']['input_shape'] = dataset.get_input_shape() + if test['encoder_kwargs'].get('loss_function', None) is None: + test['encoder_kwargs']['loss_function'] = dataset.get_loss_function() encoder = encoder_model(**test['encoder_kwargs']) encoder.after_init() corruption = corruption_model(**test['corruption_kwargs']) test_run = TestRun(dataset=dataset, encoder=encoder, corruption=corruption) # Run TestRun - test_run.run(retrain=False) + test_run.run(retrain=True) + + # Cleanup to avoid out-of-memory situations when running lots of tests + del test_run + del corruption + del encoder + del dataset if __name__ == '__main__': diff --git a/models/base_dataset.py b/models/base_dataset.py index 091a076..feb1fce 100644 --- a/models/base_dataset.py +++ b/models/base_dataset.py @@ -65,6 +65,9 @@ class BaseDataset(Dataset): def get_input_shape(self): return None + def get_loss_function(self): + return torch.nn.MSELoss() + def _subdivide(self, amount: Union[int, float]): if self._data is None: raise ValueError("Cannot subdivide! Data not loaded, call `load()` first to load data") diff --git a/models/base_encoder.py b/models/base_encoder.py index df386c3..017ba3e 100644 --- a/models/base_encoder.py +++ b/models/base_encoder.py @@ -19,7 +19,7 @@ class BaseEncoder(torch.nn.Module): # Based on https://medium.com/pytorch/implementing-an-autoencoder-in-pytorch-19baa22647d1 name = "BaseEncoder" - def __init__(self, name: Optional[str] = None, input_shape: int = 0): + def __init__(self, name: Optional[str] = None, input_shape: int = 0, loss_function=None): super(BaseEncoder, self).__init__() self.log = logging.getLogger(self.__class__.__name__) @@ -51,7 +51,10 @@ class BaseEncoder(torch.nn.Module): self.optimizer = torch.optim.Adam(self.parameters(), lr=1e-3) # Mean Squared Error loss function - self.loss_function = torch.nn.MSELoss() + if loss_function is not None: + self.loss_function = loss_function + else: + self.loss_function = torch.nn.MSELoss() def after_init(self): self.log.info(f"Auto-encoder {self.__class__.__name__} initialized with " @@ -163,6 +166,9 @@ class BaseEncoder(torch.nn.Module): # display the epoch training loss self.log.info("epoch : {}/{}, loss = {:.6f}".format(epoch + 1, epochs, loss)) + self.log.debug(f"Expected: {compare_features.cpu().detach().numpy()[0].tolist()}") + self.log.debug(f"Outputs: {outputs_for_loss.cpu().detach().numpy()[0].tolist()}") + self.log.debug(f"Loss: {train_loss}") losses.append(loss) # Every 5 epochs, save a test image diff --git a/models/basic_encoder.py b/models/basic_encoder.py index 0746175..6e7f047 100644 --- a/models/basic_encoder.py +++ b/models/basic_encoder.py @@ -9,10 +9,10 @@ class BasicAutoEncoder(BaseEncoder): # Based on https://medium.com/pytorch/implementing-an-autoencoder-in-pytorch-19baa22647d1 name = "BasicAutoEncoder" - def __init__(self, name: Optional[str] = None, input_shape: int = 0): + def __init__(self, name: Optional[str] = None, input_shape: int = 0, loss_function=None): self.log = logging.getLogger(self.__class__.__name__) # Call superclass to initialize parameters. - super(BasicAutoEncoder, self).__init__(name, input_shape) + super(BasicAutoEncoder, self).__init__(name, input_shape, loss_function) # Network, optimizer and loss function are the same as defined in the base encoder. diff --git a/models/contractive_encoder.py b/models/contractive_encoder.py index 98685da..eeadb98 100644 --- a/models/contractive_encoder.py +++ b/models/contractive_encoder.py @@ -13,11 +13,11 @@ class ContractiveAutoEncoder(BaseEncoder): # Based on https://github.com/avijit9/Contractive_Autoencoder_in_Pytorch/blob/master/CAE_pytorch.py name = "ContractiveAutoEncoder" - def __init__(self, name: Optional[str] = None, input_shape: int = 0, regularizer_weight: float = 1e-4): + def __init__(self, name: Optional[str] = None, input_shape: int = 0, loss_function=None, regularizer_weight: float = 1e-4): self.log = logging.getLogger(self.__class__.__name__) # Call superclass to initialize parameters. - super(ContractiveAutoEncoder, self).__init__(name, input_shape) + super(ContractiveAutoEncoder, self).__init__(name, input_shape, loss_function) self.regularizer_weight = regularizer_weight diff --git a/models/denoising_encoder.py b/models/denoising_encoder.py index 8af747a..f1e8e5f 100644 --- a/models/denoising_encoder.py +++ b/models/denoising_encoder.py @@ -15,12 +15,12 @@ class DenoisingAutoEncoder(BaseEncoder): # Based on https://github.com/pranjaldatta/Denoising-Autoencoder-in-Pytorch/blob/master/DenoisingAutoencoder.ipynb name = "DenoisingAutoEncoder" - def __init__(self, name: Optional[str] = None, input_shape: int = 0, + def __init__(self, name: Optional[str] = None, input_shape: int = 0, loss_function=None, input_corruption_model: BaseCorruption = NoCorruption): self.log = logging.getLogger(self.__class__.__name__) # Call superclass to initialize parameters. - super(DenoisingAutoEncoder, self).__init__(name, input_shape) + super(DenoisingAutoEncoder, self).__init__(name, input_shape, loss_function) # Network, optimizer and loss function are the same as defined in the base encoder. diff --git a/models/sparse_encoder.py b/models/sparse_encoder.py index 720a4b1..0be13c5 100644 --- a/models/sparse_encoder.py +++ b/models/sparse_encoder.py @@ -14,11 +14,11 @@ class SparseL1AutoEncoder(BaseEncoder): # Based on https://debuggercafe.com/sparse-autoencoders-using-l1-regularization-with-pytorch/ name = "SparseL1AutoEncoder" - def __init__(self, name: Optional[str] = None, input_shape: int = 0, regularization_parameter: float = 0.001): + def __init__(self, name: Optional[str] = None, input_shape: int = 0, loss_function=None, regularization_parameter: float = 0.001): self.log = logging.getLogger(self.__class__.__name__) # Call superclass to initialize parameters. - super(SparseL1AutoEncoder, self).__init__(name, input_shape) + super(SparseL1AutoEncoder, self).__init__(name, input_shape, loss_function) # Override parameters to custom values for this encoder type diff --git a/models/usweather_dataset.py b/models/usweather_dataset.py index 5ba1a33..6d71f06 100644 --- a/models/usweather_dataset.py +++ b/models/usweather_dataset.py @@ -7,11 +7,53 @@ from typing import Optional import numpy import torch +from torch.nn.modules.loss import _Loss from config import DATASET_STORAGE_BASE_PATH from models.base_dataset import BaseDataset +class USWeatherLoss(_Loss): + __constants__ = ['reduction'] + + def __init__(self, dataset=None, size_average=None, reduce=None, reduction: str = 'mean') -> None: + self.dataset = dataset + super(USWeatherLoss, self).__init__(size_average, reduce, reduction) + + self.ce_loss = torch.nn.CrossEntropyLoss() + self.l1_loss = torch.nn.L1Loss() + + def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + losses = [] + start = 0 + length = len(self.dataset._labels['Type']) + # Type is 1-hot encoded, so use cross entropy loss + losses.append(self.ce_loss(input[start:start+length], torch.argmax(target[start:start+length].long(), dim=1))) + start += length + length = len(self.dataset._labels['Severity']) + # Severity is 1-hot encoded, so use cross entropy loss + losses.append(self.ce_loss(input[start:start+length], torch.argmax(target[start:start+length].long(), dim=1))) + start += length + # Start time is a number, so use L1 loss + losses.append(self.l1_loss(input[start], target[start])) + # End time is a number, so use L1 loss + losses.append(self.l1_loss(input[start + 1], target[start + 1])) + start += 2 + length = len(self.dataset._labels['TimeZone']) + # TimeZone is 1-hot encoded, so use cross entropy loss + losses.append(self.ce_loss(input[start:start+length], torch.argmax(target[start:start+length].long(), dim=1))) + start += length + # Location latitude is a number, so use L1 loss + losses.append(self.l1_loss(input[start], target[start])) + # Location longitude is a number, so use L1 loss + losses.append(self.l1_loss(input[start + 1], target[start + 1])) + start += 2 + length = len(self.dataset._labels['State']) + # State is 1-hot encoded, so use cross entropy loss + losses.append(self.ce_loss(input[start:start+length], torch.argmax(target[start:start+length].long(), dim=1))) + return sum(losses) + + class USWeatherEventsDataset(BaseDataset): # Source: https://smoosavi.org/datasets/lstw # https://www.kaggle.com/sobhanmoosavi/us-weather-events @@ -107,7 +149,9 @@ class USWeatherEventsDataset(BaseDataset): pickle.dump(dict(self._labels), f) self.log.info("Cached version created.") - train_data, test_data = self._data[:2500000], self._data[2500000:] + # train_data, test_data = self._data[:2500000], self._data[2500000:] + # Speed up training a bit + train_data, test_data = self._data[:50000], self._data[100000:150000] self._trainset = self.__class__.get_new(name=f"{self.name} Training", data=train_data, labels=self._labels, source_path=self._source_path) @@ -140,42 +184,48 @@ class USWeatherEventsDataset(BaseDataset): return data + def output_to_result_row(self, output): + # Get 1-hot encoded values as list per value, and other values as value + if not isinstance(output, list): + output = output.tolist() + + start = 0 + length = len(self._labels['Type']) + event_types = output[start:start+length] + start += length + length = len(self._labels['Severity']) + severities = output[start:start+length] + start += length + start_time = output[start] + end_time = output[start+1] + start += 2 + length = len(self._labels['TimeZone']) + timezones = output[start:start+length] + start += length + location_lat = output[start] + location_lng = output[start+1] + start += 2 + length = len(self._labels['State']) + states = output[start:start+length] + + # Convert 1-hot encodings to normal labels, assume highest value as the true value. + event_type = self._labels['Type'][event_types.index(max(event_types))] + severity = self._labels['Severity'][severities.index(max(severities))] + timezone = self._labels['TimeZone'][timezones.index(max(timezones))] + state = self._labels['State'][states.index(max(states))] + + # Convert timestamp float into string time + start_time = datetime.fromtimestamp(start_time).strftime("%Y-%m-%d %H:%M:%S") + end_time = datetime.fromtimestamp(end_time).strftime("%Y-%m-%d %H:%M:%S") + + return [event_type, severity, start_time, end_time, timezone, location_lat, location_lng, state] + def save_batch_to_sample(self, batch, filename): res = ["Type,Severity,StartTime(UTC),EndTime(UTC),TimeZone,LocationLat,LocationLng,State\n"] for row in batch: - # Get 1-hot encoded values as list per value, and other values as value row = row.tolist() - start = 0 - length = len(self._labels['Type']) - event_types = row[start:start+length] - start += length - length = len(self._labels['Severity']) - severities = row[start:start+length] - start += length - start_time = row[start] - end_time = row[start+1] - start += 2 - length = len(self._labels['TimeZone']) - timezones = row[start:start+length] - start += length - location_lat = row[start] - location_lng = row[start+1] - start += 2 - length = len(self._labels['State']) - states = row[start:start+length] - - # Convert 1-hot encodings to normal labels, assume highest value as the true value. - event_type = self._labels['Type'][event_types.index(max(event_types))] - severity = self._labels['Severity'][severities.index(max(severities))] - timezone = self._labels['TimeZone'][timezones.index(max(timezones))] - state = self._labels['State'][states.index(max(states))] - - # Convert timestamp float into string time - start_time = datetime.fromtimestamp(start_time).strftime("%Y-%m-%d %H:%M:%S") - end_time = datetime.fromtimestamp(end_time).strftime("%Y-%m-%d %H:%M:%S") - - res.append(f"{event_type},{severity},{start_time},{end_time},{timezone},{location_lat},{location_lng},{state}\n") + res.append(",".join(map(lambda x: f'{x}', self.output_to_result_row(row)))+"\n") with open(f"{filename}.csv", "w") as f: f.writelines(res) @@ -186,7 +236,10 @@ class USWeatherEventsDataset(BaseDataset): total_score = 0 for i in range(len(originals)): - original, recon = originals[i], reconstruction[i] + original, recon = self.output_to_result_row(originals[i]), self.output_to_result_row(reconstruction[i]) total_score += sum(int(original[j] == recon[j]) for j in range(len(original))) / len(original) return total_score / len(originals) + + def get_loss_function(self): + return USWeatherLoss(dataset=self) diff --git a/models/variational_encoder.py b/models/variational_encoder.py index 4b20529..30af130 100644 --- a/models/variational_encoder.py +++ b/models/variational_encoder.py @@ -12,11 +12,11 @@ class VariationalAutoEncoder(BaseEncoder): # and https://github.com/pytorch/examples/blob/master/vae/main.py name = "VariationalAutoEncoder" - def __init__(self, name: Optional[str] = None, input_shape: int = 0): + def __init__(self, name: Optional[str] = None, input_shape: int = 0, loss_function=None): self.log = logging.getLogger(self.__class__.__name__) # Call superclass to initialize parameters. - super(VariationalAutoEncoder, self).__init__(name, input_shape) + super(VariationalAutoEncoder, self).__init__(name, input_shape, loss_function) # VAE needs intermediate output of the encoder stage, so split up the network into encoder/decoder # with no ReLU layer at the end of the encoder so we have access to the mu and variance.