fast-mri/src/sfransen/Make_overlay.py

140 lines
5.9 KiB
Python
Executable File

# http://insightsoftwareconsortium.github.io/SimpleITK-Notebooks/Python_html/05_Results_Visualization.html
from cgitb import grey
import SimpleITK as sitk
def make_isotropic(image, interpolator=sitk.sitkLinear, spacing=None):
"""
Many file formats (e.g. jpg, png,...) expect the pixels to be isotropic, same
spacing for all axes. Saving non-isotropic data in these formats will result in
distorted images. This function makes an image isotropic via resampling, if needed.
Args:
image (SimpleITK.Image): Input image.
interpolator: By default the function uses a linear interpolator. For
label images one should use the sitkNearestNeighbor interpolator
so as not to introduce non-existant labels.
spacing (float): Desired spacing. If none given then use the smallest spacing from
the original image.
Returns:
SimpleITK.Image with isotropic spacing which occupies the same region in space as
the input image.
"""
original_spacing = image.GetSpacing()
# Image is already isotropic, just return a copy.
if all(spc == original_spacing[0] for spc in original_spacing):
return sitk.Image(image)
# Make image isotropic via resampling.
original_size = image.GetSize()
if spacing is None:
spacing = min(original_spacing)
new_spacing = [spacing] * image.GetDimension()
new_size = [
int(round(osz * ospc / spacing))
for osz, ospc in zip(original_size, original_spacing)
]
return sitk.Resample(
image,
new_size,
sitk.Transform(),
interpolator,
image.GetOrigin(),
new_spacing,
image.GetDirection(),
0,
image.GetPixelID(),
)
def mask_image_multiply(mask, image):
components_per_pixel = image.GetNumberOfComponentsPerPixel()
if components_per_pixel == 1:
return mask * image
else:
return sitk.Compose(
[
mask * sitk.VectorIndexSelectionCast(image, channel)
for channel in range(components_per_pixel)
]
)
def alpha_blend(image1, image2, alpha=0.5, mask1=None, mask2=None):
"""
Alaph blend two images, pixels can be scalars or vectors.
The alpha blending factor can be either a scalar or an image whose
pixel type is sitkFloat32 and values are in [0,1].
The region that is alpha blended is controled by the given masks.
"""
if not mask1:
mask1 = sitk.Image(image1.GetSize(), sitk.sitkFloat32) + 1.0
mask1.CopyInformation(image1)
else:
mask1 = sitk.Cast(mask1, sitk.sitkFloat32)
if not mask2:
mask2 = sitk.Image(image2.GetSize(), sitk.sitkFloat32) + 1
mask2.CopyInformation(image2)
else:
mask2 = sitk.Cast(mask2, sitk.sitkFloat32)
# if we received a scalar, convert it to an image
if type(alpha) != sitk.SimpleITK.Image:
alpha = sitk.Image(image1.GetSize(), sitk.sitkFloat32) + alpha
alpha.CopyInformation(image1)
components_per_pixel = image1.GetNumberOfComponentsPerPixel()
if components_per_pixel > 1:
img1 = sitk.Cast(image1, sitk.sitkVectorFloat32)
img2 = sitk.Cast(image2, sitk.sitkVectorFloat32)
else:
img1 = sitk.Cast(image1, sitk.sitkFloat32)
img2 = sitk.Cast(image2, sitk.sitkFloat32)
intersection_mask = mask1 * mask2
intersection_image = mask_image_multiply(
alpha * intersection_mask, img1
) + mask_image_multiply((1 - alpha) * intersection_mask, img2)
return (
intersection_image
+ mask_image_multiply(mask2 - intersection_mask, img2)
+ mask_image_multiply(mask1 - intersection_mask, img1)
)
#Dictionary with functions mapping a scalar image to a three component vector image
image_mappings = {'grey': lambda x: sitk.Compose([x]*3),
'jet' : lambda x: sitk.ScalarToRGBColormap(x, sitk.ScalarToRGBColormapImageFilter.Jet),
'hot' : lambda x: sitk.ScalarToRGBColormap(x, sitk.ScalarToRGBColormapImageFilter.Hot),
'winter' : lambda x: sitk.ScalarToRGBColormap(x, sitk.ScalarToRGBColormapImageFilter.Winter)
}
image = sitk.ReadImage('../train_output/train_b50_b400_b800/output/b-800_b-400_b-50_000_pred.nii')
segmentation = sitk.ReadImage('../train_output/train_b50_b400_b800/output/b-800_b-400_b-50_000_seg.nii')
#Make the images isotropic so that when we save a slice along any axis
#it will be fine (most image formats assume isotropic pixel sizes).
#Segmentation is interpolated with nearest neighbor so we don't introduce new
#labels.
image = make_isotropic(image, interpolator = sitk.sitkLinear)
segmentation = make_isotropic(segmentation, interpolator = sitk.sitkNearestNeighbor)
#Convert image to sitkUInt8 after rescaling, color image formats only work for [0,255]
image_255 = sitk.Cast(sitk.RescaleIntensity(image, 0, 255), sitk.sitkUInt8)
segmentation_255 = sitk.Cast(sitk.RescaleIntensity(segmentation, 0, 255), sitk.sitkUInt8)
colormap = 'hot'
vec_image = image_mappings['hot'](image_255)
vec_image_label = image_mappings['winter'](image_255)
# vec_segmentation = sitk.ScalarToRGBColormap(segmentation, sitk.ScalarToRGBColormapImageFilter.Winter)
# vec_segmentation_2 = sitk.LabelToRGB(segmentation_255,colormap=[255,0,0])
vec_combined = sitk.Cast(alpha_blend(image1=vec_image, image2=vec_image_label, alpha=0.5, mask2=segmentation==1), sitk.sitkVectorUInt8)
sitk.WriteImage(vec_image, '../train_output/train_b50_b400_b800/output/b-800_b-400_b-50_000_vec_image.nii.gz')
sitk.WriteImage(vec_image_label, '../train_output/train_b50_b400_b800/output/b-800_b-400_b-50_000_vec_image_label.nii.gz')
sitk.WriteImage(vec_combined, '../train_output/train_b50_b400_b800/output/b-800_b-400_b-50_000_vec_com.nii.gz')