58 lines
2.3 KiB
Python
58 lines
2.3 KiB
Python
import config
|
|
import logging.config
|
|
|
|
# Get logging as early as possible!
|
|
logging.config.fileConfig("logging.conf")
|
|
|
|
from utils import load_dotted_path
|
|
|
|
from models.base_corruption import BaseCorruption
|
|
from models.base_dataset import BaseDataset
|
|
from models.base_encoder import BaseEncoder
|
|
from models.test_run import TestRun
|
|
|
|
|
|
def run_tests():
|
|
logger = logging.getLogger("main.run_tests")
|
|
for test in config.TEST_RUNS:
|
|
logger.info(f"Running test run '{test['name']}'...")
|
|
|
|
# Load dataset model
|
|
dataset_model = load_dotted_path(test['dataset_model'])
|
|
assert issubclass(dataset_model, BaseDataset), f"Invalid dataset_model: '{dataset_model.__name__}', should be subclass of BaseDataset."
|
|
logger.debug(f"Using dataset model '{dataset_model.__name__}'")
|
|
|
|
# Load auto-encoder model
|
|
encoder_model = load_dotted_path(test['encoder_model'])
|
|
assert issubclass(encoder_model, BaseEncoder), f"Invalid encoder_model: '{encoder_model.__name__}', should be subclass of BaseEncoder."
|
|
logger.debug(f"Using encoder model '{encoder_model.__name__}'")
|
|
|
|
# Load corruption model
|
|
corruption_model = load_dotted_path(test['corruption_model'])
|
|
assert issubclass(corruption_model, BaseCorruption), f"Invalid corruption_model: '{corruption_model.__name__}', should be subclass of BaseCorruption."
|
|
logger.debug(f"Using corruption model '{corruption_model.__name__}'")
|
|
|
|
# Create TestRun instance
|
|
dataset = dataset_model(**test['dataset_kwargs'])
|
|
if test['encoder_kwargs'].get('input_shape', None) is None:
|
|
test['encoder_kwargs']['input_shape'] = dataset.get_input_shape()
|
|
if test['encoder_kwargs'].get('loss_function', None) is None:
|
|
test['encoder_kwargs']['loss_function'] = dataset.get_loss_function()
|
|
encoder = encoder_model(**test['encoder_kwargs'])
|
|
encoder.after_init()
|
|
corruption = corruption_model(**test['corruption_kwargs'])
|
|
test_run = TestRun(dataset=dataset, encoder=encoder, corruption=corruption)
|
|
|
|
# Run TestRun
|
|
test_run.run(retrain=False)
|
|
|
|
# Cleanup to avoid out-of-memory situations when running lots of tests
|
|
del test_run
|
|
del corruption
|
|
del encoder
|
|
del dataset
|
|
|
|
|
|
if __name__ == '__main__':
|
|
run_tests()
|