239 lines
7.4 KiB
Python
Executable File
239 lines
7.4 KiB
Python
Executable File
import os
|
|
from os import path
|
|
|
|
import numpy as np
|
|
import SimpleITK as sitk
|
|
from scipy import ndimage
|
|
import yaml
|
|
|
|
def print_(*args, **kwargs):
|
|
print(*args, **kwargs, flush=True)
|
|
|
|
def print_config(config, depth=0):
|
|
for k in config:
|
|
print(end=(" "*depth*2) + f"- {k}:")
|
|
if type(config[k]) == dict:
|
|
print()
|
|
print_config(config[k], depth+1)
|
|
else:
|
|
print(f"\t{config[k]}")
|
|
print_(end="")
|
|
|
|
def prepare_project_dir(project_dir):
|
|
"""
|
|
Prepares the work directory for a project
|
|
TODO: Create text file containing configuration
|
|
"""
|
|
def _makedir(n):
|
|
os.makedirs(path.join(project_dir, n), exist_ok=True)
|
|
|
|
_makedir("models")
|
|
_makedir("logs")
|
|
_makedir("data")
|
|
_makedir("docs")
|
|
_makedir("samples")
|
|
_makedir("output")
|
|
|
|
|
|
|
|
|
|
def fake_loss(*args, **kwargs):
|
|
return 0.
|
|
|
|
def center_crop(
|
|
image: sitk.Image, shape: list, offset: list = None
|
|
) -> sitk.Image:
|
|
"""Extracts a region of interest from the center of an SITK image.
|
|
|
|
Parameters:
|
|
image: Input image (SITK).
|
|
shape: The shape of the
|
|
|
|
Returns: Cropped image (SITK)
|
|
"""
|
|
|
|
size = image.GetSize()
|
|
|
|
# Determine the centroid
|
|
centroid = [sz / 2 for sz in size]
|
|
|
|
# Determine the origin for the bounding box by subtracting half the
|
|
# shape of the bounding box from each dimension of the centroid.
|
|
box_start = [int(c - sh / 2) for c, sh in zip(centroid, shape)]
|
|
|
|
if offset:
|
|
box_start = [b - o for b, o in zip(box_start, offset)]
|
|
|
|
# Extract the region of provided shape starting from the previously
|
|
# determined starting pixel.
|
|
region_extractor = sitk.RegionOfInterestImageFilter()
|
|
region_extractor.SetSize(shape)
|
|
region_extractor.SetIndex(box_start)
|
|
cropped_image = region_extractor.Execute(image)
|
|
|
|
return cropped_image
|
|
|
|
def augment(img, seg,
|
|
noise_chance = 0.3,
|
|
noise_mult_max = 0.001,
|
|
rotate_chance = 0.2,
|
|
rotate_max_angle = 30,
|
|
):
|
|
#noise_mult_max was initially 0.1
|
|
if np.random.uniform() < noise_chance:
|
|
img = augment_noise(img, np.random.uniform(0., noise_mult_max))
|
|
|
|
if np.random.uniform() < rotate_chance:
|
|
img, seg = augment_rotate(img, seg,
|
|
angle=np.random.uniform(0-rotate_max_angle, rotate_max_angle))
|
|
img, seg = augment_tilt(img, seg,
|
|
angle=np.random.uniform(0-rotate_max_angle, rotate_max_angle))
|
|
|
|
return img, seg
|
|
|
|
def augment_noise(img, multiplier):
|
|
noise = np.random.standard_normal(img.shape) * multiplier
|
|
return img + noise
|
|
|
|
def augment_rotate(img, seg, angle):
|
|
img = ndimage.rotate(img, angle, reshape=False, cval=0)
|
|
seg = ndimage.rotate(seg, angle, reshape=False, order=0)
|
|
return img, seg
|
|
|
|
def augment_tilt(img, seg, angle):
|
|
img = ndimage.rotate(img, angle, reshape=False, axes=(2,1), cval=0)
|
|
seg = ndimage.rotate(seg, angle, reshape=False, axes=(2,1), order=0)
|
|
return img, seg
|
|
|
|
def resample(
|
|
image: sitk.Image, min_shape: list, method=sitk.sitkLinear,
|
|
new_spacing: list=[1, 1, 3.6]
|
|
) -> sitk.Image:
|
|
"""Resamples an image to given target spacing and shape.
|
|
|
|
Parameters:
|
|
image: Input image (SITK).
|
|
shape: Minimum output shape for the underlying array.
|
|
method: SimpleITK interpolator to use for resampling.
|
|
(e.g. sitk.sitkNearestNeighbor, sitk.sitkLinear)
|
|
new_spacing: The new spacing to resample to.
|
|
|
|
Returns:
|
|
int: Resampled image
|
|
|
|
"""
|
|
|
|
# Extract size and spacing from the image
|
|
size = image.GetSize()
|
|
spacing = image.GetSpacing()
|
|
|
|
# Determine how much larger the image will become with the new spacing
|
|
factor = [sp / new_sp for sp, new_sp in zip(spacing, new_spacing)]
|
|
|
|
# Determine the outcome size of the image for each dimension
|
|
get_size = lambda size, factor, min_shape: max(int(size * factor), min_shape)
|
|
new_size = [get_size(sz, f, sh) for sz, f, sh in zip(size, factor, min_shape)]
|
|
|
|
# Resample the image
|
|
resampler = sitk.ResampleImageFilter()
|
|
resampler.SetReferenceImage(image)
|
|
resampler.SetOutputSpacing(new_spacing)
|
|
resampler.SetSize(new_size)
|
|
resampler.SetInterpolator(method)
|
|
resampled_image = resampler.Execute(image)
|
|
|
|
return resampled_image
|
|
|
|
# Start training
|
|
def get_generator(
|
|
images: dict,
|
|
segmentations: list,
|
|
sequences: list,
|
|
shape: tuple,
|
|
indexes: list = None,
|
|
batch_size: int = 5,
|
|
shuffle: bool = False,
|
|
augmentation = True
|
|
):
|
|
"""
|
|
Returns a (training) generator for use with model.fit().
|
|
|
|
Parameters:
|
|
input_modalities: List of modalty names to include.
|
|
output_modalities: Names of the target modalities.
|
|
batch_size: Number of images per batch (default: all).
|
|
indexes: Only use the specified image indexes.
|
|
shuffle: Shuffle the lists of indexes once at the beginning.
|
|
augmentation: Apply augmentation or not (bool).
|
|
"""
|
|
|
|
num_rows = len(images)
|
|
|
|
if indexes == None:
|
|
indexes = list(range(num_rows))
|
|
|
|
if type(indexes) == int:
|
|
indexes = list(range(indexes))
|
|
|
|
if batch_size == None:
|
|
batch_size = len(indexes)
|
|
|
|
idx = 0
|
|
|
|
# Prepare empty batch placeholder with named inputs and outputs
|
|
input_batch = np.zeros((batch_size,) + shape + (len(images),))
|
|
output_batch = np.zeros((batch_size,) + shape + (1,))
|
|
|
|
# Loop infinitely to keep generating batches
|
|
while True:
|
|
# Prepare each observation in a batch
|
|
for batch_idx in range(batch_size):
|
|
# Shuffle the order of images if all indexes have been seen
|
|
if idx == 0 and shuffle:
|
|
np.random.shuffle(indexes)
|
|
current_index = indexes[idx]
|
|
|
|
# Insert the augmented images into the input batch
|
|
img_crop = [images[s][current_index] for s in sequences]
|
|
img_crop = np.stack(img_crop, axis=-1)
|
|
seg_crop = segmentations[current_index][..., np.newaxis]
|
|
if augmentation:
|
|
img_crop, seg_crop = augment(img_crop, seg_crop)
|
|
input_batch[batch_idx] = img_crop
|
|
output_batch[batch_idx] = seg_crop
|
|
|
|
# Increase the current index and modulo by the number of rows
|
|
# so that we stay within bounds
|
|
idx = (idx + 1) % len(indexes)
|
|
|
|
yield input_batch, output_batch
|
|
|
|
def resample_to_reference(image, ref_img,
|
|
interpolator=sitk.sitkNearestNeighbor,
|
|
default_pixel_value=0):
|
|
resampled_img = sitk.Resample(image, ref_img,
|
|
sitk.Transform(),
|
|
interpolator, default_pixel_value,
|
|
ref_img.GetPixelID())
|
|
return resampled_img
|
|
|
|
def dump_dict_to_yaml(
|
|
data: dict,
|
|
target_dir: str,
|
|
filename: str = "settings") -> None:
|
|
""" Writes the given dictionary as a yaml to the target directory.
|
|
|
|
Parameters:
|
|
`data (dict)`: dictionary of data.
|
|
`target_dir (str)`: directory where the yaml will be saved.
|
|
`` filename (str)`: name of the file without extension
|
|
"""
|
|
print("\nParameters")
|
|
for pair in data.items():
|
|
print(f"\t{pair}")
|
|
print()
|
|
|
|
path = f"{target_dir}/{filename}.yml"
|
|
print_(f"Wrote yaml to: {path}")
|
|
with open(path, 'w') as outfile:
|
|
yaml.dump(data, outfile, default_flow_style=False) |