import os from os import path import numpy as np import SimpleITK as sitk from scipy import ndimage import yaml def print_(*args, **kwargs): print(*args, **kwargs, flush=True) def print_config(config, depth=0): for k in config: print(end=(" "*depth*2) + f"- {k}:") if type(config[k]) == dict: print() print_config(config[k], depth+1) else: print(f"\t{config[k]}") print_(end="") def prepare_project_dir(project_dir): """ Prepares the work directory for a project TODO: Create text file containing configuration """ def _makedir(n): os.makedirs(path.join(project_dir, n), exist_ok=True) _makedir("models") _makedir("logs") _makedir("data") _makedir("docs") _makedir("samples") _makedir("output") def fake_loss(*args, **kwargs): return 0. def center_crop( image: sitk.Image, shape: list, offset: list = None ) -> sitk.Image: """Extracts a region of interest from the center of an SITK image. Parameters: image: Input image (SITK). shape: The shape of the Returns: Cropped image (SITK) """ size = image.GetSize() # Determine the centroid centroid = [sz / 2 for sz in size] # Determine the origin for the bounding box by subtracting half the # shape of the bounding box from each dimension of the centroid. box_start = [int(c - sh / 2) for c, sh in zip(centroid, shape)] if offset: box_start = [b - o for b, o in zip(box_start, offset)] # Extract the region of provided shape starting from the previously # determined starting pixel. region_extractor = sitk.RegionOfInterestImageFilter() region_extractor.SetSize(shape) region_extractor.SetIndex(box_start) cropped_image = region_extractor.Execute(image) return cropped_image def augment(img, seg, noise_chance = 0.3, noise_mult_max = 0.001, rotate_chance = 0.2, rotate_max_angle = 30, ): #noise_mult_max was initially 0.1 if np.random.uniform() < noise_chance: img = augment_noise(img, np.random.uniform(0., noise_mult_max)) if np.random.uniform() < rotate_chance: img, seg = augment_rotate(img, seg, angle=np.random.uniform(0-rotate_max_angle, rotate_max_angle)) img, seg = augment_tilt(img, seg, angle=np.random.uniform(0-rotate_max_angle, rotate_max_angle)) return img, seg def augment_noise(img, multiplier): noise = np.random.standard_normal(img.shape) * multiplier return img + noise def augment_rotate(img, seg, angle): img = ndimage.rotate(img, angle, reshape=False, cval=0) seg = ndimage.rotate(seg, angle, reshape=False, order=0) return img, seg def augment_tilt(img, seg, angle): img = ndimage.rotate(img, angle, reshape=False, axes=(2,1), cval=0) seg = ndimage.rotate(seg, angle, reshape=False, axes=(2,1), order=0) return img, seg def resample( image: sitk.Image, min_shape: list, method=sitk.sitkLinear, new_spacing: list=[1, 1, 3.6] ) -> sitk.Image: """Resamples an image to given target spacing and shape. Parameters: image: Input image (SITK). shape: Minimum output shape for the underlying array. method: SimpleITK interpolator to use for resampling. (e.g. sitk.sitkNearestNeighbor, sitk.sitkLinear) new_spacing: The new spacing to resample to. Returns: int: Resampled image """ # Extract size and spacing from the image size = image.GetSize() spacing = image.GetSpacing() # Determine how much larger the image will become with the new spacing factor = [sp / new_sp for sp, new_sp in zip(spacing, new_spacing)] # Determine the outcome size of the image for each dimension get_size = lambda size, factor, min_shape: max(int(size * factor), min_shape) new_size = [get_size(sz, f, sh) for sz, f, sh in zip(size, factor, min_shape)] # Resample the image resampler = sitk.ResampleImageFilter() resampler.SetReferenceImage(image) resampler.SetOutputSpacing(new_spacing) resampler.SetSize(new_size) resampler.SetInterpolator(method) resampled_image = resampler.Execute(image) return resampled_image # Start training def get_generator( images: dict, segmentations: list, sequences: list, shape: tuple, indexes: list = None, batch_size: int = 5, shuffle: bool = False, augmentation = True ): """ Returns a (training) generator for use with model.fit(). Parameters: input_modalities: List of modalty names to include. output_modalities: Names of the target modalities. batch_size: Number of images per batch (default: all). indexes: Only use the specified image indexes. shuffle: Shuffle the lists of indexes once at the beginning. augmentation: Apply augmentation or not (bool). """ num_rows = len(images) if indexes == None: indexes = list(range(num_rows)) if type(indexes) == int: indexes = list(range(indexes)) if batch_size == None: batch_size = len(indexes) idx = 0 # Prepare empty batch placeholder with named inputs and outputs input_batch = np.zeros((batch_size,) + shape + (len(images),)) output_batch = np.zeros((batch_size,) + shape + (1,)) # Loop infinitely to keep generating batches while True: # Prepare each observation in a batch for batch_idx in range(batch_size): # Shuffle the order of images if all indexes have been seen if idx == 0 and shuffle: np.random.shuffle(indexes) current_index = indexes[idx] # Insert the augmented images into the input batch img_crop = [images[s][current_index] for s in sequences] img_crop = np.stack(img_crop, axis=-1) seg_crop = segmentations[current_index][..., np.newaxis] if augmentation: img_crop, seg_crop = augment(img_crop, seg_crop) input_batch[batch_idx] = img_crop output_batch[batch_idx] = seg_crop # Increase the current index and modulo by the number of rows # so that we stay within bounds idx = (idx + 1) % len(indexes) yield input_batch, output_batch def resample_to_reference(image, ref_img, interpolator=sitk.sitkNearestNeighbor, default_pixel_value=0): resampled_img = sitk.Resample(image, ref_img, sitk.Transform(), interpolator, default_pixel_value, ref_img.GetPixelID()) return resampled_img def dump_dict_to_yaml( data: dict, target_dir: str, filename: str = "settings") -> None: """ Writes the given dictionary as a yaml to the target directory. Parameters: `data (dict)`: dictionary of data. `target_dir (str)`: directory where the yaml will be saved. `` filename (str)`: name of the file without extension """ print("\nParameters") for pair in data.items(): print(f"\t{pair}") print() path = f"{target_dir}/{filename}.yml" print_(f"Wrote yaml to: {path}") with open(path, 'w') as outfile: yaml.dump(data, outfile, default_flow_style=False)