Initial, very basic framework for running comparison tests

This commit is contained in:
Kevin Alberts 2020-11-24 17:19:46 +01:00
commit 51cc5d1d30
Signed by: Kurocon
GPG key ID: BCD496FEBA0C6BC1
10 changed files with 444 additions and 0 deletions

0
models/__init__.py Normal file
View file

30
models/base_corruption.py Normal file
View file

@ -0,0 +1,30 @@
from models.base_dataset import BaseDataset
class BaseCorruption:
"""
Base corruption model that is not implemented.
"""
name = "BaseCorruption"
def __init__(self, name: str = None):
if name is not None:
self.name = name
def __str__(self):
return f"{self.name}"
@classmethod
def corrupt(cls, dataset: BaseDataset) -> BaseDataset:
raise NotImplementedError()
class NoCorruption(BaseCorruption):
"""
Corruption model that does not corrupt the dataset at all.
"""
name = "No corruption"
@classmethod
def corrupt(cls, dataset: BaseDataset) -> BaseDataset:
return dataset

71
models/base_dataset.py Normal file
View file

@ -0,0 +1,71 @@
import math
from typing import Union, Optional
class BaseDataset:
# Train amount is either a proportion of data that should be used as training data (between 0 and 1),
# or an integer indicating how many entries should be used as training data (e.g. 1000, 2000)
#
# So 0.2 would mean 20% of all data in the dataset (200 if dataset is 1000 entries) is used as training data,
# and 1000 would mean that 1000 entries are used as training data, regardless of the size of the dataset.
TRAIN_AMOUNT = 0.2
name = "BaseDataset"
_source_path = None
_data = None
_trainset: 'BaseDataset' = None
_testset: 'BaseDataset' = None
def __init__(self, name: Optional[str] = None):
if name is not None:
self.name = name
def __str__(self):
if self._data is not None:
return f"{self.name} ({len(self._data)} objects)"
else:
return f"{self.name} (no data loaded)"
@classmethod
def get_new(cls, name: str, data: Optional[list] = None, source_path: Optional[str] = None,
train_set: Optional['BaseDataset'] = None, test_set: Optional['BaseDataset'] = None):
dset = cls()
dset._data = data
dset._source_path = source_path
dset._trainset = train_set
dset._testset = test_set
return dset
def load(self, name: str, path: str):
self.name = str
self._source_path = path
raise NotImplementedError()
def _subdivide(self, amount: Union[int, float]):
if self._data is None:
raise ValueError("Cannot subdivide! Data not loaded, call `load()` first to load data")
if isinstance(amount, float) and 0 < amount < 1:
size_train = math.floor(len(self._data) * amount)
train_data = self._data[:size_train]
test_data = self._data[size_train:]
elif isinstance(amount, int) and amount > 0:
train_data = self._data[:amount]
test_data = self._data[amount:]
else:
raise ValueError("Cannot subdivide! Invalid amount given, "
"must be either a fraction between 0 and 1, or an integer.")
self._trainset = self.__class__.get_new(name=f"{self.name} Training", data=train_data, source_path=self._source_path)
self._testset = self.__class__.get_new(name=f"{self.name} Testing", data=test_data, source_path=self._source_path)
def get_train(self) -> 'BaseDataset':
if not self._trainset or not self._testset:
self._subdivide(self.TRAIN_AMOUNT)
return self._trainset
def get_test(self) -> 'BaseDataset':
if not self._trainset or not self._testset:
self._subdivide(self.TRAIN_AMOUNT)
return self._testset

20
models/base_encoder.py Normal file
View file

@ -0,0 +1,20 @@
from typing import Optional
from models.base_dataset import BaseDataset
class BaseEncoder:
name = "BaseEncoder"
def __init__(self, name: Optional[str] = None):
if name is not None:
self.name = name
def __str__(self):
return f"{self.name}"
def train(self, dataset: BaseDataset):
raise NotImplementedError()
def test(self, dataset: BaseDataset):
raise NotImplementedError()

29
models/test_run.py Normal file
View file

@ -0,0 +1,29 @@
from models.base_corruption import BaseCorruption
from models.base_dataset import BaseDataset
from models.base_encoder import BaseEncoder
class TestRun:
dataset: BaseDataset = None
encoder: BaseEncoder = None
corruption: BaseCorruption = None
def __init__(self, dataset: BaseDataset, encoder: BaseEncoder, corruption: BaseCorruption):
self.dataset = dataset
self.encoder = encoder
self.corruption = corruption
def run(self):
if self.dataset is None:
raise ValueError("Cannot run test! Dataset is not specified.")
if self.encoder is None:
raise ValueError("Cannot run test! AutoEncoder is not specified.")
if self.corruption is None:
raise ValueError("Cannot run test! Corruption method is not specified.")
return self._run()
def _run(self):
raise NotImplementedError()
def get_metrics(self):
raise NotImplementedError()