From 51cc5d1d307bca48fb2a01ce1058178b90d55843 Mon Sep 17 00:00:00 2001 From: Kevin Alberts Date: Tue, 24 Nov 2020 17:19:46 +0100 Subject: [PATCH] Initial, very basic framework for running comparison tests --- .gitignore | 223 ++++++++++++++++++++++++++++++++++++++ config.example.py | 14 +++ datasets/.gitkeep | 0 main.py | 57 ++++++++++ models/__init__.py | 0 models/base_corruption.py | 30 +++++ models/base_dataset.py | 71 ++++++++++++ models/base_encoder.py | 20 ++++ models/test_run.py | 29 +++++ saved_models/.gitkeep | 0 10 files changed, 444 insertions(+) create mode 100644 .gitignore create mode 100644 config.example.py create mode 100644 datasets/.gitkeep create mode 100644 main.py create mode 100644 models/__init__.py create mode 100644 models/base_corruption.py create mode 100644 models/base_dataset.py create mode 100644 models/base_encoder.py create mode 100644 models/test_run.py create mode 100644 saved_models/.gitkeep diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..ad27643 --- /dev/null +++ b/.gitignore @@ -0,0 +1,223 @@ +# Created by .ignore support plugin (hsz.mobi) +### Python template +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +### JetBrains template +# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider +# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 + +# User-specific stuff +.idea/**/workspace.xml +.idea/**/tasks.xml +.idea/**/usage.statistics.xml +.idea/**/dictionaries +.idea/**/shelf + +# Generated files +.idea/**/contentModel.xml + +# Sensitive or high-churn files +.idea/**/dataSources/ +.idea/**/dataSources.ids +.idea/**/dataSources.local.xml +.idea/**/sqlDataSources.xml +.idea/**/dynamic.xml +.idea/**/uiDesigner.xml +.idea/**/dbnavigator.xml + +# Gradle +.idea/**/gradle.xml +.idea/**/libraries + +# Gradle and Maven with auto-import +# When using Gradle or Maven with auto-import, you should exclude module files, +# since they will be recreated, and may cause churn. Uncomment if using +# auto-import. +# .idea/artifacts +# .idea/compiler.xml +# .idea/jarRepositories.xml +# .idea/modules.xml +# .idea/*.iml +# .idea/modules +# *.iml +# *.ipr + +# CMake +cmake-build-*/ + +# Mongo Explorer plugin +.idea/**/mongoSettings.xml + +# File-based project format +*.iws + +# IntelliJ +out/ + +# mpeltonen/sbt-idea plugin +.idea_modules/ + +# JIRA plugin +atlassian-ide-plugin.xml + +# Cursive Clojure plugin +.idea/replstate.xml + +# Crashlytics plugin (for Android Studio and IntelliJ) +com_crashlytics_export_strings.xml +crashlytics.properties +crashlytics-build.properties +fabric.properties + +# Editor-based Rest Client +.idea/httpRequests + +# Android studio 3.1+ serialized cache file +.idea/caches/build_file_checksums.ser + +.idea/ + +/saved_models/* +!/saved_models/.gitkeep + +/datasets/* +!/datasets/.gitkeep + +/config.py diff --git a/config.example.py b/config.example.py new file mode 100644 index 0000000..b003cd9 --- /dev/null +++ b/config.example.py @@ -0,0 +1,14 @@ +MODEL_STORAGE_BASE_PATH = "/path/to/this/project/saved_models" +DATASET_STORAGE_BASE_PATH = "/path/to/this/project/datasets" + +TEST_RUNS = [ + { + 'name': "Basic test run", + 'encoder_model': "models.base_encoder.BaseEncoder", + 'encoder_kwargs': {}, + 'dataset_model': "models.base_dataset.BaseDataset", + 'dataset_kwargs': {}, + 'corruption_model': "models.base_corruption.NoCorruption", + 'corruption_kwargs': {}, + }, +] diff --git a/datasets/.gitkeep b/datasets/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/main.py b/main.py new file mode 100644 index 0000000..46127f1 --- /dev/null +++ b/main.py @@ -0,0 +1,57 @@ +import importlib + +import config +import logging + +from models.base_corruption import BaseCorruption +from models.base_dataset import BaseDataset +from models.base_encoder import BaseEncoder +from models.test_run import TestRun + +logger = logging.getLogger("main.py") +logger.setLevel(logging.DEBUG) + +ch = logging.StreamHandler() +ch.setLevel(logging.DEBUG) +formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') +ch.setFormatter(formatter) +logger.addHandler(ch) + + +def load_dotted_path(path): + split_path = path.split(".") + modulename, classname = ".".join(split_path[:-1]), split_path[-1] + model = getattr(importlib.import_module(modulename), classname) + return model + + +def run_tests(): + for test in config.TEST_RUNS: + logger.info(f"Running test run '{test['name']}'...") + + # Load dataset model + dataset_model = load_dotted_path(test['dataset_model']) + assert issubclass(dataset_model, BaseDataset), f"Invalid dataset_model: '{dataset_model.__name__}', should be subclass of BaseDataset." + logger.debug(f"Using dataset model '{dataset_model.__name__}'") + + # Load auto-encoder model + encoder_model = load_dotted_path(test['encoder_model']) + assert issubclass(encoder_model, BaseEncoder), f"Invalid encoder_model: '{encoder_model.__name__}', should be subclass of BaseEncoder." + logger.debug(f"Using encoder model '{encoder_model.__name__}'") + + # Load corruption model + corruption_model = load_dotted_path(test['corruption_model']) + assert issubclass(corruption_model, BaseCorruption), f"Invalid corruption_model: '{corruption_model.__name__}', should be subclass of BaseCorruption." + logger.debug(f"Using corruption model '{corruption_model.__name__}'") + + # Create TestRun instance + test_run = TestRun(dataset=dataset_model(**test['dataset_kwargs']), + encoder=encoder_model(**test['encoder_kwargs']), + corruption=corruption_model(**test['corruption_kwargs'])) + + # Run TestRun + test_run.run() + + +if __name__ == '__main__': + run_tests() diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/models/base_corruption.py b/models/base_corruption.py new file mode 100644 index 0000000..fb754f0 --- /dev/null +++ b/models/base_corruption.py @@ -0,0 +1,30 @@ +from models.base_dataset import BaseDataset + + +class BaseCorruption: + """ + Base corruption model that is not implemented. + """ + name = "BaseCorruption" + + def __init__(self, name: str = None): + if name is not None: + self.name = name + + def __str__(self): + return f"{self.name}" + + @classmethod + def corrupt(cls, dataset: BaseDataset) -> BaseDataset: + raise NotImplementedError() + + +class NoCorruption(BaseCorruption): + """ + Corruption model that does not corrupt the dataset at all. + """ + name = "No corruption" + + @classmethod + def corrupt(cls, dataset: BaseDataset) -> BaseDataset: + return dataset diff --git a/models/base_dataset.py b/models/base_dataset.py new file mode 100644 index 0000000..8126546 --- /dev/null +++ b/models/base_dataset.py @@ -0,0 +1,71 @@ +import math +from typing import Union, Optional + + +class BaseDataset: + + # Train amount is either a proportion of data that should be used as training data (between 0 and 1), + # or an integer indicating how many entries should be used as training data (e.g. 1000, 2000) + # + # So 0.2 would mean 20% of all data in the dataset (200 if dataset is 1000 entries) is used as training data, + # and 1000 would mean that 1000 entries are used as training data, regardless of the size of the dataset. + TRAIN_AMOUNT = 0.2 + + name = "BaseDataset" + _source_path = None + _data = None + _trainset: 'BaseDataset' = None + _testset: 'BaseDataset' = None + + def __init__(self, name: Optional[str] = None): + if name is not None: + self.name = name + + def __str__(self): + if self._data is not None: + return f"{self.name} ({len(self._data)} objects)" + else: + return f"{self.name} (no data loaded)" + + @classmethod + def get_new(cls, name: str, data: Optional[list] = None, source_path: Optional[str] = None, + train_set: Optional['BaseDataset'] = None, test_set: Optional['BaseDataset'] = None): + dset = cls() + dset._data = data + dset._source_path = source_path + dset._trainset = train_set + dset._testset = test_set + return dset + + def load(self, name: str, path: str): + self.name = str + self._source_path = path + raise NotImplementedError() + + def _subdivide(self, amount: Union[int, float]): + if self._data is None: + raise ValueError("Cannot subdivide! Data not loaded, call `load()` first to load data") + + if isinstance(amount, float) and 0 < amount < 1: + size_train = math.floor(len(self._data) * amount) + train_data = self._data[:size_train] + test_data = self._data[size_train:] + elif isinstance(amount, int) and amount > 0: + train_data = self._data[:amount] + test_data = self._data[amount:] + else: + raise ValueError("Cannot subdivide! Invalid amount given, " + "must be either a fraction between 0 and 1, or an integer.") + + self._trainset = self.__class__.get_new(name=f"{self.name} Training", data=train_data, source_path=self._source_path) + self._testset = self.__class__.get_new(name=f"{self.name} Testing", data=test_data, source_path=self._source_path) + + def get_train(self) -> 'BaseDataset': + if not self._trainset or not self._testset: + self._subdivide(self.TRAIN_AMOUNT) + return self._trainset + + def get_test(self) -> 'BaseDataset': + if not self._trainset or not self._testset: + self._subdivide(self.TRAIN_AMOUNT) + return self._testset diff --git a/models/base_encoder.py b/models/base_encoder.py new file mode 100644 index 0000000..8a2fb99 --- /dev/null +++ b/models/base_encoder.py @@ -0,0 +1,20 @@ +from typing import Optional + +from models.base_dataset import BaseDataset + + +class BaseEncoder: + name = "BaseEncoder" + + def __init__(self, name: Optional[str] = None): + if name is not None: + self.name = name + + def __str__(self): + return f"{self.name}" + + def train(self, dataset: BaseDataset): + raise NotImplementedError() + + def test(self, dataset: BaseDataset): + raise NotImplementedError() diff --git a/models/test_run.py b/models/test_run.py new file mode 100644 index 0000000..fca22a0 --- /dev/null +++ b/models/test_run.py @@ -0,0 +1,29 @@ +from models.base_corruption import BaseCorruption +from models.base_dataset import BaseDataset +from models.base_encoder import BaseEncoder + + +class TestRun: + dataset: BaseDataset = None + encoder: BaseEncoder = None + corruption: BaseCorruption = None + + def __init__(self, dataset: BaseDataset, encoder: BaseEncoder, corruption: BaseCorruption): + self.dataset = dataset + self.encoder = encoder + self.corruption = corruption + + def run(self): + if self.dataset is None: + raise ValueError("Cannot run test! Dataset is not specified.") + if self.encoder is None: + raise ValueError("Cannot run test! AutoEncoder is not specified.") + if self.corruption is None: + raise ValueError("Cannot run test! Corruption method is not specified.") + return self._run() + + def _run(self): + raise NotImplementedError() + + def get_metrics(self): + raise NotImplementedError() diff --git a/saved_models/.gitkeep b/saved_models/.gitkeep new file mode 100644 index 0000000..e69de29