First basic auto-encoder and CIFAR-10 dataset implemented
This commit is contained in:
		
							parent
							
								
									51cc5d1d30
								
							
						
					
					
						commit
						62f9b873e9
					
				
					 10 changed files with 509 additions and 27 deletions
				
			
		| 
						 | 
				
			
			@ -1,8 +1,11 @@
 | 
			
		|||
import math
 | 
			
		||||
import math, logging
 | 
			
		||||
from typing import Union, Optional
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from torch.utils.data import Dataset
 | 
			
		||||
 | 
			
		||||
class BaseDataset:
 | 
			
		||||
 | 
			
		||||
class BaseDataset(Dataset):
 | 
			
		||||
 | 
			
		||||
    # Train amount is either a proportion of data that should be used as training data (between 0 and 1),
 | 
			
		||||
    # or an integer indicating how many entries should be used as training data (e.g. 1000, 2000)
 | 
			
		||||
| 
						 | 
				
			
			@ -16,10 +19,14 @@ class BaseDataset:
 | 
			
		|||
    _data = None
 | 
			
		||||
    _trainset: 'BaseDataset' = None
 | 
			
		||||
    _testset: 'BaseDataset' = None
 | 
			
		||||
    transform = None
 | 
			
		||||
 | 
			
		||||
    def __init__(self, name: Optional[str] = None):
 | 
			
		||||
    def __init__(self, name: Optional[str] = None, path: Optional[str] = None):
 | 
			
		||||
        self.log = logging.getLogger(self.__class__.__name__)
 | 
			
		||||
        if name is not None:
 | 
			
		||||
            self.name = name
 | 
			
		||||
        if path is not None:
 | 
			
		||||
            self._source_path = path
 | 
			
		||||
 | 
			
		||||
    def __str__(self):
 | 
			
		||||
        if self._data is not None:
 | 
			
		||||
| 
						 | 
				
			
			@ -27,19 +34,30 @@ class BaseDataset:
 | 
			
		|||
        else:
 | 
			
		||||
            return f"{self.name} (no data loaded)"
 | 
			
		||||
 | 
			
		||||
    # __len__ so that len(dataset) returns the size of the dataset.
 | 
			
		||||
    # __getitem__ to support the indexing such that dataset[i] can be used to get ith sample
 | 
			
		||||
    def __len__(self):
 | 
			
		||||
        return len(self._data)
 | 
			
		||||
 | 
			
		||||
    def __getitem__(self, item):
 | 
			
		||||
        return self.transform(self._data[item]) if self.transform else self._data[item]
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def get_new(cls, name: str, data: Optional[list] = None, source_path: Optional[str] = None,
 | 
			
		||||
                train_set: Optional['BaseDataset'] = None, test_set: Optional['BaseDataset'] = None):
 | 
			
		||||
        dset = cls()
 | 
			
		||||
        dset.name = name
 | 
			
		||||
        dset._data = data
 | 
			
		||||
        dset._source_path = source_path
 | 
			
		||||
        dset._trainset = train_set
 | 
			
		||||
        dset._testset = test_set
 | 
			
		||||
        return dset
 | 
			
		||||
 | 
			
		||||
    def load(self, name: str, path: str):
 | 
			
		||||
        self.name = str
 | 
			
		||||
        self._source_path = path
 | 
			
		||||
    def load(self, name: Optional[str] = None, path: Optional[str] = None):
 | 
			
		||||
        if name is not None:
 | 
			
		||||
            self.name = name
 | 
			
		||||
        if path is not None:
 | 
			
		||||
            self._source_path = path
 | 
			
		||||
        raise NotImplementedError()
 | 
			
		||||
 | 
			
		||||
    def _subdivide(self, amount: Union[int, float]):
 | 
			
		||||
| 
						 | 
				
			
			@ -69,3 +87,13 @@ class BaseDataset:
 | 
			
		|||
        if not self._trainset or not self._testset:
 | 
			
		||||
            self._subdivide(self.TRAIN_AMOUNT)
 | 
			
		||||
        return self._testset
 | 
			
		||||
 | 
			
		||||
    def get_loader(self, dataset, batch_size: int = 128, num_workers: int = 4) -> torch.utils.data.DataLoader:
 | 
			
		||||
        return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False,
 | 
			
		||||
                                           num_workers=num_workers, pin_memory=True)
 | 
			
		||||
 | 
			
		||||
    def get_train_loader(self, batch_size: int = 128, num_workers: int = 4) -> torch.utils.data.DataLoader:
 | 
			
		||||
        return self.get_loader(self.get_train(), batch_size=batch_size, num_workers=num_workers)
 | 
			
		||||
 | 
			
		||||
    def get_test_loader(self, batch_size: int = 128, num_workers: int = 4) -> torch.utils.data.DataLoader:
 | 
			
		||||
        return self.get_loader(self.get_test(), batch_size=batch_size, num_workers=num_workers)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,20 +1,172 @@
 | 
			
		|||
