fast-mri/code/pred_label_images.py

75 lines
2.6 KiB
Python
Executable File

# A script with returns the prediction with labels overlay and stores them in a new folder named "report"
#modules
from preprocessing_function import preprocess
import numpy as np
from tqdm import tqdm
from helpers import *
import tensorflow as tf
from tensorflow import keras
#model load
Model_path = '..'
reconstructed_model = tf.keras.models.load_model(Model_path)
reconstructed_model.compile(loss='',
optimizer='',
metrics=[''])
####image load
# arg.series = b-50 b-400 etc.
# DATA_DIR = data dir
# IMAGE_SHAPE =
# TARGET_SPACING =
images, image_paths = {s: [] for s in args.series}, {}
segmentations = []
print_(f"> Loading images into RAM...")
# Read the image paths from the data directory.
# Texts files are expected to have the name "[series_name].txt"
for s in args.series:
with open(path.join(DATA_DIR, f"{s}.txt"), 'r') as f:
image_paths[s] = [l.strip() for l in f.readlines()]
with open(path.join(DATA_DIR, f"seg.txt"), 'r') as f:
seg_paths = [l.strip() for l in f.readlines()]
num_images = len(seg_paths)
# Read and preprocess each of the paths for each series, and the segmentations.
for img_idx in tqdm(range(num_images)): #[:40] for less images
img_s = {s: sitk.ReadImage(image_paths[s][img_idx], sitk.sitkFloat32)
for s in args.series}
seg_s = sitk.ReadImage(seg_paths[img_idx], sitk.sitkFloat32)
img_n, seg_n = preprocess(img_s, seg_s,
shape=IMAGE_SHAPE, spacing=TARGET_SPACING)
for seq in img_n:
images[seq].append(img_n[seq])
segmentations.append(seg_n)
self._images = np.stack(
[np.stack(self.images[k], axis=0) for k in sequences], axis=-1)
#image predict
predictions = reconstructed_model.predict(images[idx])
#heatmap + input image
#save in folder output
# Export scan crops and targets once
# they don't change during training so we export them only once
for i in range(min(self.num_images, self.validation_set[0].shape[0])):
for s_idx, s in enumerate(sequences):
img_s = sitk.GetImageFromArray(
self.validation_set[0][i][..., s_idx].squeeze().T)
sitk.WriteImage(img_s, f"{prefix}_{i:03d}_{s}.nii.gz")
seg_s = sitk.GetImageFromArray(
self.validation_set[1][i].squeeze().T)
sitk.WriteImage(seg_s, f"{prefix}_{i:03d}_seg.nii.gz")
prd_s = sitk.GetImageFromArray(predictions[i].squeeze().T)
prd_bin_s = sitk.GetImageFromArray(
np.around(predictions[i]).astype(np.float32).squeeze().T)
sitk.WriteImage(prd_s, f"{self.prefix}_{i:03d}_pred.nii.gz")
sitk.WriteImage(prd_bin_s, f"{self.prefix}_{i:03d}_pred_bin.nii.gz")