Allow input shape to be defined by dataset, save loss values as csv after training, implemented basic version of US weather dataset, but it is very slow and has bad results probably due to input encoding issue
This commit is contained in:
parent
f76374111c
commit
f6a19c4921
|
@ -25,7 +25,7 @@ args=('output.log', 'w')
|
|||
|
||||
[handler_consoleHandler]
|
||||
class=StreamHandler
|
||||
level=INFO
|
||||
level=DEBUG
|
||||
formatter=simpleFormatter
|
||||
args=(sys.stdout,)
|
||||
|
||||
|
|
2
main.py
2
main.py
|
@ -34,6 +34,8 @@ def run_tests():
|
|||
|
||||
# Create TestRun instance
|
||||
dataset = dataset_model(**test['dataset_kwargs'])
|
||||
if test['encoder_kwargs'].get('input_shape', None) is None:
|
||||
test['encoder_kwargs']['input_shape'] = dataset.get_input_shape()
|
||||
encoder = encoder_model(**test['encoder_kwargs'])
|
||||
encoder.after_init()
|
||||
corruption = corruption_model(**test['corruption_kwargs'])
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
import torch
|
||||
|
||||
from models.base_dataset import BaseDataset
|
||||
|
||||
|
||||
|
@ -30,8 +32,8 @@ class NoCorruption(BaseCorruption):
|
|||
name = "No corruption"
|
||||
|
||||
@classmethod
|
||||
def corrupt_image(cls, image):
|
||||
return image
|
||||
def corrupt_image(cls, image: torch.Tensor):
|
||||
return image.numpy()
|
||||
|
||||
@classmethod
|
||||
def corrupt_dataset(cls, dataset: BaseDataset) -> BaseDataset:
|
||||
|
|
|
@ -62,6 +62,9 @@ class BaseDataset(Dataset):
|
|||
self._source_path = path
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_input_shape(self):
|
||||
return None
|
||||
|
||||
def _subdivide(self, amount: Union[int, float]):
|
||||
if self._data is None:
|
||||
raise ValueError("Cannot subdivide! Data not loaded, call `load()` first to load data")
|
||||
|
|
|
@ -117,7 +117,7 @@ class BaseEncoder(torch.nn.Module):
|
|||
outputs = None
|
||||
for epoch in range(epochs):
|
||||
self.log.debug(f"Start training epoch {epoch + 1}...")
|
||||
loss = 0
|
||||
loss = []
|
||||
for i, batch_features in enumerate(train_loader):
|
||||
# # load batch features to the active device
|
||||
# batch_features = batch_features.to(self.device)
|
||||
|
@ -151,15 +151,15 @@ class BaseEncoder(torch.nn.Module):
|
|||
self.optimizer.step()
|
||||
|
||||
# add the mini-batch training loss to epoch loss
|
||||
loss += train_loss.item()
|
||||
loss.append(train_loss.item())
|
||||
|
||||
# Print progress every 50 batches
|
||||
if i % 50 == 0:
|
||||
if i % 100 == 0:
|
||||
self.log.debug(f" progress: [{i * len(batch_features)}/{len(train_loader.dataset)} "
|
||||
f"({(100 * i / len(train_loader)):.0f}%)]")
|
||||
|
||||
# compute the epoch training loss
|
||||
loss = loss / len(train_loader)
|
||||
loss = sum(loss) / len(loss)
|
||||
|
||||
# display the epoch training loss
|
||||
self.log.info("epoch : {}/{}, loss = {:.6f}".format(epoch + 1, epochs, loss))
|
||||
|
|
|
@ -50,6 +50,9 @@ class Cifar10Dataset(BaseDataset):
|
|||
|
||||
self.log.info(f"Loaded {self}, divided into {self._trainset} and {self._testset}")
|
||||
|
||||
def get_input_shape(self):
|
||||
return 3072 # 32x32x3 (32x32px, 3 colors)
|
||||
|
||||
def __getitem__(self, item):
|
||||
# Get image data
|
||||
img = self._data[item]
|
||||
|
@ -87,7 +90,7 @@ class Cifar10Dataset(BaseDataset):
|
|||
return img
|
||||
|
||||
def save_batch_to_sample(self, batch, filename):
|
||||
img = batch.view(batch.size(0), 3, 32, 32)
|
||||
img = batch.view(batch.size(0), 3, 32, 32)[:48]
|
||||
save_image(img, f"{filename}.png")
|
||||
|
||||
def calculate_score(self, originals, reconstruction, device):
|
||||
|
|
|
@ -41,6 +41,9 @@ class MNISTDataset(BaseDataset):
|
|||
|
||||
self.log.info(f"Loaded {self}, divided into {self._trainset} and {self._testset}")
|
||||
|
||||
def get_input_shape(self):
|
||||
return 784 # 28x28x1 (28x28px, 1 color)
|
||||
|
||||
def __getitem__(self, item):
|
||||
# Get image data
|
||||
img = self._data[item]
|
||||
|
@ -55,7 +58,7 @@ class MNISTDataset(BaseDataset):
|
|||
return img
|
||||
|
||||
def save_batch_to_sample(self, batch, filename):
|
||||
img = batch.view(batch.size(0), 1, 28, 28)
|
||||
img = batch.view(batch.size(0), 1, 28, 28)[:48]
|
||||
save_image(img, f"{filename}.png")
|
||||
|
||||
def calculate_score(self, originals, reconstruction, device):
|
||||
|
|
48
models/random_corruption.py
Normal file
48
models/random_corruption.py
Normal file
|
@ -0,0 +1,48 @@
|
|||
import random
|
||||
|
||||
from torch import Tensor
|
||||
|
||||
from models.base_corruption import BaseCorruption
|
||||
from models.base_dataset import BaseDataset
|
||||
import numpy
|
||||
|
||||
|
||||
def add_noise(image):
|
||||
if isinstance(image, Tensor):
|
||||
image = image.numpy()
|
||||
image = image.astype(numpy.float32)
|
||||
|
||||
# 90% chance to corrupt something
|
||||
if random.random() < 0.9:
|
||||
corrupt_index1, corrupt_index2 = random.sample(range(len(image)), 2)
|
||||
image[corrupt_index1] = 0
|
||||
# 10% chance to corrupt a second column
|
||||
if random.random() < 0.1:
|
||||
image[corrupt_index2] = 0
|
||||
|
||||
return image
|
||||
|
||||
|
||||
class RandomCorruption(BaseCorruption):
|
||||
"""
|
||||
Corruption model that clears random fields of data.
|
||||
"""
|
||||
name = "Gaussian"
|
||||
|
||||
@classmethod
|
||||
def corrupt_image(cls, image: Tensor):
|
||||
return add_noise(image.numpy())
|
||||
|
||||
@classmethod
|
||||
def corrupt_dataset(cls, dataset: BaseDataset) -> BaseDataset:
|
||||
data = [cls.corrupt_image(x) for x in dataset]
|
||||
# data = list(map(add_noise, dataset._data))
|
||||
train_set = cls.corrupt_dataset(dataset.get_train()) if dataset.has_train() else None
|
||||
test_set = cls.corrupt_dataset(dataset.get_test()) if dataset.has_test() else None
|
||||
return dataset.__class__.get_new(
|
||||
name=f"{dataset.name} Corrupted",
|
||||
data=data,
|
||||
labels=dataset._labels,
|
||||
source_path=dataset._source_path,
|
||||
train_set=train_set,
|
||||
test_set=test_set)
|
|
@ -4,7 +4,7 @@ import multiprocessing
|
|||
from models.base_corruption import BaseCorruption
|
||||
from models.base_dataset import BaseDataset
|
||||
from models.base_encoder import BaseEncoder
|
||||
from utils import save_train_loss_graph
|
||||
from utils import save_train_loss_graph, save_train_loss_values
|
||||
|
||||
|
||||
class TestRun:
|
||||
|
@ -41,7 +41,7 @@ class TestRun:
|
|||
if retrain:
|
||||
# Train encoder
|
||||
self.log.info("Training auto-encoder...")
|
||||
train_loss = self.encoder.train_encoder(self.dataset, epochs=50, num_workers=multiprocessing.cpu_count() - 1)
|
||||
train_losses = self.encoder.train_encoder(self.dataset, epochs=50, num_workers=multiprocessing.cpu_count() - 1)
|
||||
|
||||
if save_model:
|
||||
self.log.info("Saving auto-encoder model...")
|
||||
|
@ -49,7 +49,8 @@ class TestRun:
|
|||
|
||||
# Save train loss graph
|
||||
self.log.info("Saving loss graph...")
|
||||
save_train_loss_graph(train_loss, f"{self.encoder.name}_{self.dataset.name}")
|
||||
save_train_loss_graph(train_losses, f"{self.encoder.name}_{self.dataset.name}")
|
||||
save_train_loss_values(train_losses, f"{self.encoder.name}_{self.dataset.name}")
|
||||
else:
|
||||
self.log.info("Loading saved auto-encoder...")
|
||||
load_success = self.encoder.load_model(f"{self.encoder.name}_{self.dataset.name}")
|
||||
|
|
192
models/usweather_dataset.py
Normal file
192
models/usweather_dataset.py
Normal file
|
@ -0,0 +1,192 @@
|
|||
import csv
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import numpy
|
||||
import torch
|
||||
|
||||
from config import DATASET_STORAGE_BASE_PATH
|
||||
from models.base_dataset import BaseDataset
|
||||
|
||||
|
||||
class USWeatherEventsDataset(BaseDataset):
|
||||
# Source: https://smoosavi.org/datasets/lstw
|
||||
# https://www.kaggle.com/sobhanmoosavi/us-weather-events
|
||||
name = "US Weather Events"
|
||||
|
||||
def transform(self, data):
|
||||
return torch.from_numpy(numpy.array(data, numpy.float32, copy=False))
|
||||
|
||||
def unpickle(self, filename):
|
||||
import pickle
|
||||
with open(filename, 'rb') as fo:
|
||||
dict = pickle.load(fo, encoding='bytes')
|
||||
return dict
|
||||
|
||||
def load(self, name: Optional[str] = None, path: Optional[str] = None):
|
||||
if name is not None:
|
||||
self.name = name
|
||||
if path is not None:
|
||||
self._source_path = path
|
||||
|
||||
self._data = []
|
||||
self._labels = defaultdict(list)
|
||||
|
||||
# Load from cache pickle file if it exists, else create cache file and load from csv
|
||||
if os.path.isfile(os.path.join(DATASET_STORAGE_BASE_PATH, self._source_path, "weather_py_data.pickle"))\
|
||||
and os.path.isfile(os.path.join(DATASET_STORAGE_BASE_PATH, self._source_path, "weather_py_labels.pickle")):
|
||||
self.log.info("Loading cached version of dataset...")
|
||||
self._data = self.unpickle(os.path.join(DATASET_STORAGE_BASE_PATH, self._source_path, "weather_py_data.pickle"))
|
||||
self._labels = self.unpickle(os.path.join(DATASET_STORAGE_BASE_PATH, self._source_path, "weather_py_labels.pickle"))
|
||||
else:
|
||||
self.log.info("Creating cached version of dataset...")
|
||||
size = 5023709
|
||||
with open(os.path.join(DATASET_STORAGE_BASE_PATH, self._source_path, "WeatherEvents_Aug16_June20_Publish.csv")) as f:
|
||||
data = csv.DictReader(f)
|
||||
# Build label map before processing for 1-hot encoding
|
||||
self.log.info("Preparing labels...")
|
||||
for i, row in enumerate(data):
|
||||
if i % 500000 == 0:
|
||||
self.log.debug(f"{i} / ~{size} ({((i / size) * 100):.4f}%)")
|
||||
|
||||
for label_type in ['Type', 'Severity', 'TimeZone', 'State']:
|
||||
if row[label_type] not in self._labels[label_type]:
|
||||
self._labels[label_type].append(row[label_type])
|
||||
|
||||
with open(os.path.join(DATASET_STORAGE_BASE_PATH, self._source_path, "WeatherEvents_Aug16_June20_Publish.csv")) as f:
|
||||
data = csv.DictReader(f)
|
||||
self.log.info("Processing data...")
|
||||
for i, row in enumerate(data):
|
||||
self._data.append(numpy.array([] +
|
||||
# Event ID doesn't matter
|
||||
# 1-hot encoded event type columns
|
||||
[int(row['Type'] == self._labels['Type'][i]) for i in range(len(self._labels['Type']))] +
|
||||
|
||||
# 1-hot encoded event severity columns
|
||||
[int(row['Severity'] == self._labels['Severity'][i]) for i in range(len(self._labels['Severity']))] +
|
||||
|
||||
[
|
||||
# Start time as unix timestamp
|
||||
datetime.strptime(row['StartTime(UTC)'], "%Y-%m-%d %H:%M:%S").timestamp(),
|
||||
# End time as unix timestamp
|
||||
datetime.strptime(row['EndTime(UTC)'], "%Y-%m-%d %H:%M:%S").timestamp()
|
||||
] +
|
||||
|
||||
# 1-hot encoded event timezone columns
|
||||
[int(row['TimeZone'] == self._labels['TimeZone'][i]) for i in range(len(self._labels['TimeZone']))] +
|
||||
|
||||
[
|
||||
# Location Latitude as float
|
||||
float(row['LocationLat']),
|
||||
# Location Longitude as float
|
||||
float(row['LocationLng']),
|
||||
] +
|
||||
|
||||
# 1-hot encoded event state columns
|
||||
[int(row['State'] == self._labels['State'][i]) for i in range(len(self._labels['State']))]
|
||||
|
||||
# Airport code, city, county and zip code are not considered,
|
||||
# as they have too many unique values for 1-hot encoding.
|
||||
))
|
||||
|
||||
if i % 500000 == 0:
|
||||
self.log.debug(f"{i} / ~{size} ({((i / size) * 100):.4f}%)")
|
||||
|
||||
self.log.info("Shuffling data...")
|
||||
rng = numpy.random.default_rng()
|
||||
rng.shuffle(self._data)
|
||||
|
||||
self.log.info("Saving cached version...")
|
||||
import pickle
|
||||
with open(os.path.join(DATASET_STORAGE_BASE_PATH, self._source_path, "weather_py_data.pickle"), 'wb') as f:
|
||||
pickle.dump(self._data, f)
|
||||
with open(os.path.join(DATASET_STORAGE_BASE_PATH, self._source_path, "weather_py_labels.pickle"), 'wb') as f:
|
||||
pickle.dump(dict(self._labels), f)
|
||||
self.log.info("Cached version created.")
|
||||
|
||||
train_data, test_data = self._data[:2500000], self._data[2500000:]
|
||||
|
||||
self._trainset = self.__class__.get_new(name=f"{self.name} Training", data=train_data, labels=self._labels,
|
||||
source_path=self._source_path)
|
||||
|
||||
self._testset = self.__class__.get_new(name=f"{self.name} Testing", data=test_data, labels=self._labels,
|
||||
source_path=self._source_path)
|
||||
|
||||
self.log.info(f"Loaded {self}, divided into {self._trainset} and {self._testset}")
|
||||
|
||||
def get_input_shape(self):
|
||||
if os.path.isfile(os.path.join(DATASET_STORAGE_BASE_PATH, self._source_path, "weather_py_labels.pickle")):
|
||||
labels = self.unpickle(os.path.join(DATASET_STORAGE_BASE_PATH, self._source_path, "weather_py_labels.pickle"))
|
||||
size = 0
|
||||
size += len(labels['Type'])
|
||||
size += len(labels['Severity'])
|
||||
size += 2
|
||||
size += len(labels['TimeZone'])
|
||||
size += 2
|
||||
size += len(labels['State'])
|
||||
return size
|
||||
else:
|
||||
return 69
|
||||
|
||||
def __getitem__(self, item):
|
||||
data = self._data[item]
|
||||
|
||||
# Run transforms
|
||||
if self.transform is not None:
|
||||
data = self.transform(data)
|
||||
|
||||
return data
|
||||
|
||||
def save_batch_to_sample(self, batch, filename):
|
||||
res = ["Type,Severity,StartTime(UTC),EndTime(UTC),TimeZone,LocationLat,LocationLng,State\n"]
|
||||
|
||||
for row in batch:
|
||||
# Get 1-hot encoded values as list per value, and other values as value
|
||||
row = row.tolist()
|
||||
start = 0
|
||||
length = len(self._labels['Type'])
|
||||
event_types = row[start:start+length]
|
||||
start += length
|
||||
length = len(self._labels['Severity'])
|
||||
severities = row[start:start+length]
|
||||
start += length
|
||||
start_time = row[start]
|
||||
end_time = row[start+1]
|
||||
start += 2
|
||||
length = len(self._labels['TimeZone'])
|
||||
timezones = row[start:start+length]
|
||||
start += length
|
||||
location_lat = row[start]
|
||||
location_lng = row[start+1]
|
||||
start += 2
|
||||
length = len(self._labels['State'])
|
||||
states = row[start:start+length]
|
||||
|
||||
# Convert 1-hot encodings to normal labels, assume highest value as the true value.
|
||||
event_type = self._labels['Type'][event_types.index(max(event_types))]
|
||||
severity = self._labels['Severity'][severities.index(max(severities))]
|
||||
timezone = self._labels['TimeZone'][timezones.index(max(timezones))]
|
||||
state = self._labels['State'][states.index(max(states))]
|
||||
|
||||
# Convert timestamp float into string time
|
||||
start_time = datetime.fromtimestamp(start_time).strftime("%Y-%m-%d %H:%M:%S")
|
||||
end_time = datetime.fromtimestamp(end_time).strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
res.append(f"{event_type},{severity},{start_time},{end_time},{timezone},{location_lat},{location_lng},{state}\n")
|
||||
|
||||
with open(f"{filename}.csv", "w") as f:
|
||||
f.writelines(res)
|
||||
|
||||
def calculate_score(self, originals, reconstruction, device):
|
||||
originals = originals.to(device)
|
||||
reconstruction = reconstruction.to(device)
|
||||
|
||||
total_score = 0
|
||||
for i in range(len(originals)):
|
||||
original, recon = originals[i], reconstruction[i]
|
||||
total_score += sum(int(original[j] == recon[j]) for j in range(len(original))) / len(original)
|
||||
|
||||
return total_score / len(originals)
|
5
utils.py
5
utils.py
|
@ -116,3 +116,8 @@ def save_train_loss_graph(train_loss, filename):
|
|||
plt.ylabel('Loss')
|
||||
plt.yscale('log')
|
||||
plt.savefig(os.path.join(TRAIN_TEMP_DATA_BASE_PATH, f'{filename}_loss.png'))
|
||||
|
||||
|
||||
def save_train_loss_values(train_loss, filename):
|
||||
with open(os.path.join(TRAIN_TEMP_DATA_BASE_PATH, f'{filename}_loss.csv'), 'w') as f:
|
||||
f.write(",".join(map(str, train_loss)))
|
||||
|
|
Loading…
Reference in a new issue