First basic auto-encoder and CIFAR-10 dataset implemented
This commit is contained in:
		
							parent
							
								
									51cc5d1d30
								
							
						
					
					
						commit
						62f9b873e9
					
				
					 10 changed files with 509 additions and 27 deletions
				
			
		
							
								
								
									
										3
									
								
								.gitignore
									
										
									
									
										vendored
									
									
								
							
							
						
						
									
										3
									
								
								.gitignore
									
										
									
									
										vendored
									
									
								
							| 
						 | 
					@ -220,4 +220,7 @@ fabric.properties
 | 
				
			||||||
/datasets/*
 | 
					/datasets/*
 | 
				
			||||||
!/datasets/.gitkeep
 | 
					!/datasets/.gitkeep
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/test_temp
 | 
				
			||||||
 | 
					/train_temp
 | 
				
			||||||
 | 
					
 | 
				
			||||||
/config.py
 | 
					/config.py
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
							
								
								
									
										28
									
								
								logging.conf
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										28
									
								
								logging.conf
									
										
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,28 @@
 | 
				
			||||||
 | 
					[loggers]
 | 
				
			||||||
 | 
					keys=root
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					[handlers]
 | 
				
			||||||
 | 
					keys=consoleHandler,fileHandler
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					[formatters]
 | 
				
			||||||
 | 
					keys=simpleFormatter
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					[logger_root]
 | 
				
			||||||
 | 
					level=DEBUG
 | 
				
			||||||
 | 
					handlers=consoleHandler,fileHandler
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					[handler_fileHandler]
 | 
				
			||||||
 | 
					class=FileHandler
 | 
				
			||||||
 | 
					level=DEBUG
 | 
				
			||||||
 | 
					formatter=simpleFormatter
 | 
				
			||||||
 | 
					args=('output.log', 'w')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					[handler_consoleHandler]
 | 
				
			||||||
 | 
					class=StreamHandler
 | 
				
			||||||
 | 
					level=INFO
 | 
				
			||||||
 | 
					formatter=simpleFormatter
 | 
				
			||||||
 | 
					args=(sys.stdout,)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					[formatter_simpleFormatter]
 | 
				
			||||||
 | 
					format=%(asctime)s - %(name)s - %(levelname)s - %(message)s
 | 
				
			||||||
 | 
					datefmt=
 | 
				
			||||||
							
								
								
									
										18
									
								
								main.py
									
										
									
									
									
								
							
							
						
						
									
										18
									
								
								main.py
									
										
									
									
									
								
							| 
						 | 
					@ -1,22 +1,17 @@
 | 
				
			||||||
import importlib
 | 
					import importlib
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import config
 | 
					import config
 | 
				
			||||||
import logging
 | 
					import logging.config
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Get logging as early as possible!
 | 
				
			||||||
 | 
					logging.config.fileConfig("logging.conf")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from models.base_corruption import BaseCorruption
 | 
					from models.base_corruption import BaseCorruption
 | 
				
			||||||
from models.base_dataset import BaseDataset
 | 
					from models.base_dataset import BaseDataset
 | 
				
			||||||
from models.base_encoder import BaseEncoder
 | 
					from models.base_encoder import BaseEncoder
 | 
				
			||||||
from models.test_run import TestRun
 | 
					from models.test_run import TestRun
 | 
				
			||||||
 | 
					
 | 
				
			||||||
logger = logging.getLogger("main.py")
 | 
					 | 
				
			||||||
logger.setLevel(logging.DEBUG)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
ch = logging.StreamHandler()
 | 
					 | 
				
			||||||
ch.setLevel(logging.DEBUG)
 | 
					 | 
				
			||||||
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
 | 
					 | 
				
			||||||
ch.setFormatter(formatter)
 | 
					 | 
				
			||||||
logger.addHandler(ch)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
def load_dotted_path(path):
 | 
					def load_dotted_path(path):
 | 
				
			||||||
    split_path = path.split(".")
 | 
					    split_path = path.split(".")
 | 
				
			||||||
| 
						 | 
					@ -26,6 +21,7 @@ def load_dotted_path(path):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def run_tests():
 | 
					def run_tests():
 | 
				
			||||||
 | 
					    logger = logging.getLogger("main.run_tests")
 | 
				
			||||||
    for test in config.TEST_RUNS:
 | 
					    for test in config.TEST_RUNS:
 | 
				
			||||||
        logger.info(f"Running test run '{test['name']}'...")
 | 
					        logger.info(f"Running test run '{test['name']}'...")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -50,7 +46,7 @@ def run_tests():
 | 
				
			||||||
                           corruption=corruption_model(**test['corruption_kwargs']))
 | 
					                           corruption=corruption_model(**test['corruption_kwargs']))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Run TestRun
 | 
					        # Run TestRun
 | 
				
			||||||
        test_run.run()
 | 
					        test_run.run(retrain=False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if __name__ == '__main__':
 | 
					if __name__ == '__main__':
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,8 +1,11 @@
 | 
				
			||||||
import math
 | 
					import math, logging
 | 
				
			||||||
from typing import Union, Optional
 | 
					from typing import Union, Optional
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					from torch.utils.data import Dataset
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class BaseDataset:
 | 
					
 | 
				
			||||||
 | 
					class BaseDataset(Dataset):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Train amount is either a proportion of data that should be used as training data (between 0 and 1),
 | 
					    # Train amount is either a proportion of data that should be used as training data (between 0 and 1),
 | 
				
			||||||
    # or an integer indicating how many entries should be used as training data (e.g. 1000, 2000)
 | 
					    # or an integer indicating how many entries should be used as training data (e.g. 1000, 2000)
 | 
				
			||||||
| 
						 | 
					@ -16,10 +19,14 @@ class BaseDataset:
 | 
				
			||||||
    _data = None
 | 
					    _data = None
 | 
				
			||||||
    _trainset: 'BaseDataset' = None
 | 
					    _trainset: 'BaseDataset' = None
 | 
				
			||||||
    _testset: 'BaseDataset' = None
 | 
					    _testset: 'BaseDataset' = None
 | 
				
			||||||
 | 
					    transform = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __init__(self, name: Optional[str] = None):
 | 
					    def __init__(self, name: Optional[str] = None, path: Optional[str] = None):
 | 
				
			||||||
 | 
					        self.log = logging.getLogger(self.__class__.__name__)
 | 
				
			||||||
        if name is not None:
 | 
					        if name is not None:
 | 
				
			||||||
            self.name = name
 | 
					            self.name = name
 | 
				
			||||||
 | 
					        if path is not None:
 | 
				
			||||||
 | 
					            self._source_path = path
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __str__(self):
 | 
					    def __str__(self):
 | 
				
			||||||
        if self._data is not None:
 | 
					        if self._data is not None:
 | 
				
			||||||
| 
						 | 
					@ -27,19 +34,30 @@ class BaseDataset:
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            return f"{self.name} (no data loaded)"
 | 
					            return f"{self.name} (no data loaded)"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # __len__ so that len(dataset) returns the size of the dataset.
 | 
				
			||||||
 | 
					    # __getitem__ to support the indexing such that dataset[i] can be used to get ith sample
 | 
				
			||||||
 | 
					    def __len__(self):
 | 
				
			||||||
 | 
					        return len(self._data)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __getitem__(self, item):
 | 
				
			||||||
 | 
					        return self.transform(self._data[item]) if self.transform else self._data[item]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @classmethod
 | 
					    @classmethod
 | 
				
			||||||
    def get_new(cls, name: str, data: Optional[list] = None, source_path: Optional[str] = None,
 | 
					    def get_new(cls, name: str, data: Optional[list] = None, source_path: Optional[str] = None,
 | 
				
			||||||
                train_set: Optional['BaseDataset'] = None, test_set: Optional['BaseDataset'] = None):
 | 
					                train_set: Optional['BaseDataset'] = None, test_set: Optional['BaseDataset'] = None):
 | 
				
			||||||
        dset = cls()
 | 
					        dset = cls()
 | 
				
			||||||
 | 
					        dset.name = name
 | 
				
			||||||
        dset._data = data
 | 
					        dset._data = data
 | 
				
			||||||
        dset._source_path = source_path
 | 
					        dset._source_path = source_path
 | 
				
			||||||
        dset._trainset = train_set
 | 
					        dset._trainset = train_set
 | 
				
			||||||
        dset._testset = test_set
 | 
					        dset._testset = test_set
 | 
				
			||||||
        return dset
 | 
					        return dset
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def load(self, name: str, path: str):
 | 
					    def load(self, name: Optional[str] = None, path: Optional[str] = None):
 | 
				
			||||||
        self.name = str
 | 
					        if name is not None:
 | 
				
			||||||
        self._source_path = path
 | 
					            self.name = name
 | 
				
			||||||
 | 
					        if path is not None:
 | 
				
			||||||
 | 
					            self._source_path = path
 | 
				
			||||||
        raise NotImplementedError()
 | 
					        raise NotImplementedError()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def _subdivide(self, amount: Union[int, float]):
 | 
					    def _subdivide(self, amount: Union[int, float]):
 | 
				
			||||||
| 
						 | 
					@ -69,3 +87,13 @@ class BaseDataset:
 | 
				
			||||||
        if not self._trainset or not self._testset:
 | 
					        if not self._trainset or not self._testset:
 | 
				
			||||||
            self._subdivide(self.TRAIN_AMOUNT)
 | 
					            self._subdivide(self.TRAIN_AMOUNT)
 | 
				
			||||||
        return self._testset
 | 
					        return self._testset
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def get_loader(self, dataset, batch_size: int = 128, num_workers: int = 4) -> torch.utils.data.DataLoader:
 | 
				
			||||||
 | 
					        return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False,
 | 
				
			||||||
 | 
					                                           num_workers=num_workers, pin_memory=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def get_train_loader(self, batch_size: int = 128, num_workers: int = 4) -> torch.utils.data.DataLoader:
 | 
				
			||||||
 | 
					        return self.get_loader(self.get_train(), batch_size=batch_size, num_workers=num_workers)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def get_test_loader(self, batch_size: int = 128, num_workers: int = 4) -> torch.utils.data.DataLoader:
 | 
				
			||||||
 | 
					        return self.get_loader(self.get_test(), batch_size=batch_size, num_workers=num_workers)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,20 +1,172 @@
 | 
				
			||||||
 | 
					import json
 | 
				
			||||||
 | 
					import logging
 | 
				
			||||||
 | 
					import os
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from typing import Optional
 | 
					from typing import Optional
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from torchvision.utils import save_image
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from config import TRAIN_TEMP_DATA_BASE_PATH, TEST_TEMP_DATA_BASE_PATH, MODEL_STORAGE_BASE_PATH
 | 
				
			||||||
from models.base_dataset import BaseDataset
 | 
					from models.base_dataset import BaseDataset
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class BaseEncoder:
 | 
					class BaseEncoder(torch.nn.Module):
 | 
				
			||||||
 | 
					    # Based on https://medium.com/pytorch/implementing-an-autoencoder-in-pytorch-19baa22647d1
 | 
				
			||||||
    name = "BaseEncoder"
 | 
					    name = "BaseEncoder"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __init__(self, name: Optional[str] = None):
 | 
					    def __init__(self, name: Optional[str] = None, input_shape: int = 0):
 | 
				
			||||||
 | 
					        super(BaseEncoder, self).__init__()
 | 
				
			||||||
 | 
					        self.log = logging.getLogger(self.__class__.__name__)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if name is not None:
 | 
					        if name is not None:
 | 
				
			||||||
            self.name = name
 | 
					            self.name = name
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        assert input_shape != 0, f"Encoder {self.__class__.__name__} input_shape parameter should not be 0"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.input_shape = input_shape
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Default fallbacks (can be overridden by sub implementations)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # 4 layer NN, halving the input each layer
 | 
				
			||||||
 | 
					        self.network = torch.nn.Sequential(
 | 
				
			||||||
 | 
					            torch.nn.Linear(in_features=input_shape, out_features=input_shape // 2),
 | 
				
			||||||
 | 
					            torch.nn.ReLU(),
 | 
				
			||||||
 | 
					            torch.nn.Linear(in_features=input_shape // 2, out_features=input_shape // 4),
 | 
				
			||||||
 | 
					            torch.nn.ReLU(),
 | 
				
			||||||
 | 
					            torch.nn.Linear(in_features=input_shape // 4, out_features=input_shape // 2),
 | 
				
			||||||
 | 
					            torch.nn.ReLU(),
 | 
				
			||||||
 | 
					            torch.nn.Linear(in_features=input_shape // 2, out_features=input_shape)
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Use GPU acceleration if available
 | 
				
			||||||
 | 
					        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Adam optimizer with learning rate 1e-3
 | 
				
			||||||
 | 
					        self.optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Mean Squared Error loss function
 | 
				
			||||||
 | 
					        self.loss_function = torch.nn.MSELoss()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def after_init(self):
 | 
				
			||||||
 | 
					        self.log.info(f"Auto-encoder {self.__class__.__name__} initialized with {len(list(self.network.children()))} layers on {self.device.type}. Optimizer: {self.optimizer}, Loss function: {self.loss_function}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self, features):
 | 
				
			||||||
 | 
					        return self.network(features)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def save_model(self, filename):
 | 
				
			||||||
 | 
					        torch.save(self.state_dict(), os.path.join(MODEL_STORAGE_BASE_PATH, f"{filename}.model"))
 | 
				
			||||||
 | 
					        with open(os.path.join(MODEL_STORAGE_BASE_PATH, f"{filename}.meta"), 'w') as f:
 | 
				
			||||||
 | 
					            f.write(json.dumps({
 | 
				
			||||||
 | 
					                'name': self.name,
 | 
				
			||||||
 | 
					                'input_shape': self.input_shape
 | 
				
			||||||
 | 
					            }))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def load_model(self, filename=None):
 | 
				
			||||||
 | 
					        if filename is None:
 | 
				
			||||||
 | 
					            filename = f"{self.name}"
 | 
				
			||||||
 | 
					        try:
 | 
				
			||||||
 | 
					            loaded_model = torch.load(os.path.join(MODEL_STORAGE_BASE_PATH, f"{filename}.model"), map_location=self.device)
 | 
				
			||||||
 | 
					            self.load_state_dict(loaded_model)
 | 
				
			||||||
 | 
					            self.to(self.device)
 | 
				
			||||||
 | 
					            return True
 | 
				
			||||||
 | 
					        except OSError as e:
 | 
				
			||||||
 | 
					            self.log.error(f"Could not load model '{filename}': {e}")
 | 
				
			||||||
 | 
					            return False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @classmethod
 | 
				
			||||||
 | 
					    def create_model_from_file(cls, filename, device=None):
 | 
				
			||||||
 | 
					        try:
 | 
				
			||||||
 | 
					            if device is None:
 | 
				
			||||||
 | 
					                device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 | 
				
			||||||
 | 
					            with open(os.path.join(MODEL_STORAGE_BASE_PATH, f"{filename}.meta")) as f:
 | 
				
			||||||
 | 
					                model_kwargs = json.loads(f.read())
 | 
				
			||||||
 | 
					            model = cls(**model_kwargs)
 | 
				
			||||||
 | 
					            loaded_model = torch.load(os.path.join(MODEL_STORAGE_BASE_PATH, f"{filename}.model"), map_location=device)
 | 
				
			||||||
 | 
					            model.load_state_dict(loaded_model)
 | 
				
			||||||
 | 
					            model.to(device)
 | 
				
			||||||
 | 
					            return model
 | 
				
			||||||
 | 
					        except OSError as e:
 | 
				
			||||||
 | 
					            log = logging.getLogger(cls.__name__)
 | 
				
			||||||
 | 
					            log.error(f"Could not load model '{filename}': {e}")
 | 
				
			||||||
 | 
					            return None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __str__(self):
 | 
					    def __str__(self):
 | 
				
			||||||
        return f"{self.name}"
 | 
					        return f"{self.name}"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def train(self, dataset: BaseDataset):
 | 
					    def train_encoder(self, dataset: BaseDataset, epochs: int = 20, batch_size: int = 128, num_workers: int = 4):
 | 
				
			||||||
        raise NotImplementedError()
 | 
					        self.log.debug("Getting training dataset DataLoader.")
 | 
				
			||||||
 | 
					        train_loader = dataset.get_train_loader(batch_size=batch_size, num_workers=num_workers)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test(self, dataset: BaseDataset):
 | 
					        # Puts module in training mode.
 | 
				
			||||||
        raise NotImplementedError()
 | 
					        self.log.debug("Putting model into training mode.")
 | 
				
			||||||
 | 
					        self.train()
 | 
				
			||||||
 | 
					        self.to(self.device, non_blocking=True)
 | 
				
			||||||
 | 
					        self.loss_function.to(self.device, non_blocking=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        losses = []
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        outputs = None
 | 
				
			||||||
 | 
					        for epoch in range(epochs):
 | 
				
			||||||
 | 
					            self.log.debug(f"Start training epoch {epoch}...")
 | 
				
			||||||
 | 
					            loss = 0
 | 
				
			||||||
 | 
					            for batch_features in train_loader:
 | 
				
			||||||
 | 
					                # load batch features to the active device
 | 
				
			||||||
 | 
					                batch_features = batch_features.to(self.device)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                # reset the gradients back to zero
 | 
				
			||||||
 | 
					                # PyTorch accumulates gradients on subsequent backward passes
 | 
				
			||||||
 | 
					                self.optimizer.zero_grad()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                # compute reconstructions
 | 
				
			||||||
 | 
					                outputs = self(batch_features)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                # compute training reconstruction loss
 | 
				
			||||||
 | 
					                train_loss = self.loss_function(outputs, batch_features)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                # compute accumulated gradients
 | 
				
			||||||
 | 
					                train_loss.backward()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                # perform parameter update based on current gradients
 | 
				
			||||||
 | 
					                self.optimizer.step()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                # add the mini-batch training loss to epoch loss
 | 
				
			||||||
 | 
					                loss += train_loss.item()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            # compute the epoch training loss
 | 
				
			||||||
 | 
					            loss = loss / len(train_loader)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            # display the epoch training loss
 | 
				
			||||||
 | 
					            self.log.info("epoch : {}/{}, loss = {:.6f}".format(epoch + 1, epochs, loss))
 | 
				
			||||||
 | 
					            losses.append(loss)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            # Every 5 epochs, save a test image
 | 
				
			||||||
 | 
					            if epoch % 5 == 0:
 | 
				
			||||||
 | 
					                img = outputs.cpu().data
 | 
				
			||||||
 | 
					                img = img.view(img.size(0), 3, 32, 32)
 | 
				
			||||||
 | 
					                save_image(img, os.path.join(TRAIN_TEMP_DATA_BASE_PATH,
 | 
				
			||||||
 | 
					                                             f'{self.name}_{dataset.name}_linear_ae_image{epoch}.png'))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return losses
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_encoder(self, dataset: BaseDataset, batch_size: int = 128, num_workers: int = 4):
 | 
				
			||||||
 | 
					        self.log.debug("Getting testing dataset DataLoader.")
 | 
				
			||||||
 | 
					        test_loader = dataset.get_test_loader(batch_size=batch_size, num_workers=num_workers)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.log.debug(f"Start testing...")
 | 
				
			||||||
 | 
					        i = 0
 | 
				
			||||||
 | 
					        for batch in test_loader:
 | 
				
			||||||
 | 
					            img = batch.view(batch.size(0), 3, 32, 32)
 | 
				
			||||||
 | 
					            save_image(img, os.path.join(TEST_TEMP_DATA_BASE_PATH,
 | 
				
			||||||
 | 
					                                         f'{self.name}_{dataset.name}_test_input_{i}.png'))
 | 
				
			||||||
 | 
					            # load batch features to the active device
 | 
				
			||||||
 | 
					            batch = batch.to(self.device)
 | 
				
			||||||
 | 
					            outputs = self(batch)
 | 
				
			||||||
 | 
					            img = outputs.cpu().data
 | 
				
			||||||
 | 
					            img = img.view(outputs.size(0), 3, 32, 32)
 | 
				
			||||||
 | 
					            save_image(img, os.path.join(TEST_TEMP_DATA_BASE_PATH,
 | 
				
			||||||
 | 
					                                         f'{self.name}_{dataset.name}_test_reconstruction_{i}.png'))
 | 
				
			||||||
 | 
					            i += 1
 | 
				
			||||||
 | 
					            break
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
							
								
								
									
										48
									
								
								models/basic_encoder.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										48
									
								
								models/basic_encoder.py
									
										
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,48 @@
 | 
				
			||||||
 | 
					import json
 | 
				
			||||||
 | 
					import logging
 | 
				
			||||||
 | 
					import os
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from typing import Optional
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from models.base_encoder import BaseEncoder
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class BasicAutoEncoder(BaseEncoder):
 | 
				
			||||||
 | 
					    # Based on https://medium.com/pytorch/implementing-an-autoencoder-in-pytorch-19baa22647d1
 | 
				
			||||||
 | 
					    name = "BasicAutoEncoder"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(self, name: Optional[str] = None, input_shape: int = 0):
 | 
				
			||||||
 | 
					        self.log = logging.getLogger(self.__class__.__name__)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Call superclass to initialize parameters.
 | 
				
			||||||
 | 
					        super(BasicAutoEncoder, self).__init__(name, input_shape)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Override parameters to custom values for this encoder type
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # TODO - Hoe kan ik het beste bepalen hoe groot de intermediate layers moeten zijn?
 | 
				
			||||||
 | 
					        # - Proportioneel van input grootte naar opgegeven bottleneck grootte?
 | 
				
			||||||
 | 
					        # - Uit een paper plukken
 | 
				
			||||||
 | 
					        # - Zelf kiezen (e.g. helft elke keer, fixed aantal layers)?
 | 
				
			||||||
 | 
					        self.network = torch.nn.Sequential(
 | 
				
			||||||
 | 
					            torch.nn.Linear(in_features=input_shape, out_features=input_shape // 2),
 | 
				
			||||||
 | 
					            torch.nn.ReLU(),
 | 
				
			||||||
 | 
					            torch.nn.Linear(in_features=input_shape // 2, out_features=input_shape // 4),
 | 
				
			||||||
 | 
					            torch.nn.ReLU(),
 | 
				
			||||||
 | 
					            torch.nn.Linear(in_features=input_shape // 4, out_features=input_shape // 2),
 | 
				
			||||||
 | 
					            torch.nn.ReLU(),
 | 
				
			||||||
 | 
					            torch.nn.Linear(in_features=input_shape // 2, out_features=input_shape)
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Use GPU acceleration if available
 | 
				
			||||||
 | 
					        # self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Adam optimizer with learning rate 1e-3
 | 
				
			||||||
 | 
					        self.optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Mean Squared Error loss function
 | 
				
			||||||
 | 
					        # self.loss_function = torch.nn.MSELoss()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.after_init()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
							
								
								
									
										75
									
								
								models/cifar10_dataset.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										75
									
								
								models/cifar10_dataset.py
									
										
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,75 @@
 | 
				
			||||||
 | 
					import os
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from typing import Optional
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import numpy
 | 
				
			||||||
 | 
					import torchvision
 | 
				
			||||||
 | 
					from PIL import Image
 | 
				
			||||||
 | 
					from torchvision import transforms
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from config import DATASET_STORAGE_BASE_PATH
 | 
				
			||||||
 | 
					from models.base_dataset import BaseDataset
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class Cifar10Dataset(BaseDataset):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),
 | 
				
			||||||
 | 
					    #                                             torchvision.transforms.Normalize((0.5, ), (0.5, ))
 | 
				
			||||||
 | 
					    #                                             ])
 | 
				
			||||||
 | 
					    transform = transforms.Compose([
 | 
				
			||||||
 | 
					        transforms.ToPILImage(),
 | 
				
			||||||
 | 
					        transforms.ToTensor(),
 | 
				
			||||||
 | 
					        # transforms.Normalize((0.5,), (0.5,))
 | 
				
			||||||
 | 
					    ])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def unpickle(self, filename):
 | 
				
			||||||
 | 
					        import pickle
 | 
				
			||||||
 | 
					        with open(filename, 'rb') as fo:
 | 
				
			||||||
 | 
					            dict = pickle.load(fo, encoding='bytes')
 | 
				
			||||||
 | 
					        return dict
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def load(self, name: Optional[str] = None, path: Optional[str] = None):
 | 
				
			||||||
 | 
					        if name is not None:
 | 
				
			||||||
 | 
					            self.name = name
 | 
				
			||||||
 | 
					        if path is not None:
 | 
				
			||||||
 | 
					            self._source_path = path
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self._data = []
 | 
				
			||||||
 | 
					        for i in range(1, 6):
 | 
				
			||||||
 | 
					            data = self.unpickle(os.path.join(DATASET_STORAGE_BASE_PATH,
 | 
				
			||||||
 | 
					                                              self._source_path,
 | 
				
			||||||
 | 
					                                              f"data_batch_{i}"))
 | 
				
			||||||
 | 
					            self._data.extend(data[b'data'])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self._trainset = self.__class__.get_new(name=f"{self.name} Training", data=self._data[:],
 | 
				
			||||||
 | 
					                                                source_path=self._source_path)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        test_data = self.unpickle(os.path.join(DATASET_STORAGE_BASE_PATH,
 | 
				
			||||||
 | 
					                                               self._source_path,
 | 
				
			||||||
 | 
					                                               f"test_batch"))
 | 
				
			||||||
 | 
					        self._data.extend(test_data[b'data'])
 | 
				
			||||||
 | 
					        self._testset = self.__class__.get_new(name=f"{self.name} Testing", data=test_data[b'data'][:],
 | 
				
			||||||
 | 
					                                               source_path=self._source_path)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.log.info(f"Loaded {self}, divided into {self._trainset} and {self._testset}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __getitem__(self, item):
 | 
				
			||||||
 | 
					        # Get image data
 | 
				
			||||||
 | 
					        img = self._data[item]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        img_r, img_g, img_b = img.reshape((3, 1024))
 | 
				
			||||||
 | 
					        img_r = img_r.reshape((32, 32))
 | 
				
			||||||
 | 
					        img_g = img_g.reshape((32, 32))
 | 
				
			||||||
 | 
					        img_b = img_b.reshape((32, 32))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Reshape to 32x32x3 image
 | 
				
			||||||
 | 
					        img = numpy.stack((img_r, img_g, img_b), axis=2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Run transforms
 | 
				
			||||||
 | 
					        if self.transform is not None:
 | 
				
			||||||
 | 
					            img = self.transform(img)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Reshape the 32x32x3 image to a 1x3072 array for the Linear layer
 | 
				
			||||||
 | 
					        img = img.view(-1, 32 * 32 * 3)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return img
 | 
				
			||||||
| 
						 | 
					@ -1,6 +1,10 @@
 | 
				
			||||||
 | 
					import logging
 | 
				
			||||||
 | 
					import multiprocessing
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from models.base_corruption import BaseCorruption
 | 
					from models.base_corruption import BaseCorruption
 | 
				
			||||||
from models.base_dataset import BaseDataset
 | 
					from models.base_dataset import BaseDataset
 | 
				
			||||||
from models.base_encoder import BaseEncoder
 | 
					from models.base_encoder import BaseEncoder
 | 
				
			||||||
 | 
					from output_utils import save_train_loss_graph
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class TestRun:
 | 
					class TestRun:
 | 
				
			||||||
| 
						 | 
					@ -12,18 +16,52 @@ class TestRun:
 | 
				
			||||||
        self.dataset = dataset
 | 
					        self.dataset = dataset
 | 
				
			||||||
        self.encoder = encoder
 | 
					        self.encoder = encoder
 | 
				
			||||||
        self.corruption = corruption
 | 
					        self.corruption = corruption
 | 
				
			||||||
 | 
					        self.log = logging.getLogger(self.__class__.__name__)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def run(self):
 | 
					    def run(self, retrain: bool = False, save_model: bool = True):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Run the test
 | 
				
			||||||
 | 
					        :param retrain: If the auto-encoder should be trained from scratch
 | 
				
			||||||
 | 
					        :type retrain: bool
 | 
				
			||||||
 | 
					        :param save_model: If the auto-encoder should be saved after re-training (only effective when retraining)
 | 
				
			||||||
 | 
					        :type save_model: bool
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        # Verify inputs
 | 
				
			||||||
        if self.dataset is None:
 | 
					        if self.dataset is None:
 | 
				
			||||||
            raise ValueError("Cannot run test! Dataset is not specified.")
 | 
					            raise ValueError("Cannot run test! Dataset is not specified.")
 | 
				
			||||||
        if self.encoder is None:
 | 
					        if self.encoder is None:
 | 
				
			||||||
            raise ValueError("Cannot run test! AutoEncoder is not specified.")
 | 
					            raise ValueError("Cannot run test! AutoEncoder is not specified.")
 | 
				
			||||||
        if self.corruption is None:
 | 
					        if self.corruption is None:
 | 
				
			||||||
            raise ValueError("Cannot run test! Corruption method is not specified.")
 | 
					            raise ValueError("Cannot run test! Corruption method is not specified.")
 | 
				
			||||||
        return self._run()
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def _run(self):
 | 
					        # Load dataset
 | 
				
			||||||
        raise NotImplementedError()
 | 
					        self.log.info("Loading dataset...")
 | 
				
			||||||
 | 
					        self.dataset.load()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if retrain:
 | 
				
			||||||
 | 
					            # Train encoder
 | 
				
			||||||
 | 
					            self.log.info("Training auto-encoder...")
 | 
				
			||||||
 | 
					            train_loss = self.encoder.train_encoder(self.dataset, epochs=50, num_workers=multiprocessing.cpu_count() - 1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if save_model:
 | 
				
			||||||
 | 
					                self.log.info("Saving auto-encoder model...")
 | 
				
			||||||
 | 
					                self.encoder.save_model(f"{self.encoder.name}_{self.dataset.name}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            # Save train loss graph
 | 
				
			||||||
 | 
					            self.log.info("Saving loss graph...")
 | 
				
			||||||
 | 
					            save_train_loss_graph(train_loss, self.dataset.name)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            self.log.info("Loading saved auto-encoder...")
 | 
				
			||||||
 | 
					            load_success = self.encoder.load_model(f"{self.encoder.name}_{self.dataset.name}")
 | 
				
			||||||
 | 
					            if not load_success:
 | 
				
			||||||
 | 
					                self.log.error("Loading failed. Stopping test run.")
 | 
				
			||||||
 | 
					                return
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Test encoder
 | 
				
			||||||
 | 
					        self.log.info("Testing auto-encoder...")
 | 
				
			||||||
 | 
					        self.encoder.test_encoder(self.dataset, num_workers=multiprocessing.cpu_count() - 1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.log.info("Done!")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def get_metrics(self):
 | 
					    def get_metrics(self):
 | 
				
			||||||
        raise NotImplementedError()
 | 
					        raise NotImplementedError()
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
							
								
								
									
										109
									
								
								output_utils.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										109
									
								
								output_utils.py
									
										
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,109 @@
 | 
				
			||||||
 | 
					import os
 | 
				
			||||||
 | 
					from string import Template
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from tabulate import tabulate
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import matplotlib.pyplot as plt
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from config import TRAIN_TEMP_DATA_BASE_PATH
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def training_header():
 | 
				
			||||||
 | 
					    """Prints a training header to the console"""
 | 
				
			||||||
 | 
					    print('| Epoch | Avg. loss | Train Acc. | Test Acc.  | Elapsed  |   ETA    |\n'
 | 
				
			||||||
 | 
					          '|-------+-----------+------------+------------+----------+----------|')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def training_epoch_stats(epoch, running_loss, total, train_acc, test_acc, elapsed, eta):
 | 
				
			||||||
 | 
					    """This function is called every time an epoch ends, all of the important training statistics are taken as
 | 
				
			||||||
 | 
					    an input and outputted to the console.
 | 
				
			||||||
 | 
					   :param epoch: current training epoch epoch
 | 
				
			||||||
 | 
					   :type epoch: int
 | 
				
			||||||
 | 
					   :param running_loss: current total running loss of the model
 | 
				
			||||||
 | 
					   :type running_loss: float
 | 
				
			||||||
 | 
					   :param total: number of the images that model has trained on
 | 
				
			||||||
 | 
					   :type total: int
 | 
				
			||||||
 | 
					   :param train_acc: models accuracy on the provided train dataset
 | 
				
			||||||
 | 
					   :type train_acc: str
 | 
				
			||||||
 | 
					   :param test_acc: models accuracy on the optionally provided eval dataset
 | 
				
			||||||
 | 
					   :type test_acc: str
 | 
				
			||||||
 | 
					   :param elapsed: total time it took to run the epoch
 | 
				
			||||||
 | 
					   :type elapsed: str
 | 
				
			||||||
 | 
					   :param eta: an estimated time when training will end
 | 
				
			||||||
 | 
					   :type eta: str
 | 
				
			||||||
 | 
					   """
 | 
				
			||||||
 | 
					    print('| {:5d} | {:.7f} | {:10s} | {:10s} | {:8s} | {} |'
 | 
				
			||||||
 | 
					          .format(epoch, running_loss / total, train_acc, test_acc, elapsed, eta))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def training_finished(start, now):
 | 
				
			||||||
 | 
					    """Outputs elapsed training time to the console.
 | 
				
			||||||
 | 
					    :param start: time when the training has started
 | 
				
			||||||
 | 
					    :type start: datetime.datetime
 | 
				
			||||||
 | 
					    :param now: current time
 | 
				
			||||||
 | 
					    :type now: datetime.datetime
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    print('Training finished, total time elapsed: {}\n'.format(now - start))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def accuracy_summary_basic(total, correct, acc):
 | 
				
			||||||
 | 
					    """Outputs model evaluation statistics to the console if verbosity is set to 1.
 | 
				
			||||||
 | 
					    :param total: number of total objects model was evaluated on
 | 
				
			||||||
 | 
					    :type total: int
 | 
				
			||||||
 | 
					    :param correct: number of correctly classified objects
 | 
				
			||||||
 | 
					    :type correct: int
 | 
				
			||||||
 | 
					    :param acc: overall accuracy of the model
 | 
				
			||||||
 | 
					    :type acc: float
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    print('Total: {} -- Correct: {} -- Accuracy: {}%\n'.format(total, correct, acc))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def accuracy_summary_extended(classes, class_total, class_correct):
 | 
				
			||||||
 | 
					    """Outputs model evaluation statistics to the console if  verbosity is equal or greater than 2.
 | 
				
			||||||
 | 
					    :param classes: list with class names on which model was evaluated
 | 
				
			||||||
 | 
					    :type classes: list
 | 
				
			||||||
 | 
					    :param class_total: list with a number of classes on which model was evaluated
 | 
				
			||||||
 | 
					    :type class_total: list
 | 
				
			||||||
 | 
					    :param class_correct: list with a number of properly classified classes
 | 
				
			||||||
 | 
					    :type class_correct: list
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    print('| {} | {} | {} |'.format(type(classes), type(class_total), type(class_correct)))
 | 
				
			||||||
 | 
					    table = []
 | 
				
			||||||
 | 
					    for i, c in enumerate(classes):
 | 
				
			||||||
 | 
					        if class_total[i] != 0:
 | 
				
			||||||
 | 
					            class_acc = '{:.1f}%'.format(100 * class_correct[i] / class_total[i])
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            class_acc = '-'
 | 
				
			||||||
 | 
					        table.append([c, class_total[i], class_correct[i], class_acc])
 | 
				
			||||||
 | 
					    print(tabulate(table, headers=['Class', 'Total', 'Correct', 'Acc'], tablefmt='orgtbl'), '\n')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def strfdelta(tdelta, fmt='%H:%M:%S'):
 | 
				
			||||||
 | 
					    """Similar to strftime, but this one is for a datetime.timedelta object.
 | 
				
			||||||
 | 
					    :param tdelta: datetime object containing some time difference
 | 
				
			||||||
 | 
					    :type tdelta: datetime.timedelta
 | 
				
			||||||
 | 
					    :param fmt: string with a format
 | 
				
			||||||
 | 
					    :type fmt: str
 | 
				
			||||||
 | 
					    :return: string containing a time in a given format
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    class DeltaTemplate(Template):
 | 
				
			||||||
 | 
					        delimiter = "%"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    d = {"D": tdelta.days}
 | 
				
			||||||
 | 
					    hours, rem = divmod(tdelta.seconds, 3600)
 | 
				
			||||||
 | 
					    minutes, seconds = divmod(rem, 60)
 | 
				
			||||||
 | 
					    d["H"] = '{:02d}'.format(hours)
 | 
				
			||||||
 | 
					    d["M"] = '{:02d}'.format(minutes)
 | 
				
			||||||
 | 
					    d["S"] = '{:02d}'.format(seconds)
 | 
				
			||||||
 | 
					    t = DeltaTemplate(fmt)
 | 
				
			||||||
 | 
					    return t.substitute(**d)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def save_train_loss_graph(train_loss, filename):
 | 
				
			||||||
 | 
					    plt.figure()
 | 
				
			||||||
 | 
					    plt.plot(train_loss)
 | 
				
			||||||
 | 
					    plt.title('Train Loss')
 | 
				
			||||||
 | 
					    plt.xlabel('Epochs')
 | 
				
			||||||
 | 
					    plt.ylabel('Loss')
 | 
				
			||||||
 | 
					    plt.savefig(os.path.join(TRAIN_TEMP_DATA_BASE_PATH, f'{filename}_loss.png'))
 | 
				
			||||||
							
								
								
									
										5
									
								
								requirements.txt
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								requirements.txt
									
										
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,5 @@
 | 
				
			||||||
 | 
					torch==1.7.0+cu110
 | 
				
			||||||
 | 
					torchvision==0.8.1+cu110
 | 
				
			||||||
 | 
					torchaudio===0.7.0
 | 
				
			||||||
 | 
					tabulate
 | 
				
			||||||
 | 
					matplotlib
 | 
				
			||||||
		Reference in a new issue