import json
 | 
			
		||||
import logging
 | 
			
		||||
import os
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
from typing import Optional
 | 
			
		||||
 | 
			
		||||
from torchvision.utils import save_image
 | 
			
		||||
 | 
			
		||||
from config import TRAIN_TEMP_DATA_BASE_PATH, TEST_TEMP_DATA_BASE_PATH, MODEL_STORAGE_BASE_PATH
 | 
			
		||||
from models.base_dataset import BaseDataset
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class BaseEncoder:
 | 
			
		||||
class BaseEncoder(torch.nn.Module):
 | 
			
		||||
    # Based on https://medium.com/pytorch/implementing-an-autoencoder-in-pytorch-19baa22647d1
 | 
			
		||||
    name = "BaseEncoder"
 | 
			
		||||
 | 
			
		||||
    def __init__(self, name: Optional[str] = None):
 | 
			
		||||
    def __init__(self, name: Optional[str] = None, input_shape: int = 0):
 | 
			
		||||
        super(BaseEncoder, self).__init__()
 | 
			
		||||
        self.log = logging.getLogger(self.__class__.__name__)
 | 
			
		||||
 | 
			
		||||
        if name is not None:
 | 
			
		||||
            self.name = name
 | 
			
		||||
 | 
			
		||||
        assert input_shape != 0, f"Encoder {self.__class__.__name__} input_shape parameter should not be 0"
 | 
			
		||||
 | 
			
		||||
        self.input_shape = input_shape
 | 
			
		||||
 | 
			
		||||
        # Default fallbacks (can be overridden by sub implementations)
 | 
			
		||||
 | 
			
		||||
        # 4 layer NN, halving the input each layer
 | 
			
		||||
        self.network = torch.nn.Sequential(
 | 
			
		||||
            torch.nn.Linear(in_features=input_shape, out_features=input_shape // 2),
 | 
			
		||||
            torch.nn.ReLU(),
 | 
			
		||||
            torch.nn.Linear(in_features=input_shape // 2, out_features=input_shape // 4),
 | 
			
		||||
            torch.nn.ReLU(),
 | 
			
		||||
            torch.nn.Linear(in_features=input_shape // 4, out_features=input_shape // 2),
 | 
			
		||||
            torch.nn.ReLU(),
 | 
			
		||||
            torch.nn.Linear(in_features=input_shape // 2, out_features=input_shape)
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # Use GPU acceleration if available
 | 
			
		||||
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 | 
			
		||||
 | 
			
		||||
        # Adam optimizer with learning rate 1e-3
 | 
			
		||||
        self.optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
 | 
			
		||||
 | 
			
		||||
        # Mean Squared Error loss function
 | 
			
		||||
        self.loss_function = torch.nn.MSELoss()
 | 
			
		||||
 | 
			
		||||
    def after_init(self):
 | 
			
		||||
        self.log.info(f"Auto-encoder {self.__class__.__name__} initialized with {len(list(self.network.children()))} layers on {self.device.type}. Optimizer: {self.optimizer}, Loss function: {self.loss_function}")
 | 
			
		||||
 | 
			
		||||
    def forward(self, features):
 | 
			
		||||
        return self.network(features)
 | 
			
		||||
 | 
			
		||||
    def save_model(self, filename):
 | 
			
		||||
        torch.save(self.state_dict(), os.path.join(MODEL_STORAGE_BASE_PATH, f"{filename}.model"))
 | 
			
		||||
        with open(os.path.join(MODEL_STORAGE_BASE_PATH, f"{filename}.meta"), 'w') as f:
 | 
			
		||||
            f.write(json.dumps({
 | 
			
		||||
                'name': self.name,
 | 
			
		||||
                'input_shape': self.input_shape
 | 
			
		||||
            }))
 | 
			
		||||
 | 
			
		||||
    def load_model(self, filename=None):
 | 
			
		||||
        if filename is None:
 | 
			
		||||
            filename = f"{self.name}"
 | 
			
		||||
        try:
 | 
			
		||||
            loaded_model = torch.load(os.path.join(MODEL_STORAGE_BASE_PATH, f"{filename}.model"), map_location=self.device)
 | 
			
		||||
            self.load_state_dict(loaded_model)
 | 
			
		||||
            self.to(self.device)
 | 
			
		||||
            return True
 | 
			
		||||
        except OSError as e:
 | 
			
		||||
            self.log.error(f"Could not load model '{filename}': {e}")
 | 
			
		||||
            return False
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def create_model_from_file(cls, filename, device=None):
 | 
			
		||||
        try:
 | 
			
		||||
            if device is None:
 | 
			
		||||
                device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 | 
			
		||||
            with open(os.path.join(MODEL_STORAGE_BASE_PATH, f"{filename}.meta")) as f:
 | 
			
		||||
                model_kwargs = json.loads(f.read())
 | 
			
		||||
            model = cls(**model_kwargs)
 | 
			
		||||
            loaded_model = torch.load(os.path.join(MODEL_STORAGE_BASE_PATH, f"{filename}.model"), map_location=device)
 | 
			
		||||
            model.load_state_dict(loaded_model)
 | 
			
		||||
            model.to(device)
 | 
			
		||||
            return model
 | 
			
		||||
        except OSError as e:
 | 
			
		||||
            log = logging.getLogger(cls.__name__)
 | 
			
		||||
            log.error(f"Could not load model '{filename}': {e}")
 | 
			
		||||
            return None
 | 
			
		||||
 | 
			
		||||
    def __str__(self):
 | 
			
		||||
        return f"{self.name}"
 | 
			
		||||
 | 
			
		||||
    def train(self, dataset: BaseDataset):
 | 
			
		||||
        raise NotImplementedError()
 | 
			
		||||
    def train_encoder(self, dataset: BaseDataset, epochs: int = 20, batch_size: int = 128, num_workers: int = 4):
 | 
			
		||||
        self.log.debug("Getting training dataset DataLoader.")
 | 
			
		||||
        train_loader = dataset.get_train_loader(batch_size=batch_size, num_workers=num_workers)
 | 
			
		||||
 | 
			
		||||
    def test(self, dataset: BaseDataset):
 | 
			
		||||
        raise NotImplementedError()
 | 
			
		||||
        # Puts module in training mode.
 | 
			
		||||
        self.log.debug("Putting model into training mode.")
 | 
			
		||||
        self.train()
 | 
			
		||||
        self.to(self.device, non_blocking=True)
 | 
			
		||||
        self.loss_function.to(self.device, non_blocking=True)
 | 
			
		||||
 | 
			
		||||
        losses = []
 | 
			
		||||
 | 
			
		||||
        outputs = None
 | 
			
		||||
        for epoch in range(epochs):
 | 
			
		||||
            self.log.debug(f"Start training epoch {epoch}...")
 | 
			
		||||
            loss = 0
 | 
			
		||||
            for batch_features in train_loader:
 | 
			
		||||
                # load batch features to the active device
 | 
			
		||||
                batch_features = batch_features.to(self.device)
 | 
			
		||||
 | 
			
		||||
                # reset the gradients back to zero
 | 
			
		||||
                # PyTorch accumulates gradients on subsequent backward passes
 | 
			
		||||
                self.optimizer.zero_grad()
 | 
			
		||||
 | 
			
		||||
                # compute reconstructions
 | 
			
		||||
                outputs = self(batch_features)
 | 
			
		||||
 | 
			
		||||
                # compute training reconstruction loss
 | 
			
		||||
                train_loss = self.loss_function(outputs, batch_features)
 | 
			
		||||
 | 
			
		||||
                # compute accumulated gradients
 | 
			
		||||
                train_loss.backward()
 | 
			
		||||
 | 
			
		||||
                # perform parameter update based on current gradients
 | 
			
		||||
                self.optimizer.step()
 | 
			
		||||
 | 
			
		||||
                # add the mini-batch training loss to epoch loss
 | 
			
		||||
                loss += train_loss.item()
 | 
			
		||||
 | 
			
		||||
            # compute the epoch training loss
 | 
			
		||||
            loss = loss / len(train_loader)
 | 
			
		||||
 | 
			
		||||
            # display the epoch training loss
 | 
			
		||||
            self.log.info("epoch : {}/{}, loss = {:.6f}".format(epoch + 1, epochs, loss))
 | 
			
		||||
            losses.append(loss)
 | 
			
		||||
 | 
			
		||||
            # Every 5 epochs, save a test image
 | 
			
		||||
            if epoch % 5 == 0:
 | 
			
		||||
                img = outputs.cpu().data
 | 
			
		||||
                img = img.view(img.size(0), 3, 32, 32)
 | 
			
		||||
                save_image(img, os.path.join(TRAIN_TEMP_DATA_BASE_PATH,
 | 
			
		||||
                                             f'{self.name}_{dataset.name}_linear_ae_image{epoch}.png'))
 | 
			
		||||
 | 
			
		||||
        return losses
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    def test_encoder(self, dataset: BaseDataset, batch_size: int = 128, num_workers: int = 4):
 | 
			
		||||
        self.log.debug("Getting testing dataset DataLoader.")
 | 
			
		||||
        test_loader = dataset.get_test_loader(batch_size=batch_size, num_workers=num_workers)
 | 
			
		||||
 | 
			
		||||
        self.log.debug(f"Start testing...")
 | 
			
		||||
        i = 0
 | 
			
		||||
        for batch in test_loader:
 | 
			
		||||
            img = batch.view(batch.size(0), 3, 32, 32)
 | 
			
		||||
            save_image(img, os.path.join(TEST_TEMP_DATA_BASE_PATH,
 | 
			
		||||
                                         f'{self.name}_{dataset.name}_test_input_{i}.png'))
 | 
			
		||||
            # load batch features to the active device
 | 
			
		||||
            batch = batch.to(self.device)
 | 
			
		||||
            outputs = self(batch)
 | 
			
		||||
            img = outputs.cpu().data
 | 
			
		||||
            img = img.view(outputs.size(0), 3, 32, 32)
 | 
			
		||||
            save_image(img, os.path.join(TEST_TEMP_DATA_BASE_PATH,
 | 
			
		||||
                                         f'{self.name}_{dataset.name}_test_reconstruction_{i}.png'))
 | 
			
		||||
            i += 1
 | 
			
		||||
            break
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										48
									
								
								models/basic_encoder.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										48
									
								
								models/basic_encoder.py
									
										
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,48 @@
 | 
			
		|||
import json
 | 
			
		||||
import logging
 | 
			
		||||
import os
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
from typing import Optional
 | 
			
		||||
 | 
			
		||||
from models.base_encoder import BaseEncoder
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class BasicAutoEncoder(BaseEncoder):
 | 
			
		||||
    # Based on https://medium.com/pytorch/implementing-an-autoencoder-in-pytorch-19baa22647d1
 | 
			
		||||
    name = "BasicAutoEncoder"
 | 
			
		||||
 | 
			
		||||
    def __init__(self, name: Optional[str] = None, input_shape: int = 0):
 | 
			
		||||
        self.log = logging.getLogger(self.__class__.__name__)
 | 
			
		||||
 | 
			
		||||
        # Call superclass to initialize parameters.
 | 
			
		||||
        super(BasicAutoEncoder, self).__init__(name, input_shape)
 | 
			
		||||
 | 
			
		||||
        # Override parameters to custom values for this encoder type
 | 
			
		||||
 | 
			
		||||
        # TODO - Hoe kan ik het beste bepalen hoe groot de intermediate layers moeten zijn?
 | 
			
		||||
        # - Proportioneel van input grootte naar opgegeven bottleneck grootte?
 | 
			
		||||
        # - Uit een paper plukken
 | 
			
		||||
        # - Zelf kiezen (e.g. helft elke keer, fixed aantal layers)?
 | 
			
		||||
        self.network = torch.nn.Sequential(
 | 
			
		||||
            torch.nn.Linear(in_features=input_shape, out_features=input_shape // 2),
 | 
			
		||||
            torch.nn.ReLU(),
 | 
			
		||||
            torch.nn.Linear(in_features=input_shape // 2, out_features=input_shape // 4),
 | 
			
		||||
            torch.nn.ReLU(),
 | 
			
		||||
            torch.nn.Linear(in_features=input_shape // 4, out_features=input_shape // 2),
 | 
			
		||||
            torch.nn.ReLU(),
 | 
			
		||||
            torch.nn.Linear(in_features=input_shape // 2, out_features=input_shape)
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # Use GPU acceleration if available
 | 
			
		||||
        # self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 | 
			
		||||
 | 
			
		||||
        # Adam optimizer with learning rate 1e-3
 | 
			
		||||
        self.optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
 | 
			
		||||
 | 
			
		||||
        # Mean Squared Error loss function
 | 
			
		||||
        # self.loss_function = torch.nn.MSELoss()
 | 
			
		||||
 | 
			
		||||
        self.after_init()
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										75
									
								
								models/cifar10_dataset.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										75
									
								
								models/cifar10_dataset.py
									
										
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,75 @@
 | 
			
		|||
import os
 | 
			
		||||
 | 
			
		||||
from typing import Optional
 | 
			
		||||
 | 
			
		||||
import numpy
 | 
			
		||||
import torchvision
 | 
			
		||||
from PIL import Image
 | 
			
		||||
from torchvision import transforms
 | 
			
		||||
 | 
			
		||||
from config import DATASET_STORAGE_BASE_PATH
 | 
			
		||||
from models.base_dataset import BaseDataset
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Cifar10Dataset(BaseDataset):
 | 
			
		||||
 | 
			
		||||
    # transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),
 | 
			
		||||
    #                                             torchvision.transforms.Normalize((0.5, ), (0.5, ))
 | 
			
		||||
    #                                             ])
 | 
			
		||||
    transform = transforms.Compose([
 | 
			
		||||
        transforms.ToPILImage(),
 | 
			
		||||
        transforms.ToTensor(),
 | 
			
		||||
        # transforms.Normalize((0.5,), (0.5,))
 | 
			
		||||
    ])
 | 
			
		||||
 | 
			
		||||
    def unpickle(self, filename):
 | 
			
		||||
        import pickle
 | 
			
		||||
        with open(filename, 'rb') as fo:
 | 
			
		||||
            dict = pickle.load(fo, encoding='bytes')
 | 
			
		||||
        return dict
 | 
			
		||||
 | 
			
		||||
    def load(self, name: Optional[str] = None, path: Optional[str] = None):
 | 
			
		||||
        if name is not None:
 | 
			
		||||
            self.name = name
 | 
			
		||||
        if path is not None:
 | 
			
		||||
            self._source_path = path
 | 
			
		||||
 | 
			
		||||
        self._data = []
 | 
			
		||||
        for i in range(1, 6):
 | 
			
		||||
            data = self.unpickle(os.path.join(DATASET_STORAGE_BASE_PATH,
 | 
			
		||||
                                              self._source_path,
 | 
			
		||||
                                              f"data_batch_{i}"))
 | 
			
		||||
            self._data.extend(data[b'data'])
 | 
			
		||||
 | 
			
		||||
        self._trainset = self.__class__.get_new(name=f"{self.name} Training", data=self._data[:],
 | 
			
		||||
                                                source_path=self._source_path)
 | 
			
		||||
 | 
			
		||||
        test_data = self.unpickle(os.path.join(DATASET_STORAGE_BASE_PATH,
 | 
			
		||||
                                               self._source_path,
 | 
			
		||||
                                               f"test_batch"))
 | 
			
		||||
        self._data.extend(test_data[b'data'])
 | 
			
		||||
        self._testset = self.__class__.get_new(name=f"{self.name} Testing", data=test_data[b'data'][:],
 | 
			
		||||
                                               source_path=self._source_path)
 | 
			
		||||
 | 
			
		||||
        self.log.info(f"Loaded {self}, divided into {self._trainset} and {self._testset}")
 | 
			
		||||
 | 
			
		||||
    def __getitem__(self, item):
 | 
			
		||||
        # Get image data
 | 
			
		||||
        img = self._data[item]
 | 
			
		||||
 | 
			
		||||
        img_r, img_g, img_b = img.reshape((3, 1024))
 | 
			
		||||
        img_r = img_r.reshape((32, 32))
 | 
			
		||||
        img_g = img_g.reshape((32, 32))
 | 
			
		||||
        img_b = img_b.reshape((32, 32))
 | 
			
		||||
 | 
			
		||||
        # Reshape to 32x32x3 image
 | 
			
		||||
        img = numpy.stack((img_r, img_g, img_b), axis=2)
 | 
			
		||||
 | 
			
		||||
        # Run transforms
 | 
			
		||||
        if self.transform is not None:
 | 
			
		||||
            img = self.transform(img)
 | 
			
		||||
 | 
			
		||||
        # Reshape the 32x32x3 image to a 1x3072 array for the Linear layer
 | 
			
		||||
        img = img.view(-1, 32 * 32 * 3)
 | 
			
		||||
 | 
			
		||||
        return img
 | 
			
		||||
| 
						 | 
				
			
			@ -1,6 +1,10 @@
 | 
			
		|||
import logging
 | 
			
		||||
import multiprocessing
 | 
			
		||||
 | 
			
		||||
from models.base_corruption import BaseCorruption
 | 
			
		||||
from models.base_dataset import BaseDataset
 | 
			
		||||
from models.base_encoder import BaseEncoder
 | 
			
		||||
from output_utils import save_train_loss_graph
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestRun:
 | 
			
		||||
| 
						 | 
				
			
			@ -12,18 +16,52 @@ class TestRun:
 | 
			
		|||
        self.dataset = dataset
 | 
			
		||||
        self.encoder = encoder
 | 
			
		||||
        self.corruption = corruption
 | 
			
		||||
        self.log = logging.getLogger(self.__class__.__name__)
 | 
			
		||||
 | 
			
		||||
    def run(self):
 | 
			
		||||
    def run(self, retrain: bool = False, save_model: bool = True):
 | 
			
		||||
        """
 | 
			
		||||
        Run the test
 | 
			
		||||
        :param retrain: If the auto-encoder should be trained from scratch
 | 
			
		||||
        :type retrain: bool
 | 
			
		||||
        :param save_model: If the auto-encoder should be saved after re-training (only effective when retraining)
 | 
			
		||||
        :type save_model: bool
 | 
			
		||||
        """
 | 
			
		||||
        # Verify inputs
 | 
			
		||||
        if self.dataset is None:
 | 
			
		||||
            raise ValueError("Cannot run test! Dataset is not specified.")
 | 
			
		||||
        if self.encoder is None:
 | 
			
		||||
            raise ValueError("Cannot run test! AutoEncoder is not specified.")
 | 
			
		||||
        if self.corruption is None:
 | 
			
		||||
            raise ValueError("Cannot run test! Corruption method is not specified.")
 | 
			
		||||
        return self._run()
 | 
			
		||||
 | 
			
		||||
    def _run(self):
 | 
			
		||||
        raise NotImplementedError()
 | 
			
		||||
        # Load dataset
 | 
			
		||||
        self.log.info("Loading dataset...")
 | 
			
		||||
        self.dataset.load()
 | 
			
		||||
 | 
			
		||||
        if retrain:
 | 
			
		||||
            # Train encoder
 | 
			
		||||
            self.log.info("Training auto-encoder...")
 | 
			
		||||
            train_loss = self.encoder.train_encoder(self.dataset, epochs=50, num_workers=multiprocessing.cpu_count() - 1)
 | 
			
		||||
 | 
			
		||||
            if save_model:
 | 
			
		||||
                self.log.info("Saving auto-encoder model...")
 | 
			
		||||
                self.encoder.save_model(f"{self.encoder.name}_{self.dataset.name}")
 | 
			
		||||
 | 
			
		||||
            # Save train loss graph
 | 
			
		||||
            self.log.info("Saving loss graph...")
 | 
			
		||||
            save_train_loss_graph(train_loss, self.dataset.name)
 | 
			
		||||
        else:
 | 
			
		||||
            self.log.info("Loading saved auto-encoder...")
 | 
			
		||||
            load_success = self.encoder.load_model(f"{self.encoder.name}_{self.dataset.name}")
 | 
			
		||||
            if not load_success:
 | 
			
		||||
                self.log.error("Loading failed. Stopping test run.")
 | 
			
		||||
                return
 | 
			
		||||
 | 
			
		||||
        # Test encoder
 | 
			
		||||
        self.log.info("Testing auto-encoder...")
 | 
			
		||||
        self.encoder.test_encoder(self.dataset, num_workers=multiprocessing.cpu_count() - 1)
 | 
			
		||||
 | 
			
		||||
        self.log.info("Done!")
 | 
			
		||||
 | 
			
		||||
    def get_metrics(self):
 | 
			
		||||
        raise NotImplementedError()
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Reference in a new issue