RP_AutoEncoderComparison/utils.py

124 lines
4.4 KiB
Python

import importlib
import os
from string import Template
from tabulate import tabulate
import matplotlib.pyplot as plt
from config import TRAIN_TEMP_DATA_BASE_PATH
def load_dotted_path(path):
split_path = path.split(".")
modulename, classname = ".".join(split_path[:-1]), split_path[-1]
model = getattr(importlib.import_module(modulename), classname)
return model
def training_header():
"""Prints a training header to the console"""
print('| Epoch | Avg. loss | Train Acc. | Test Acc. | Elapsed | ETA |\n'
'|-------+-----------+------------+------------+----------+----------|')
def training_epoch_stats(epoch, running_loss, total, train_acc, test_acc, elapsed, eta):
"""This function is called every time an epoch ends, all of the important training statistics are taken as
an input and outputted to the console.
:param epoch: current training epoch epoch
:type epoch: int
:param running_loss: current total running loss of the model
:type running_loss: float
:param total: number of the images that model has trained on
:type total: int
:param train_acc: models accuracy on the provided train dataset
:type train_acc: str
:param test_acc: models accuracy on the optionally provided eval dataset
:type test_acc: str
:param elapsed: total time it took to run the epoch
:type elapsed: str
:param eta: an estimated time when training will end
:type eta: str
"""
print('| {:5d} | {:.7f} | {:10s} | {:10s} | {:8s} | {} |'
.format(epoch, running_loss / total, train_acc, test_acc, elapsed, eta))
def training_finished(start, now):
"""Outputs elapsed training time to the console.
:param start: time when the training has started
:type start: datetime.datetime
:param now: current time
:type now: datetime.datetime
"""
print('Training finished, total time elapsed: {}\n'.format(now - start))
def accuracy_summary_basic(total, correct, acc):
"""Outputs model evaluation statistics to the console if verbosity is set to 1.
:param total: number of total objects model was evaluated on
:type total: int
:param correct: number of correctly classified objects
:type correct: int
:param acc: overall accuracy of the model
:type acc: float
"""
print('Total: {} -- Correct: {} -- Accuracy: {}%\n'.format(total, correct, acc))
def accuracy_summary_extended(classes, class_total, class_correct):
"""Outputs model evaluation statistics to the console if verbosity is equal or greater than 2.
:param classes: list with class names on which model was evaluated
:type classes: list
:param class_total: list with a number of classes on which model was evaluated
:type class_total: list
:param class_correct: list with a number of properly classified classes
:type class_correct: list
"""
print('| {} | {} | {} |'.format(type(classes), type(class_total), type(class_correct)))
table = []
for i, c in enumerate(classes):
if class_total[i] != 0:
class_acc = '{:.1f}%'.format(100 * class_correct[i] / class_total[i])
else:
class_acc = '-'
table.append([c, class_total[i], class_correct[i], class_acc])
print(tabulate(table, headers=['Class', 'Total', 'Correct', 'Acc'], tablefmt='orgtbl'), '\n')
def strfdelta(tdelta, fmt='%H:%M:%S'):
"""Similar to strftime, but this one is for a datetime.timedelta object.
:param tdelta: datetime object containing some time difference
:type tdelta: datetime.timedelta
:param fmt: string with a format
:type fmt: str
:return: string containing a time in a given format
"""
class DeltaTemplate(Template):
delimiter = "%"
d = {"D": tdelta.days}
hours, rem = divmod(tdelta.seconds, 3600)
minutes, seconds = divmod(rem, 60)
d["H"] = '{:02d}'.format(hours)
d["M"] = '{:02d}'.format(minutes)
d["S"] = '{:02d}'.format(seconds)
t = DeltaTemplate(fmt)
return t.substitute(**d)
def save_train_loss_graph(train_loss, filename):
plt.figure()
plt.plot(train_loss)
plt.title('Train Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.yscale('log')
plt.savefig(os.path.join(TRAIN_TEMP_DATA_BASE_PATH, f'{filename}_loss.png'))
def save_train_loss_values(train_loss, filename):
with open(os.path.join(TRAIN_TEMP_DATA_BASE_PATH, f'{filename}_loss.csv'), 'w') as f:
f.write(",".join(map(str, train_loss)))