Small changes to make the US weather dataset work properly, example test runs in config file.
This commit is contained in:
		
							parent
							
								
									e4c51e2d3d
								
							
						
					
					
						commit
						d0785e12e2
					
				
					 4 changed files with 154 additions and 56 deletions
				
			
		| 
						 | 
					@ -3,14 +3,147 @@ DATASET_STORAGE_BASE_PATH = "/path/to/this/project/datasets"
 | 
				
			||||||
TRAIN_TEMP_DATA_BASE_PATH = "/path/to/this/project/train_temp"
 | 
					TRAIN_TEMP_DATA_BASE_PATH = "/path/to/this/project/train_temp"
 | 
				
			||||||
TEST_TEMP_DATA_BASE_PATH = "/path/to/this/project/test_temp"
 | 
					TEST_TEMP_DATA_BASE_PATH = "/path/to/this/project/test_temp"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
TEST_RUNS = [
 | 
					TEST_RUNS = [
 | 
				
			||||||
    {
 | 
					    # CIFAR-10 dataset
 | 
				
			||||||
        'name': "Basic test run",
 | 
					    # {
 | 
				
			||||||
        'encoder_model': "models.base_encoder.BaseEncoder",
 | 
					    #     'name': "CIFAR-10 on basic auto-encoder",
 | 
				
			||||||
        'encoder_kwargs': {},
 | 
					    #     'encoder_model': "models.basic_encoder.BasicAutoEncoder",
 | 
				
			||||||
        'dataset_model': "models.base_dataset.BaseDataset",
 | 
					    #     'encoder_kwargs': {},
 | 
				
			||||||
        'dataset_kwargs': {},
 | 
					    #     'dataset_model': "models.cifar10_dataset.Cifar10Dataset",
 | 
				
			||||||
        'corruption_model': "models.base_corruption.NoCorruption",
 | 
					    #     'dataset_kwargs': {"path": "cifar-10-batches-py"},
 | 
				
			||||||
        'corruption_kwargs': {},
 | 
					    #     'corruption_model': "models.gaussian_corruption.GaussianCorruption",
 | 
				
			||||||
    },
 | 
					    #     'corruption_kwargs': {},
 | 
				
			||||||
 | 
					    # },
 | 
				
			||||||
 | 
					    # {
 | 
				
			||||||
 | 
					    #     'name': "CIFAR-10 on sparse L1 auto-encoder",
 | 
				
			||||||
 | 
					    #     'encoder_model': "models.sparse_encoder.SparseL1AutoEncoder",
 | 
				
			||||||
 | 
					    #     'encoder_kwargs': {},
 | 
				
			||||||
 | 
					    #     'dataset_model': "models.cifar10_dataset.Cifar10Dataset",
 | 
				
			||||||
 | 
					    #     'dataset_kwargs': {"path": "cifar-10-batches-py"},
 | 
				
			||||||
 | 
					    #     'corruption_model': "models.gaussian_corruption.GaussianCorruption",
 | 
				
			||||||
 | 
					    #     'corruption_kwargs': {},
 | 
				
			||||||
 | 
					    # },
 | 
				
			||||||
 | 
					    # {
 | 
				
			||||||
 | 
					    #     'name': "CIFAR-10 on denoising auto-encoder",
 | 
				
			||||||
 | 
					    #     'encoder_model': "models.denoising_encoder.DenoisingAutoEncoder",
 | 
				
			||||||
 | 
					    #     'encoder_kwargs': {'input_corruption_model': "models.gaussian_corruption.GaussianCorruption"},
 | 
				
			||||||
 | 
					    #     'dataset_model': "models.cifar10_dataset.Cifar10Dataset",
 | 
				
			||||||
 | 
					    #     'dataset_kwargs': {"path": "cifar-10-batches-py"},
 | 
				
			||||||
 | 
					    #     'corruption_model': "models.gaussian_corruption.GaussianCorruption",
 | 
				
			||||||
 | 
					    #     'corruption_kwargs': {},
 | 
				
			||||||
 | 
					    # },
 | 
				
			||||||
 | 
					    # {
 | 
				
			||||||
 | 
					    #     'name': "CIFAR-10 on contractive auto-encoder",
 | 
				
			||||||
 | 
					    #     'encoder_model': "models.contractive_encoder.ContractiveAutoEncoder",
 | 
				
			||||||
 | 
					    #     'encoder_kwargs': {},
 | 
				
			||||||
 | 
					    #     'dataset_model': "models.cifar10_dataset.Cifar10Dataset",
 | 
				
			||||||
 | 
					    #     'dataset_kwargs': {"path": "cifar-10-batches-py"},
 | 
				
			||||||
 | 
					    #     'corruption_model': "models.gaussian_corruption.GaussianCorruption",
 | 
				
			||||||
 | 
					    #     'corruption_kwargs': {},
 | 
				
			||||||
 | 
					    # },
 | 
				
			||||||
 | 
					    # {
 | 
				
			||||||
 | 
					    #     'name': "CIFAR-10 on variational auto-encoder",
 | 
				
			||||||
 | 
					    #     'encoder_model': "models.variational_encoder.VariationalAutoEncoder",
 | 
				
			||||||
 | 
					    #     'encoder_kwargs': {},
 | 
				
			||||||
 | 
					    #     'dataset_model': "models.cifar10_dataset.Cifar10Dataset",
 | 
				
			||||||
 | 
					    #     'dataset_kwargs': {"path": "cifar-10-batches-py"},
 | 
				
			||||||
 | 
					    #     'corruption_model': "models.gaussian_corruption.GaussianCorruption",
 | 
				
			||||||
 | 
					    #     'corruption_kwargs': {},
 | 
				
			||||||
 | 
					    # },
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # MNIST dataset
 | 
				
			||||||
 | 
					    # {
 | 
				
			||||||
 | 
					    #     'name': "MNIST on basic auto-encoder",
 | 
				
			||||||
 | 
					    #     'encoder_model': "models.basic_encoder.BasicAutoEncoder",
 | 
				
			||||||
 | 
					    #     'encoder_kwargs': {},
 | 
				
			||||||
 | 
					    #     'dataset_model': "models.mnist_dataset.MNISTDataset",
 | 
				
			||||||
 | 
					    #     'dataset_kwargs': {"path": "mnist"},
 | 
				
			||||||
 | 
					    #     'corruption_model': "models.gaussian_corruption.GaussianCorruption",
 | 
				
			||||||
 | 
					    #     'corruption_kwargs': {},
 | 
				
			||||||
 | 
					    # },
 | 
				
			||||||
 | 
					    # {
 | 
				
			||||||
 | 
					    #     'name': "MNIST on sparse L1 auto-encoder",
 | 
				
			||||||
 | 
					    #     'encoder_model': "models.sparse_encoder.SparseL1AutoEncoder",
 | 
				
			||||||
 | 
					    #     'encoder_kwargs': {},
 | 
				
			||||||
 | 
					    #     'dataset_model': "models.mnist_dataset.MNISTDataset",
 | 
				
			||||||
 | 
					    #     'dataset_kwargs': {"path": "mnist"},
 | 
				
			||||||
 | 
					    #     'corruption_model': "models.gaussian_corruption.GaussianCorruption",
 | 
				
			||||||
 | 
					    #     'corruption_kwargs': {},
 | 
				
			||||||
 | 
					    # },
 | 
				
			||||||
 | 
					    # {
 | 
				
			||||||
 | 
					    #     'name': "MNIST on denoising auto-encoder",
 | 
				
			||||||
 | 
					    #     'encoder_model': "models.denoising_encoder.DenoisingAutoEncoder",
 | 
				
			||||||
 | 
					    #     'encoder_kwargs': {'input_corruption_model': "models.gaussian_corruption.GaussianCorruption"},
 | 
				
			||||||
 | 
					    #     'dataset_model': "models.mnist_dataset.MNISTDataset",
 | 
				
			||||||
 | 
					    #     'dataset_kwargs': {"path": "mnist"},
 | 
				
			||||||
 | 
					    #     'corruption_model': "models.gaussian_corruption.GaussianCorruption",
 | 
				
			||||||
 | 
					    #     'corruption_kwargs': {},
 | 
				
			||||||
 | 
					    # },
 | 
				
			||||||
 | 
					    # {
 | 
				
			||||||
 | 
					    #     'name': "MNIST on contractive auto-encoder",
 | 
				
			||||||
 | 
					    #     'encoder_model': "models.contractive_encoder.ContractiveAutoEncoder",
 | 
				
			||||||
 | 
					    #     'encoder_kwargs': {},
 | 
				
			||||||
 | 
					    #     'dataset_model': "models.mnist_dataset.MNISTDataset",
 | 
				
			||||||
 | 
					    #     'dataset_kwargs': {"path": "mnist"},
 | 
				
			||||||
 | 
					    #     'corruption_model': "models.gaussian_corruption.GaussianCorruption",
 | 
				
			||||||
 | 
					    #     'corruption_kwargs': {},
 | 
				
			||||||
 | 
					    # },
 | 
				
			||||||
 | 
					    # {
 | 
				
			||||||
 | 
					    #     'name': "MNIST on variational auto-encoder",
 | 
				
			||||||
 | 
					    #     'encoder_model': "models.variational_encoder.VariationalAutoEncoder",
 | 
				
			||||||
 | 
					    #     'encoder_kwargs': {},
 | 
				
			||||||
 | 
					    #     'dataset_model': "models.mnist_dataset.MNISTDataset",
 | 
				
			||||||
 | 
					    #     'dataset_kwargs': {"path": "mnist"},
 | 
				
			||||||
 | 
					    #     'corruption_model': "models.gaussian_corruption.GaussianCorruption",
 | 
				
			||||||
 | 
					    #     'corruption_kwargs': {},
 | 
				
			||||||
 | 
					    # },
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # US Weather Events dataset
 | 
				
			||||||
 | 
					    # {
 | 
				
			||||||
 | 
					    #     'name': "US Weather Events on basic auto-encoder",
 | 
				
			||||||
 | 
					    #     'encoder_model': "models.basic_encoder.BasicAutoEncoder",
 | 
				
			||||||
 | 
					    #     'encoder_kwargs': {},
 | 
				
			||||||
 | 
					    #     'dataset_model': "models.usweather_dataset.USWeatherEventsDataset",
 | 
				
			||||||
 | 
					    #     'dataset_kwargs': {"path": "weather-events"},
 | 
				
			||||||
 | 
					    #     'corruption_model': "models.random_corruption.RandomCorruption",
 | 
				
			||||||
 | 
					    #     'corruption_kwargs': {},
 | 
				
			||||||
 | 
					    # },
 | 
				
			||||||
 | 
					    # {
 | 
				
			||||||
 | 
					    #     'name': "US Weather Events on sparse L1 auto-encoder",
 | 
				
			||||||
 | 
					    #     'encoder_model': "models.sparse_encoder.SparseL1AutoEncoder",
 | 
				
			||||||
 | 
					    #     'encoder_kwargs': {},
 | 
				
			||||||
 | 
					    #     'dataset_model': "models.usweather_dataset.USWeatherEventsDataset",
 | 
				
			||||||
 | 
					    #     'dataset_kwargs': {"path": "weather-events"},
 | 
				
			||||||
 | 
					    #     'corruption_model': "models.random_corruption.RandomCorruption",
 | 
				
			||||||
 | 
					    #     'corruption_kwargs': {},
 | 
				
			||||||
 | 
					    # },
 | 
				
			||||||
 | 
					    # {
 | 
				
			||||||
 | 
					    #     'name': "US Weather Events on denoising auto-encoder",
 | 
				
			||||||
 | 
					    #     'encoder_model': "models.denoising_encoder.DenoisingAutoEncoder",
 | 
				
			||||||
 | 
					    #     'encoder_kwargs': {'input_corruption_model': "models.random_corruption.RandomCorruption"},
 | 
				
			||||||
 | 
					    #     'dataset_model': "models.usweather_dataset.USWeatherEventsDataset",
 | 
				
			||||||
 | 
					    #     'dataset_kwargs': {"path": "weather-events"},
 | 
				
			||||||
 | 
					    #     'corruption_model': "models.random_corruption.RandomCorruption",
 | 
				
			||||||
 | 
					    #     'corruption_kwargs': {},
 | 
				
			||||||
 | 
					    # },
 | 
				
			||||||
 | 
					    # {
 | 
				
			||||||
 | 
					    #     'name': "US Weather Events on contractive auto-encoder",
 | 
				
			||||||
 | 
					    #     'encoder_model': "models.contractive_encoder.ContractiveAutoEncoder",
 | 
				
			||||||
 | 
					    #     'encoder_kwargs': {},
 | 
				
			||||||
 | 
					    #     'dataset_model': "models.usweather_dataset.USWeatherEventsDataset",
 | 
				
			||||||
 | 
					    #     'dataset_kwargs': {"path": "weather-events"},
 | 
				
			||||||
 | 
					    #     'corruption_model': "models.random_corruption.RandomCorruption",
 | 
				
			||||||
 | 
					    #     'corruption_kwargs': {},
 | 
				
			||||||
 | 
					    # },
 | 
				
			||||||
 | 
					    # {
 | 
				
			||||||
 | 
					    #     'name': "US Weather Events on variational auto-encoder",
 | 
				
			||||||
 | 
					    #     'encoder_model': "models.variational_encoder.VariationalAutoEncoder",
 | 
				
			||||||
 | 
					    #     'encoder_kwargs': {},
 | 
				
			||||||
 | 
					    #     'dataset_model': "models.usweather_dataset.USWeatherEventsDataset",
 | 
				
			||||||
 | 
					    #     'dataset_kwargs': {"path": "weather-events"},
 | 
				
			||||||
 | 
					    #     'corruption_model': "models.random_corruption.RandomCorruption",
 | 
				
			||||||
 | 
					    #     'corruption_kwargs': {},
 | 
				
			||||||
 | 
					    # },
 | 
				
			||||||
]
 | 
					]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
							
								
								
									
										2
									
								
								main.py
									
										
									
									
									
								
							
							
						
						
									
										2
									
								
								main.py
									
										
									
									
									
								
							| 
						 | 
					@ -44,7 +44,7 @@ def run_tests():
 | 
				
			||||||
        test_run = TestRun(dataset=dataset, encoder=encoder, corruption=corruption)
 | 
					        test_run = TestRun(dataset=dataset, encoder=encoder, corruption=corruption)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Run TestRun
 | 
					        # Run TestRun
 | 
				
			||||||
        test_run.run(retrain=True)
 | 
					        test_run.run(retrain=False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Cleanup to avoid out-of-memory situations when running lots of tests
 | 
					        # Cleanup to avoid out-of-memory situations when running lots of tests
 | 
				
			||||||
        del test_run
 | 
					        del test_run
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -73,6 +73,7 @@ class ContractiveAutoEncoder(BaseEncoder):
 | 
				
			||||||
        weights = self.state_dict()['encoder.2.weight']
 | 
					        weights = self.state_dict()['encoder.2.weight']
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Hadamard product
 | 
					        # Hadamard product
 | 
				
			||||||
 | 
					        if len(hidden_output.shape) > 2:
 | 
				
			||||||
            hidden_output = hidden_output.reshape(hidden_output.shape[0], hidden_output.shape[2])
 | 
					            hidden_output = hidden_output.reshape(hidden_output.shape[0], hidden_output.shape[2])
 | 
				
			||||||
        dh = hidden_output * (1 - hidden_output)
 | 
					        dh = hidden_output * (1 - hidden_output)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,7 +1,6 @@
 | 
				
			||||||
import csv
 | 
					import csv
 | 
				
			||||||
import os
 | 
					import os
 | 
				
			||||||
from collections import defaultdict
 | 
					from collections import defaultdict
 | 
				
			||||||
from datetime import datetime
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
from typing import Optional
 | 
					from typing import Optional
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -26,31 +25,22 @@ class USWeatherLoss(_Loss):
 | 
				
			||||||
    def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
 | 
					    def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
 | 
				
			||||||
        losses = []
 | 
					        losses = []
 | 
				
			||||||
        start = 0
 | 
					        start = 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        length = len(self.dataset._labels['Type'])
 | 
					        length = len(self.dataset._labels['Type'])
 | 
				
			||||||
        # Type is 1-hot encoded, so use cross entropy loss
 | 
					        # Type is 1-hot encoded, so use cross entropy loss
 | 
				
			||||||
        losses.append(self.ce_loss(input[start:start+length], torch.argmax(target[start:start+length].long(), dim=1)))
 | 
					        losses.append(self.ce_loss(input[:, start:start+length], torch.argmax(target[:, start:start+length].long(), dim=1)))
 | 
				
			||||||
        start += length
 | 
					        start += length
 | 
				
			||||||
        length = len(self.dataset._labels['Severity'])
 | 
					        length = len(self.dataset._labels['Severity'])
 | 
				
			||||||
        # Severity is 1-hot encoded, so use cross entropy loss
 | 
					        # Severity is 1-hot encoded, so use cross entropy loss
 | 
				
			||||||
        losses.append(self.ce_loss(input[start:start+length], torch.argmax(target[start:start+length].long(), dim=1)))
 | 
					        losses.append(self.ce_loss(input[:, start:start+length], torch.argmax(target[:, start:start+length].long(), dim=1)))
 | 
				
			||||||
        start += length
 | 
					        start += length
 | 
				
			||||||
        # Start time is a number, so use L1 loss
 | 
					 | 
				
			||||||
        losses.append(self.l1_loss(input[start], target[start]))
 | 
					 | 
				
			||||||
        # End time is a number, so use L1 loss
 | 
					 | 
				
			||||||
        losses.append(self.l1_loss(input[start + 1], target[start + 1]))
 | 
					 | 
				
			||||||
        start += 2
 | 
					 | 
				
			||||||
        length = len(self.dataset._labels['TimeZone'])
 | 
					        length = len(self.dataset._labels['TimeZone'])
 | 
				
			||||||
        # TimeZone is 1-hot encoded, so use cross entropy loss
 | 
					        # TimeZone is 1-hot encoded, so use cross entropy loss
 | 
				
			||||||
        losses.append(self.ce_loss(input[start:start+length], torch.argmax(target[start:start+length].long(), dim=1)))
 | 
					        losses.append(self.ce_loss(input[:, start:start+length], torch.argmax(target[:, start:start+length].long(), dim=1)))
 | 
				
			||||||
        start += length
 | 
					        start += length
 | 
				
			||||||
        # Location latitude is a number, so use L1 loss
 | 
					 | 
				
			||||||
        losses.append(self.l1_loss(input[start], target[start]))
 | 
					 | 
				
			||||||
        # Location longitude is a number, so use L1 loss
 | 
					 | 
				
			||||||
        losses.append(self.l1_loss(input[start + 1], target[start + 1]))
 | 
					 | 
				
			||||||
        start += 2
 | 
					 | 
				
			||||||
        length = len(self.dataset._labels['State'])
 | 
					        length = len(self.dataset._labels['State'])
 | 
				
			||||||
        # State is 1-hot encoded, so use cross entropy loss
 | 
					        # State is 1-hot encoded, so use cross entropy loss
 | 
				
			||||||
        losses.append(self.ce_loss(input[start:start+length], torch.argmax(target[start:start+length].long(), dim=1)))
 | 
					        losses.append(self.ce_loss(input[:, start:start+length], torch.argmax(target[:, start:start+length].long(), dim=1)))
 | 
				
			||||||
        return sum(losses)
 | 
					        return sum(losses)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -110,23 +100,9 @@ class USWeatherEventsDataset(BaseDataset):
 | 
				
			||||||
                        # 1-hot encoded event severity columns
 | 
					                        # 1-hot encoded event severity columns
 | 
				
			||||||
                        [int(row['Severity'] == self._labels['Severity'][i]) for i in range(len(self._labels['Severity']))] +
 | 
					                        [int(row['Severity'] == self._labels['Severity'][i]) for i in range(len(self._labels['Severity']))] +
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                        [
 | 
					 | 
				
			||||||
                            # Start time as unix timestamp
 | 
					 | 
				
			||||||
                            datetime.strptime(row['StartTime(UTC)'], "%Y-%m-%d %H:%M:%S").timestamp(),
 | 
					 | 
				
			||||||
                            # End time as unix timestamp
 | 
					 | 
				
			||||||
                            datetime.strptime(row['EndTime(UTC)'], "%Y-%m-%d %H:%M:%S").timestamp()
 | 
					 | 
				
			||||||
                        ] +
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                        # 1-hot encoded event timezone columns
 | 
					                        # 1-hot encoded event timezone columns
 | 
				
			||||||
                        [int(row['TimeZone'] == self._labels['TimeZone'][i]) for i in range(len(self._labels['TimeZone']))] +
 | 
					                        [int(row['TimeZone'] == self._labels['TimeZone'][i]) for i in range(len(self._labels['TimeZone']))] +
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                        [
 | 
					 | 
				
			||||||
                            # Location Latitude as float
 | 
					 | 
				
			||||||
                            float(row['LocationLat']),
 | 
					 | 
				
			||||||
                            # Location Longitude as float
 | 
					 | 
				
			||||||
                            float(row['LocationLng']),
 | 
					 | 
				
			||||||
                        ] +
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                        # 1-hot encoded event state columns
 | 
					                        # 1-hot encoded event state columns
 | 
				
			||||||
                        [int(row['State'] == self._labels['State'][i]) for i in range(len(self._labels['State']))]
 | 
					                        [int(row['State'] == self._labels['State'][i]) for i in range(len(self._labels['State']))]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -151,7 +127,7 @@ class USWeatherEventsDataset(BaseDataset):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # train_data, test_data = self._data[:2500000], self._data[2500000:]
 | 
					        # train_data, test_data = self._data[:2500000], self._data[2500000:]
 | 
				
			||||||
        # Speed up training a bit
 | 
					        # Speed up training a bit
 | 
				
			||||||
        train_data, test_data = self._data[:50000], self._data[100000:150000]
 | 
					        train_data, test_data = self._data[:250000], self._data[250000:500000]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self._trainset = self.__class__.get_new(name=f"{self.name} Training", data=train_data, labels=self._labels,
 | 
					        self._trainset = self.__class__.get_new(name=f"{self.name} Training", data=train_data, labels=self._labels,
 | 
				
			||||||
                                                source_path=self._source_path)
 | 
					                                                source_path=self._source_path)
 | 
				
			||||||
| 
						 | 
					@ -167,13 +143,11 @@ class USWeatherEventsDataset(BaseDataset):
 | 
				
			||||||
            size = 0
 | 
					            size = 0
 | 
				
			||||||
            size += len(labels['Type'])
 | 
					            size += len(labels['Type'])
 | 
				
			||||||
            size += len(labels['Severity'])
 | 
					            size += len(labels['Severity'])
 | 
				
			||||||
            size += 2
 | 
					 | 
				
			||||||
            size += len(labels['TimeZone'])
 | 
					            size += len(labels['TimeZone'])
 | 
				
			||||||
            size += 2
 | 
					 | 
				
			||||||
            size += len(labels['State'])
 | 
					            size += len(labels['State'])
 | 
				
			||||||
            return size
 | 
					            return size
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            return 69
 | 
					            return 65
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __getitem__(self, item):
 | 
					    def __getitem__(self, item):
 | 
				
			||||||
        data = self._data[item]
 | 
					        data = self._data[item]
 | 
				
			||||||
| 
						 | 
					@ -196,15 +170,9 @@ class USWeatherEventsDataset(BaseDataset):
 | 
				
			||||||
        length = len(self._labels['Severity'])
 | 
					        length = len(self._labels['Severity'])
 | 
				
			||||||
        severities = output[start:start+length]
 | 
					        severities = output[start:start+length]
 | 
				
			||||||
        start += length
 | 
					        start += length
 | 
				
			||||||
        start_time = output[start]
 | 
					 | 
				
			||||||
        end_time = output[start+1]
 | 
					 | 
				
			||||||
        start += 2
 | 
					 | 
				
			||||||
        length = len(self._labels['TimeZone'])
 | 
					        length = len(self._labels['TimeZone'])
 | 
				
			||||||
        timezones = output[start:start+length]
 | 
					        timezones = output[start:start+length]
 | 
				
			||||||
        start += length
 | 
					        start += length
 | 
				
			||||||
        location_lat = output[start]
 | 
					 | 
				
			||||||
        location_lng = output[start+1]
 | 
					 | 
				
			||||||
        start += 2
 | 
					 | 
				
			||||||
        length = len(self._labels['State'])
 | 
					        length = len(self._labels['State'])
 | 
				
			||||||
        states = output[start:start+length]
 | 
					        states = output[start:start+length]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -214,14 +182,10 @@ class USWeatherEventsDataset(BaseDataset):
 | 
				
			||||||
        timezone = self._labels['TimeZone'][timezones.index(max(timezones))]
 | 
					        timezone = self._labels['TimeZone'][timezones.index(max(timezones))]
 | 
				
			||||||
        state = self._labels['State'][states.index(max(states))]
 | 
					        state = self._labels['State'][states.index(max(states))]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Convert timestamp float into string time
 | 
					        return [event_type, severity, timezone, state]
 | 
				
			||||||
        start_time = datetime.fromtimestamp(start_time).strftime("%Y-%m-%d %H:%M:%S")
 | 
					 | 
				
			||||||
        end_time = datetime.fromtimestamp(end_time).strftime("%Y-%m-%d %H:%M:%S")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        return [event_type, severity, start_time, end_time, timezone, location_lat, location_lng, state]
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def save_batch_to_sample(self, batch, filename):
 | 
					    def save_batch_to_sample(self, batch, filename):
 | 
				
			||||||
        res = ["Type,Severity,StartTime(UTC),EndTime(UTC),TimeZone,LocationLat,LocationLng,State\n"]
 | 
					        res = ["Type,Severity,TimeZone,State\n"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        for row in batch:
 | 
					        for row in batch:
 | 
				
			||||||
            row = row.tolist()
 | 
					            row = row.tolist()
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Reference in a new issue