fast-mri/src/sfransen/DWI_exp/batchgenerator.py

65 lines
2.3 KiB
Python
Executable File

import math
from typing import Callable, Dict, List, Optional, Tuple
import numpy as np
from tensorflow.keras.utils import Sequence
class BatchGenerator(Sequence):
def __init__(self,
images: Dict[str, List[np.ndarray]],
segmentations: List[np.ndarray],
sequences: List[str],
shape: Tuple[int],
indexes: List[int] = None,
batch_size: int = 5,
shuffle: bool = False,
augmentation_function: Optional[Callable\
[[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray]]] = None):
self.images = images
self.segmentations = segmentations
self.sequences = sequences
self.indexes = indexes
self.batch_size = batch_size
self.shuffle = shuffle
self.augmentation = augmentation_function
self.num_rows = len(images)
if self.indexes == None:
self.indexes = list(range(self.num_rows))
if type(indexes) == int:
self.indexes = list(range(self.indexes))
if self.batch_size == None:
self.batch_size = len(self.indexes)
# Stack image index as index 0, stack sequences as index -1
self._images = np.stack(
[np.stack(self.images[k], axis=0) for k in sequences], axis=-1)
self._segmentations = np.stack(segmentations, axis=0)[..., np.newaxis]
def __len__(self):
return math.ceil(len(self.indexes) / self.batch_size)
def __getitem__(self, batch_idx):
# Get image / segmentation indexes for current batch
indexes = self.indexes[batch_idx*self.batch_size:(1+batch_idx)*self.batch_size]
# Get matching images and segmentations
images = self._images[indexes]
segmentations = self._segmentations[indexes]
# Apply data augmentation if a function is provided
if self.augmentation:
for img_idx in range(len(indexes)):
images[img_idx], segmentations[img_idx] = self.augmentation(
images[img_idx], segmentations[img_idx])
return images, segmentations
def on_epoch_end(self):
super().on_epoch_end()
if self.shuffle:
np.random.shuffle(self.indexes)