import math from typing import Callable, Dict, List, Optional, Tuple import numpy as np from tensorflow.keras.utils import Sequence class BatchGenerator(Sequence): def __init__(self, images: Dict[str, List[np.ndarray]], segmentations: List[np.ndarray], sequences: List[str], shape: Tuple[int], indexes: List[int] = None, batch_size: int = 5, shuffle: bool = False, augmentation_function: Optional[Callable\ [[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray]]] = None): self.images = images self.segmentations = segmentations self.sequences = sequences self.indexes = indexes self.batch_size = batch_size self.shuffle = shuffle self.augmentation = augmentation_function self.num_rows = len(images) if self.indexes == None: self.indexes = list(range(self.num_rows)) if type(indexes) == int: self.indexes = list(range(self.indexes)) if self.batch_size == None: self.batch_size = len(self.indexes) # Stack image index as index 0, stack sequences as index -1 self._images = np.stack( [np.stack(self.images[k], axis=0) for k in sequences], axis=-1) self._segmentations = np.stack(segmentations, axis=0)[..., np.newaxis] def __len__(self): return math.ceil(len(self.indexes) / self.batch_size) def __getitem__(self, batch_idx): # Get image / segmentation indexes for current batch indexes = self.indexes[batch_idx*self.batch_size:(1+batch_idx)*self.batch_size] # Get matching images and segmentations images = self._images[indexes] segmentations = self._segmentations[indexes] # Apply data augmentation if a function is provided if self.augmentation: for img_idx in range(len(indexes)): images[img_idx], segmentations[img_idx] = self.augmentation( images[img_idx], segmentations[img_idx]) return images, segmentations def on_epoch_end(self): super().on_epoch_end() if self.shuffle: np.random.shuffle(self.indexes)