Initial, very basic framework for running comparison tests
This commit is contained in:
		
						commit
						51cc5d1d30
					
				
					 10 changed files with 444 additions and 0 deletions
				
			
		
							
								
								
									
										0
									
								
								models/__init__.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								models/__init__.py
									
										
									
									
									
										Normal file
									
								
							
							
								
								
									
										30
									
								
								models/base_corruption.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										30
									
								
								models/base_corruption.py
									
										
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,30 @@
 | 
			
		|||
from models.base_dataset import BaseDataset
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class BaseCorruption:
 | 
			
		||||
    """
 | 
			
		||||
    Base corruption model that is not implemented.
 | 
			
		||||
    """
 | 
			
		||||
    name = "BaseCorruption"
 | 
			
		||||
 | 
			
		||||
    def __init__(self, name: str = None):
 | 
			
		||||
        if name is not None:
 | 
			
		||||
            self.name = name
 | 
			
		||||
 | 
			
		||||
    def __str__(self):
 | 
			
		||||
        return f"{self.name}"
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def corrupt(cls, dataset: BaseDataset) -> BaseDataset:
 | 
			
		||||
        raise NotImplementedError()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class NoCorruption(BaseCorruption):
 | 
			
		||||
    """
 | 
			
		||||
    Corruption model that does not corrupt the dataset at all.
 | 
			
		||||
    """
 | 
			
		||||
    name = "No corruption"
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def corrupt(cls, dataset: BaseDataset) -> BaseDataset:
 | 
			
		||||
        return dataset
 | 
			
		||||
							
								
								
									
										71
									
								
								models/base_dataset.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										71
									
								
								models/base_dataset.py
									
										
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,71 @@
 | 
			
		|||
import math
 | 
			
		||||
from typing import Union, Optional
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class BaseDataset:
 | 
			
		||||
 | 
			
		||||
    # Train amount is either a proportion of data that should be used as training data (between 0 and 1),
 | 
			
		||||
    # or an integer indicating how many entries should be used as training data (e.g. 1000, 2000)
 | 
			
		||||
    #
 | 
			
		||||
    # So 0.2 would mean 20% of all data in the dataset (200 if dataset is 1000 entries) is used as training data,
 | 
			
		||||
    # and 1000 would mean that 1000 entries are used as training data, regardless of the size of the dataset.
 | 
			
		||||
    TRAIN_AMOUNT = 0.2
 | 
			
		||||
 | 
			
		||||
    name = "BaseDataset"
 | 
			
		||||
    _source_path = None
 | 
			
		||||
    _data = None
 | 
			
		||||
    _trainset: 'BaseDataset' = None
 | 
			
		||||
    _testset: 'BaseDataset' = None
 | 
			
		||||
 | 
			
		||||
    def __init__(self, name: Optional[str] = None):
 | 
			
		||||
        if name is not None:
 | 
			
		||||
            self.name = name
 | 
			
		||||
 | 
			
		||||
    def __str__(self):
 | 
			
		||||
        if self._data is not None:
 | 
			
		||||
            return f"{self.name} ({len(self._data)} objects)"
 | 
			
		||||
        else:
 | 
			
		||||
            return f"{self.name} (no data loaded)"
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def get_new(cls, name: str, data: Optional[list] = None, source_path: Optional[str] = None,
 | 
			
		||||
                train_set: Optional['BaseDataset'] = None, test_set: Optional['BaseDataset'] = None):
 | 
			
		||||
        dset = cls()
 | 
			
		||||
        dset._data = data
 | 
			
		||||
        dset._source_path = source_path
 | 
			
		||||
        dset._trainset = train_set
 | 
			
		||||
        dset._testset = test_set
 | 
			
		||||
        return dset
 | 
			
		||||
 | 
			
		||||
    def load(self, name: str, path: str):
 | 
			
		||||
        self.name = str
 | 
			
		||||
        self._source_path = path
 | 
			
		||||
        raise NotImplementedError()
 | 
			
		||||
 | 
			
		||||
    def _subdivide(self, amount: Union[int, float]):
 | 
			
		||||
        if self._data is None:
 | 
			
		||||
            raise ValueError("Cannot subdivide! Data not loaded, call `load()` first to load data")
 | 
			
		||||
 | 
			
		||||
        if isinstance(amount, float) and 0 < amount < 1:
 | 
			
		||||
            size_train = math.floor(len(self._data) * amount)
 | 
			
		||||
            train_data = self._data[:size_train]
 | 
			
		||||
            test_data = self._data[size_train:]
 | 
			
		||||
        elif isinstance(amount, int) and amount > 0:
 | 
			
		||||
            train_data = self._data[:amount]
 | 
			
		||||
            test_data = self._data[amount:]
 | 
			
		||||
        else:
 | 
			
		||||
            raise ValueError("Cannot subdivide! Invalid amount given, "
 | 
			
		||||
                             "must be either a fraction between 0 and 1, or an integer.")
 | 
			
		||||
 | 
			
		||||
        self._trainset = self.__class__.get_new(name=f"{self.name} Training", data=train_data, source_path=self._source_path)
 | 
			
		||||
        self._testset = self.__class__.get_new(name=f"{self.name} Testing", data=test_data, source_path=self._source_path)
 | 
			
		||||
 | 
			
		||||
    def get_train(self) -> 'BaseDataset':
 | 
			
		||||
        if not self._trainset or not self._testset:
 | 
			
		||||
            self._subdivide(self.TRAIN_AMOUNT)
 | 
			
		||||
        return self._trainset
 | 
			
		||||
 | 
			
		||||
    def get_test(self) -> 'BaseDataset':
 | 
			
		||||
        if not self._trainset or not self._testset:
 | 
			
		||||
            self._subdivide(self.TRAIN_AMOUNT)
 | 
			
		||||
        return self._testset
 | 
			
		||||
							
								
								
									
										20
									
								
								models/base_encoder.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										20
									
								
								models/base_encoder.py
									
										
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,20 @@
 | 
			
		|||
