Allow loss function to be defined by the dataset. Add specialized loss function for US weather dataset
This commit is contained in:
		
							parent
							
								
									f6a19c4921
								
							
						
					
					
						commit
						e4c51e2d3d
					
				
					 9 changed files with 116 additions and 46 deletions
				
			
		
							
								
								
									
										10
									
								
								main.py
									
										
									
									
									
								
							
							
						
						
									
										10
									
								
								main.py
									
										
									
									
									
								
							| 
						 | 
				
			
			@ -36,13 +36,21 @@ def run_tests():
 | 
			
		|||
        dataset = dataset_model(**test['dataset_kwargs'])
 | 
			
		||||
        if test['encoder_kwargs'].get('input_shape', None) is None:
 | 
			
		||||
            test['encoder_kwargs']['input_shape'] = dataset.get_input_shape()
 | 
			
		||||
        if test['encoder_kwargs'].get('loss_function', None) is None:
 | 
			
		||||
            test['encoder_kwargs']['loss_function'] = dataset.get_loss_function()
 | 
			
		||||
        encoder = encoder_model(**test['encoder_kwargs'])
 | 
			
		||||
        encoder.after_init()
 | 
			
		||||
        corruption = corruption_model(**test['corruption_kwargs'])
 | 
			
		||||
        test_run = TestRun(dataset=dataset, encoder=encoder, corruption=corruption)
 | 
			
		||||
 | 
			
		||||
        # Run TestRun
 | 
			
		||||
        test_run.run(retrain=False)
 | 
			
		||||
        test_run.run(retrain=True)
 | 
			
		||||
 | 
			
		||||
        # Cleanup to avoid out-of-memory situations when running lots of tests
 | 
			
		||||
        del test_run
 | 
			
		||||
        del corruption
 | 
			
		||||
        del encoder
 | 
			
		||||
        del dataset
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -65,6 +65,9 @@ class BaseDataset(Dataset):
 | 
			
		|||
    def get_input_shape(self):
 | 
			
		||||
        return None
 | 
			
		||||
 | 
			
		||||
    def get_loss_function(self):
 | 
			
		||||
        return torch.nn.MSELoss()
 | 
			
		||||
 | 
			
		||||
    def _subdivide(self, amount: Union[int, float]):
 | 
			
		||||
        if self._data is None:
 | 
			
		||||
            raise ValueError("Cannot subdivide! Data not loaded, call `load()` first to load data")
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -19,7 +19,7 @@ class BaseEncoder(torch.nn.Module):
 | 
			
		|||
    # Based on https://medium.com/pytorch/implementing-an-autoencoder-in-pytorch-19baa22647d1
 | 
			
		||||
    name = "BaseEncoder"
 | 
			
		||||
 | 
			
		||||
    def __init__(self, name: Optional[str] = None, input_shape: int = 0):
 | 
			
		||||
    def __init__(self, name: Optional[str] = None, input_shape: int = 0, loss_function=None):
 | 
			
		||||
        super(BaseEncoder, self).__init__()
 | 
			
		||||
        self.log = logging.getLogger(self.__class__.__name__)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -51,6 +51,9 @@ class BaseEncoder(torch.nn.Module):
 | 
			
		|||
        self.optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
 | 
			
		||||
 | 
			
		||||
        # Mean Squared Error loss function
 | 
			
		||||
        if loss_function is not None:
 | 
			
		||||
            self.loss_function = loss_function
 | 
			
		||||
        else:
 | 
			
		||||
            self.loss_function = torch.nn.MSELoss()
 | 
			
		||||
 | 
			
		||||
    def after_init(self):
 | 
			
		||||
