fast-mri/src/sfransen/load_images.py

44 lines
1.3 KiB
Python
Executable File

from typing import List
import SimpleITK as sitk
from sfransen.DWI_exp.helpers import *
import numpy as np
def load_images_parrallel(
image_paths: str,
seq: str,
target_shape: List[int],
target_space = List[float]):
img_s = sitk.ReadImage(image_paths, sitk.sitkFloat32)
#resample
mri_tra_s = resample(img_s,
min_shape=(s+1 for s in target_shape),
method=sitk.sitkNearestNeighbor,
new_spacing=target_space)
#center crop
mri_tra_s = center_crop(mri_tra_s, shape=target_shape)
#normalize
# if seq != 'seg':
# filter = sitk.NormalizeImageFilter()
# mri_tra_s = filter.Execute(mri_tra_s)
# else:
# filter = sitk.BinaryThresholdImageFilter()
# filter.SetLowerThreshold(1.0)
# mri_tra_s = filter.Execute(mri_tra_s)
# return sitk.GetArrayFromImage(mri_tra_s).T
# Return sitk.Image instead of numpy np.ndarray.
### method trained in Unet
img_n = sitk.GetArrayFromImage(mri_tra_s).T
if seq != 'seg':
image_return = (img_n - np.mean(img_n)) / ( 2* np.std(img_n))
else:
image_return = np.clip(img_n, 0., 1.)
return image_return