65 lines
2.3 KiB
Python
Executable File
65 lines
2.3 KiB
Python
Executable File
import math
|
|
from typing import Callable, Dict, List, Optional, Tuple
|
|
|
|
import numpy as np
|
|
from tensorflow.keras.utils import Sequence
|
|
|
|
class BatchGenerator(Sequence):
|
|
def __init__(self,
|
|
images: Dict[str, List[np.ndarray]],
|
|
segmentations: List[np.ndarray],
|
|
sequences: List[str],
|
|
shape: Tuple[int],
|
|
indexes: List[int] = None,
|
|
batch_size: int = 5,
|
|
shuffle: bool = False,
|
|
augmentation_function: Optional[Callable\
|
|
[[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray]]] = None):
|
|
|
|
self.images = images
|
|
self.segmentations = segmentations
|
|
self.sequences = sequences
|
|
self.indexes = indexes
|
|
self.batch_size = batch_size
|
|
self.shuffle = shuffle
|
|
self.augmentation = augmentation_function
|
|
|
|
self.num_rows = len(images)
|
|
if self.indexes == None:
|
|
self.indexes = list(range(self.num_rows))
|
|
|
|
if type(indexes) == int:
|
|
self.indexes = list(range(self.indexes))
|
|
|
|
if self.batch_size == None:
|
|
self.batch_size = len(self.indexes)
|
|
|
|
# Stack image index as index 0, stack sequences as index -1
|
|
self._images = np.stack(
|
|
[np.stack(self.images[k], axis=0) for k in sequences], axis=-1)
|
|
self._segmentations = np.stack(segmentations, axis=0)[..., np.newaxis]
|
|
|
|
def __len__(self):
|
|
return math.ceil(len(self.indexes) / self.batch_size)
|
|
|
|
def __getitem__(self, batch_idx):
|
|
# Get image / segmentation indexes for current batch
|
|
indexes = self.indexes[batch_idx*self.batch_size:(1+batch_idx)*self.batch_size]
|
|
|
|
# Get matching images and segmentations
|
|
images = self._images[indexes]
|
|
segmentations = self._segmentations[indexes]
|
|
|
|
# Apply data augmentation if a function is provided
|
|
if self.augmentation:
|
|
for img_idx in range(len(indexes)):
|
|
images[img_idx], segmentations[img_idx] = self.augmentation(
|
|
images[img_idx], segmentations[img_idx])
|
|
return images, segmentations
|
|
|
|
def on_epoch_end(self):
|
|
super().on_epoch_end()
|
|
|
|
if self.shuffle:
|
|
np.random.shuffle(self.indexes)
|