| 
						 | 
				
			
			@ -163,6 +166,9 @@ class BaseEncoder(torch.nn.Module):
 | 
			
		|||
 | 
			
		||||
            # display the epoch training loss
 | 
			
		||||
            self.log.info("epoch : {}/{}, loss = {:.6f}".format(epoch + 1, epochs, loss))
 | 
			
		||||
            self.log.debug(f"Expected: {compare_features.cpu().detach().numpy()[0].tolist()}")
 | 
			
		||||
            self.log.debug(f"Outputs:  {outputs_for_loss.cpu().detach().numpy()[0].tolist()}")
 | 
			
		||||
            self.log.debug(f"Loss:     {train_loss}")
 | 
			
		||||
            losses.append(loss)
 | 
			
		||||
 | 
			
		||||
            # Every 5 epochs, save a test image
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -9,10 +9,10 @@ class BasicAutoEncoder(BaseEncoder):
 | 
			
		|||
    # Based on https://medium.com/pytorch/implementing-an-autoencoder-in-pytorch-19baa22647d1
 | 
			
		||||
    name = "BasicAutoEncoder"
 | 
			
		||||
 | 
			
		||||
    def __init__(self, name: Optional[str] = None, input_shape: int = 0):
 | 
			
		||||
    def __init__(self, name: Optional[str] = None, input_shape: int = 0, loss_function=None):
 | 
			
		||||
        self.log = logging.getLogger(self.__class__.__name__)
 | 
			
		||||
 | 
			
		||||
        # Call superclass to initialize parameters.
 | 
			
		||||
        super(BasicAutoEncoder, self).__init__(name, input_shape)
 | 
			
		||||
        super(BasicAutoEncoder, self).__init__(name, input_shape, loss_function)
 | 
			
		||||
 | 
			
		||||
        # Network, optimizer and loss function are the same as defined in the base encoder.
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -13,11 +13,11 @@ class ContractiveAutoEncoder(BaseEncoder):
 | 
			
		|||
    # Based on https://github.com/avijit9/Contractive_Autoencoder_in_Pytorch/blob/master/CAE_pytorch.py
 | 
			
		||||
    name = "ContractiveAutoEncoder"
 | 
			
		||||
 | 
			
		||||
    def __init__(self, name: Optional[str] = None, input_shape: int = 0, regularizer_weight: float = 1e-4):
 | 
			
		||||
    def __init__(self, name: Optional[str] = None, input_shape: int = 0, loss_function=None, regularizer_weight: float = 1e-4):
 | 
			
		||||
        self.log = logging.getLogger(self.__class__.__name__)
 | 
			
		||||
 | 
			
		||||
        # Call superclass to initialize parameters.
 | 
			
		||||
        super(ContractiveAutoEncoder, self).__init__(name, input_shape)
 | 
			
		||||
        super(ContractiveAutoEncoder, self).__init__(name, input_shape, loss_function)
 | 
			
		||||
 | 
			
		||||
        self.regularizer_weight = regularizer_weight
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -15,12 +15,12 @@ class DenoisingAutoEncoder(BaseEncoder):
 | 
			
		|||
    # Based on https://github.com/pranjaldatta/Denoising-Autoencoder-in-Pytorch/blob/master/DenoisingAutoencoder.ipynb
 | 
			
		||||
    name = "DenoisingAutoEncoder"
 | 
			
		||||
 | 
			
		||||
    def __init__(self, name: Optional[str] = None, input_shape: int = 0,
 | 
			
		||||
    def __init__(self, name: Optional[str] = None, input_shape: int = 0, loss_function=None,
 | 
			
		||||
                 input_corruption_model: BaseCorruption = NoCorruption):
 | 
			
		||||
        self.log = logging.getLogger(self.__class__.__name__)
 | 
			
		||||
 | 
			
		||||
        # Call superclass to initialize parameters.
 | 
			
		||||
        super(DenoisingAutoEncoder, self).__init__(name, input_shape)
 | 
			
		||||
        super(DenoisingAutoEncoder, self).__init__(name, input_shape, loss_function)
 | 
			
		||||
 | 
			
		||||
        # Network, optimizer and loss function are the same as defined in the base encoder.
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -14,11 +14,11 @@ class SparseL1AutoEncoder(BaseEncoder):
 | 
			
		|||
    # Based on https://debuggercafe.com/sparse-autoencoders-using-l1-regularization-with-pytorch/
 | 
			
		||||
    name = "SparseL1AutoEncoder"
 | 
			
		||||
 | 
			
		||||
    def __init__(self, name: Optional[str] = None, input_shape: int = 0, regularization_parameter: float = 0.001):
 | 
			
		||||
    def __init__(self, name: Optional[str] = None, input_shape: int = 0, loss_function=None, regularization_parameter: float = 0.001):
 | 
			
		||||
        self.log = logging.getLogger(self.__class__.__name__)
 | 
			
		||||
 | 
			
		||||
        # Call superclass to initialize parameters.
 | 
			
		||||
        super(SparseL1AutoEncoder, self).__init__(name, input_shape)
 | 
			
		||||
        super(SparseL1AutoEncoder, self).__init__(name, input_shape, loss_function)
 | 
			
		||||
 | 
			
		||||
        # Override parameters to custom values for this encoder type
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -7,11 +7,53 @@ from typing import Optional
 | 
			
		|||
 | 
			
		||||
import numpy
 | 
			
		||||
import torch
 | 
			
		||||
from torch.nn.modules.loss import _Loss
 | 
			
		||||
 | 
			
		||||
from config import DATASET_STORAGE_BASE_PATH
 | 
			
		||||
from models.base_dataset import BaseDataset
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class USWeatherLoss(_Loss):
 | 
			
		||||
    __constants__ = ['reduction']
 | 
			
		||||
 | 
			
		||||
    def __init__(self, dataset=None, size_average=None, reduce=None, reduction: str = 'mean') -> None:
 | 
			
		||||
        self.dataset = dataset
 | 
			
		||||
        super(USWeatherLoss, self).__init__(size_average, reduce, reduction)
 | 
			
		||||
 | 
			
		||||
        self.ce_loss = torch.nn.CrossEntropyLoss()
 | 
			
		||||
        self.l1_loss = torch.nn.L1Loss()
 | 
			
		||||
 | 
			
		||||
    def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
 | 
			
		||||
        losses = []
 | 
			
		||||
        start = 0
 | 
			
		||||
        length = len(self.dataset._labels['Type'])
 | 
			
		||||
        # Type is 1-hot encoded, so use cross entropy loss
 | 
			
		||||
        losses.append(self.ce_loss(input[start:start+length], torch.argmax(target[start:start+length].long(), dim=1)))
 | 
			
		||||
        start += length
 | 
			
		||||
        length = len(self.dataset._labels['Severity'])
 | 
			
		||||
        # Severity is 1-hot encoded, so use cross entropy loss
 | 
			
		||||
        losses.append(self.ce_loss(input[start:start+length], torch.argmax(target[start:start+length].long(), dim=1)))
 | 
			
		||||
        start += length
 | 
			
		||||
        # Start time is a number, so use L1 loss
 | 
			
		||||
        losses.append(self.l1_loss(input[start], target[start]))
 | 
			
		||||
        # End time is a number, so use L1 loss
 | 
			
		||||
        losses.append(self.l1_loss(input[start + 1], target[start + 1]))
 | 
			
		||||
        start += 2
 | 
			
		||||
        length = len(self.dataset._labels['TimeZone'])
 | 
			
		||||
        # TimeZone is 1-hot encoded, so use cross entropy loss
 | 
			
		||||
        losses.append(self.ce_loss(input[start:start+length], torch.argmax(target[start:start+length].long(), dim=1)))
 | 
			
		||||
        start += length
 | 
			
		||||
        # Location latitude is a number, so use L1 loss
 | 
			
		||||
        losses.append(self.l1_loss(input[start], target[start]))
 | 
			
		||||
        # Location longitude is a number, so use L1 loss
 | 
			
		||||
        losses.append(self.l1_loss(input[start + 1], target[start + 1]))
 | 
			
		||||
        start += 2
 | 
			
		||||
        length = len(self.dataset._labels['State'])
 | 
			
		||||
        # State is 1-hot encoded, so use cross entropy loss
 | 
			
		||||
        losses.append(self.ce_loss(input[start:start+length], torch.argmax(target[start:start+length].long(), dim=1)))
 | 
			
		||||
        return sum(losses)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class USWeatherEventsDataset(BaseDataset):
 | 
			
		||||
    # Source: https://smoosavi.org/datasets/lstw
 | 
			
		||||
    # https://www.kaggle.com/sobhanmoosavi/us-weather-events
 | 
			
		||||
| 
						 | 
				
			
			@ -107,7 +149,9 @@ class USWeatherEventsDataset(BaseDataset):
 | 
			
		|||
                pickle.dump(dict(self._labels), f)
 | 
			
		||||
            self.log.info("Cached version created.")
 | 
			
		||||
 | 
			
		||||
        train_data, test_data = self._data[:2500000], self._data[2500000:]
 | 
			
		||||
        # train_data, test_data = self._data[:2500000], self._data[2500000:]
 | 
			
		||||
        # Speed up training a bit
 | 
			
		||||
        train_data, test_data = self._data[:50000], self._data[100000:150000]
 | 
			
		||||
 | 
			
		||||
        self._trainset = self.__class__.get_new(name=f"{self.name} Training", data=train_data, labels=self._labels,
 | 
			
		||||
                                                source_path=self._source_path)
 | 
			
		||||
| 
						 | 
				
			
			@ -140,30 +184,29 @@ class USWeatherEventsDataset(BaseDataset):
 | 
			
		|||
 | 
			
		||||
        return data
 | 
			
		||||
 | 
			
		||||
    def save_batch_to_sample(self, batch, filename):
 | 
			
		||||
        res = ["Type,Severity,StartTime(UTC),EndTime(UTC),TimeZone,LocationLat,LocationLng,State\n"]
 | 
			
		||||
 | 
			
		||||
        for row in batch:
 | 
			
		||||
    def output_to_result_row(self, output):
 | 
			
		||||
        # Get 1-hot encoded values as list per value, and other values as value
 | 
			
		||||
            row = row.tolist()
 | 
			
		||||
        if not isinstance(output, list):
 | 
			
		||||
            output = output.tolist()
 | 
			
		||||
 | 
			
		||||
        start = 0
 | 
			
		||||
        length = len(self._labels['Type'])
 | 
			
		||||
            event_types = row[start:start+length]
 | 
			
		||||
        event_types = output[start:start+length]
 | 
			
		||||
        start += length
 | 
			
		||||
        length = len(self._labels['Severity'])
 | 
			
		||||
            severities = row[start:start+length]
 | 
			
		||||
        severities = output[start:start+length]
 | 
			
		||||
        start += length
 | 
			
		||||
            start_time = row[start]
 | 
			
		||||
            end_time = row[start+1]
 | 
			
		||||
        start_time = output[start]
 | 
			
		||||
        end_time = output[start+1]
 | 
			
		||||
        start += 2
 | 
			
		||||
        length = len(self._labels['TimeZone'])
 | 
			
		||||
            timezones = row[start:start+length]
 | 
			
		||||
        timezones = output[start:start+length]
 | 
			
		||||
        start += length
 | 
			
		||||
            location_lat = row[start]
 | 
			
		||||
            location_lng = row[start+1]
 | 
			
		||||
        location_lat = output[start]
 | 
			
		||||
        location_lng = output[start+1]
 | 
			
		||||
        start += 2
 | 
			
		||||
        length = len(self._labels['State'])
 | 
			
		||||
            states = row[start:start+length]
 | 
			
		||||
        states = output[start:start+length]
 | 
			
		||||
 | 
			
		||||
        # Convert 1-hot encodings to normal labels, assume highest value as the true value.
 | 
			
		||||
        event_type = self._labels['Type'][event_types.index(max(event_types))]
 | 
			
		||||
| 
						 | 
				
			
			@ -175,7 +218,14 @@ class USWeatherEventsDataset(BaseDataset):
 | 
			
		|||
        start_time = datetime.fromtimestamp(start_time).strftime("%Y-%m-%d %H:%M:%S")
 | 
			
		||||
        end_time = datetime.fromtimestamp(end_time).strftime("%Y-%m-%d %H:%M:%S")
 | 
			
		||||
 | 
			
		||||
            res.append(f"{event_type},{severity},{start_time},{end_time},{timezone},{location_lat},{location_lng},{state}\n")
 | 
			
		||||
        return [event_type, severity, start_time, end_time, timezone, location_lat, location_lng, state]
 | 
			
		||||
 | 
			
		||||
    def save_batch_to_sample(self, batch, filename):
 | 
			
		||||
        res = ["Type,Severity,StartTime(UTC),EndTime(UTC),TimeZone,LocationLat,LocationLng,State\n"]
 | 
			
		||||
 | 
			
		||||
        for row in batch:
 | 
			
		||||
            row = row.tolist()
 | 
			
		||||
            res.append(",".join(map(lambda x: f'{x}', self.output_to_result_row(row)))+"\n")
 | 
			
		||||
 | 
			
		||||
        with open(f"{filename}.csv", "w") as f:
 | 
			
		||||
            f.writelines(res)
 | 
			
		||||
| 
						 | 
				
			
			@ -186,7 +236,10 @@ class USWeatherEventsDataset(BaseDataset):
 | 
			
		|||
 | 
			
		||||
        total_score = 0
 | 
			
		||||
        for i in range(len(originals)):
 | 
			
		||||
            original, recon = originals[i], reconstruction[i]
 | 
			
		||||
            original, recon =  self.output_to_result_row(originals[i]),  self.output_to_result_row(reconstruction[i])
 | 
			
		||||
            total_score += sum(int(original[j] == recon[j]) for j in range(len(original))) / len(original)
 | 
			
		||||
 | 
			
		||||
        return total_score / len(originals)
 | 
			
		||||
 | 
			
		||||
    def get_loss_function(self):
 | 
			
		||||
        return USWeatherLoss(dataset=self)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -12,11 +12,11 @@ class VariationalAutoEncoder(BaseEncoder):
 | 
			
		|||
    # and https://github.com/pytorch/examples/blob/master/vae/main.py
 | 
			
		||||
    name = "VariationalAutoEncoder"
 | 
			
		||||
 | 
			
		||||
    def __init__(self, name: Optional[str] = None, input_shape: int = 0):
 | 
			
		||||
    def __init__(self, name: Optional[str] = None, input_shape: int = 0, loss_function=None):
 | 
			
		||||
        self.log = logging.getLogger(self.__class__.__name__)
 | 
			
		||||
 | 
			
		||||
        # Call superclass to initialize parameters.
 | 
			
		||||
        super(VariationalAutoEncoder, self).__init__(name, input_shape)
 | 
			
		||||
        super(VariationalAutoEncoder, self).__init__(name, input_shape, loss_function)
 | 
			
		||||
 | 
			
		||||
        # VAE needs intermediate output of the encoder stage, so split up the network into encoder/decoder
 | 
			
		||||
        # with no ReLU layer at the end of the encoder so we have access to the mu and variance.
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Reference in a new issue