RP_AutoEncoderComparison/main.py

58 lines
2.3 KiB
Python

import config
import logging.config
# Get logging as early as possible!
logging.config.fileConfig("logging.conf")
from utils import load_dotted_path
from models.base_corruption import BaseCorruption
from models.base_dataset import BaseDataset
from models.base_encoder import BaseEncoder
from models.test_run import TestRun
def run_tests():
logger = logging.getLogger("main.run_tests")
for test in config.TEST_RUNS:
logger.info(f"Running test run '{test['name']}'...")
# Load dataset model
dataset_model = load_dotted_path(test['dataset_model'])
assert issubclass(dataset_model, BaseDataset), f"Invalid dataset_model: '{dataset_model.__name__}', should be subclass of BaseDataset."
logger.debug(f"Using dataset model '{dataset_model.__name__}'")
# Load auto-encoder model
encoder_model = load_dotted_path(test['encoder_model'])
assert issubclass(encoder_model, BaseEncoder), f"Invalid encoder_model: '{encoder_model.__name__}', should be subclass of BaseEncoder."
logger.debug(f"Using encoder model '{encoder_model.__name__}'")
# Load corruption model
corruption_model = load_dotted_path(test['corruption_model'])
assert issubclass(corruption_model, BaseCorruption), f"Invalid corruption_model: '{corruption_model.__name__}', should be subclass of BaseCorruption."
logger.debug(f"Using corruption model '{corruption_model.__name__}'")
# Create TestRun instance
dataset = dataset_model(**test['dataset_kwargs'])
if test['encoder_kwargs'].get('input_shape', None) is None:
test['encoder_kwargs']['input_shape'] = dataset.get_input_shape()
if test['encoder_kwargs'].get('loss_function', None) is None:
test['encoder_kwargs']['loss_function'] = dataset.get_loss_function()
encoder = encoder_model(**test['encoder_kwargs'])
encoder.after_init()
corruption = corruption_model(**test['corruption_kwargs'])
test_run = TestRun(dataset=dataset, encoder=encoder, corruption=corruption)
# Run TestRun
test_run.run(retrain=False)
# Cleanup to avoid out-of-memory situations when running lots of tests
del test_run
del corruption
del encoder
del dataset
if __name__ == '__main__':
run_tests()