from typing import Optional
 | 
			
		||||
 | 
			
		||||
from models.base_dataset import BaseDataset
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class BaseEncoder:
 | 
			
		||||
    name = "BaseEncoder"
 | 
			
		||||
 | 
			
		||||
    def __init__(self, name: Optional[str] = None):
 | 
			
		||||
        if name is not None:
 | 
			
		||||
            self.name = name
 | 
			
		||||
 | 
			
		||||
    def __str__(self):
 | 
			
		||||
        return f"{self.name}"
 | 
			
		||||
 | 
			
		||||
    def train(self, dataset: BaseDataset):
 | 
			
		||||
        raise NotImplementedError()
 | 
			
		||||
 | 
			
		||||
    def test(self, dataset: BaseDataset):
 | 
			
		||||
        raise NotImplementedError()
 | 
			
		||||
							
								
								
									
										29
									
								
								models/test_run.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										29
									
								
								models/test_run.py
									
										
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,29 @@
 | 
			
		|||
from models.base_corruption import BaseCorruption
 | 
			
		||||
from models.base_dataset import BaseDataset
 | 
			
		||||
from models.base_encoder import BaseEncoder
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestRun:
 | 
			
		||||
    dataset: BaseDataset = None
 | 
			
		||||
    encoder: BaseEncoder = None
 | 
			
		||||
    corruption: BaseCorruption = None
 | 
			
		||||
 | 
			
		||||
    def __init__(self, dataset: BaseDataset, encoder: BaseEncoder, corruption: BaseCorruption):
 | 
			
		||||
        self.dataset = dataset
 | 
			
		||||
        self.encoder = encoder
 | 
			
		||||
        self.corruption = corruption
 | 
			
		||||
 | 
			
		||||
    def run(self):
 | 
			
		||||
        if self.dataset is None:
 | 
			
		||||
            raise ValueError("Cannot run test! Dataset is not specified.")
 | 
			
		||||
        if self.encoder is None:
 | 
			
		||||
            raise ValueError("Cannot run test! AutoEncoder is not specified.")
 | 
			
		||||
        if self.corruption is None:
 | 
			
		||||
            raise ValueError("Cannot run test! Corruption method is not specified.")
 | 
			
		||||
        return self._run()
 | 
			
		||||
 | 
			
		||||
    def _run(self):
 | 
			
		||||
        raise NotImplementedError()
 | 
			
		||||
 | 
			
		||||
    def get_metrics(self):
 | 
			
		||||
        raise NotImplementedError()
 | 
			
		||||
		Reference in a new issue