fast-mri/src/sfransen/DWI_exp/helpers.py

239 lines
7.4 KiB
Python
Executable File

import os
from os import path
import numpy as np
import SimpleITK as sitk
from scipy import ndimage
import yaml
def print_(*args, **kwargs):
print(*args, **kwargs, flush=True)
def print_config(config, depth=0):
for k in config:
print(end=(" "*depth*2) + f"- {k}:")
if type(config[k]) == dict:
print()
print_config(config[k], depth+1)
else:
print(f"\t{config[k]}")
print_(end="")
def prepare_project_dir(project_dir):
"""
Prepares the work directory for a project
TODO: Create text file containing configuration
"""
def _makedir(n):
os.makedirs(path.join(project_dir, n), exist_ok=True)
_makedir("models")
_makedir("logs")
_makedir("data")
_makedir("docs")
_makedir("samples")
_makedir("output")
def fake_loss(*args, **kwargs):
return 0.
def center_crop(
image: sitk.Image, shape: list, offset: list = None
) -> sitk.Image:
"""Extracts a region of interest from the center of an SITK image.
Parameters:
image: Input image (SITK).
shape: The shape of the
Returns: Cropped image (SITK)
"""
size = image.GetSize()
# Determine the centroid
centroid = [sz / 2 for sz in size]
# Determine the origin for the bounding box by subtracting half the
# shape of the bounding box from each dimension of the centroid.
box_start = [int(c - sh / 2) for c, sh in zip(centroid, shape)]
if offset:
box_start = [b - o for b, o in zip(box_start, offset)]
# Extract the region of provided shape starting from the previously
# determined starting pixel.
region_extractor = sitk.RegionOfInterestImageFilter()
region_extractor.SetSize(shape)
region_extractor.SetIndex(box_start)
cropped_image = region_extractor.Execute(image)
return cropped_image
def augment(img, seg,
noise_chance = 0.3,
noise_mult_max = 0.001,
rotate_chance = 0.2,
rotate_max_angle = 30,
):
#noise_mult_max was initially 0.1
if np.random.uniform() < noise_chance:
img = augment_noise(img, np.random.uniform(0., noise_mult_max))
if np.random.uniform() < rotate_chance:
img, seg = augment_rotate(img, seg,
angle=np.random.uniform(0-rotate_max_angle, rotate_max_angle))
img, seg = augment_tilt(img, seg,
angle=np.random.uniform(0-rotate_max_angle, rotate_max_angle))
return img, seg
def augment_noise(img, multiplier):
noise = np.random.standard_normal(img.shape) * multiplier
return img + noise
def augment_rotate(img, seg, angle):
img = ndimage.rotate(img, angle, reshape=False, cval=0)
seg = ndimage.rotate(seg, angle, reshape=False, order=0)
return img, seg
def augment_tilt(img, seg, angle):
img = ndimage.rotate(img, angle, reshape=False, axes=(2,1), cval=0)
seg = ndimage.rotate(seg, angle, reshape=False, axes=(2,1), order=0)
return img, seg
def resample(
image: sitk.Image, min_shape: list, method=sitk.sitkLinear,
new_spacing: list=[1, 1, 3.6]
) -> sitk.Image:
"""Resamples an image to given target spacing and shape.
Parameters:
image: Input image (SITK).
shape: Minimum output shape for the underlying array.
method: SimpleITK interpolator to use for resampling.
(e.g. sitk.sitkNearestNeighbor, sitk.sitkLinear)
new_spacing: The new spacing to resample to.
Returns:
int: Resampled image
"""
# Extract size and spacing from the image
size = image.GetSize()
spacing = image.GetSpacing()
# Determine how much larger the image will become with the new spacing
factor = [sp / new_sp for sp, new_sp in zip(spacing, new_spacing)]
# Determine the outcome size of the image for each dimension
get_size = lambda size, factor, min_shape: max(int(size * factor), min_shape)
new_size = [get_size(sz, f, sh) for sz, f, sh in zip(size, factor, min_shape)]
# Resample the image
resampler = sitk.ResampleImageFilter()
resampler.SetReferenceImage(image)
resampler.SetOutputSpacing(new_spacing)
resampler.SetSize(new_size)
resampler.SetInterpolator(method)
resampled_image = resampler.Execute(image)
return resampled_image
# Start training
def get_generator(
images: dict,
segmentations: list,
sequences: list,
shape: tuple,
indexes: list = None,
batch_size: int = 5,
shuffle: bool = False,
augmentation = True
):
"""
Returns a (training) generator for use with model.fit().
Parameters:
input_modalities: List of modalty names to include.
output_modalities: Names of the target modalities.
batch_size: Number of images per batch (default: all).
indexes: Only use the specified image indexes.
shuffle: Shuffle the lists of indexes once at the beginning.
augmentation: Apply augmentation or not (bool).
"""
num_rows = len(images)
if indexes == None:
indexes = list(range(num_rows))
if type(indexes) == int:
indexes = list(range(indexes))
if batch_size == None:
batch_size = len(indexes)
idx = 0
# Prepare empty batch placeholder with named inputs and outputs
input_batch = np.zeros((batch_size,) + shape + (len(images),))
output_batch = np.zeros((batch_size,) + shape + (1,))
# Loop infinitely to keep generating batches
while True:
# Prepare each observation in a batch
for batch_idx in range(batch_size):
# Shuffle the order of images if all indexes have been seen
if idx == 0 and shuffle:
np.random.shuffle(indexes)
current_index = indexes[idx]
# Insert the augmented images into the input batch
img_crop = [images[s][current_index] for s in sequences]
img_crop = np.stack(img_crop, axis=-1)
seg_crop = segmentations[current_index][..., np.newaxis]
if augmentation:
img_crop, seg_crop = augment(img_crop, seg_crop)
input_batch[batch_idx] = img_crop
output_batch[batch_idx] = seg_crop
# Increase the current index and modulo by the number of rows
# so that we stay within bounds
idx = (idx + 1) % len(indexes)
yield input_batch, output_batch
def resample_to_reference(image, ref_img,
interpolator=sitk.sitkNearestNeighbor,
default_pixel_value=0):
resampled_img = sitk.Resample(image, ref_img,
sitk.Transform(),
interpolator, default_pixel_value,
ref_img.GetPixelID())
return resampled_img
def dump_dict_to_yaml(
data: dict,
target_dir: str,
filename: str = "settings") -> None:
""" Writes the given dictionary as a yaml to the target directory.
Parameters:
`data (dict)`: dictionary of data.
`target_dir (str)`: directory where the yaml will be saved.
`` filename (str)`: name of the file without extension
"""
print("\nParameters")
for pair in data.items():
print(f"\t{pair}")
print()
path = f"{target_dir}/{filename}.yml"
print_(f"Wrote yaml to: {path}")
with open(path, 'w') as outfile:
yaml.dump(data, outfile, default_flow_style=False)