fast-mri/scripts/test3.py

66 lines
3.0 KiB
Python
Executable File

from glob import glob
from os.path import normpath, basename
import SimpleITK as sitk
import numpy as np
import os
from os import path
from sfransen.utils_quintin import *
from sfransen.DWI_exp.helpers import *
from sfransen.DWI_exp.preprocessing_function import preprocess
from sfransen.DWI_exp.callbacks import dice_coef
#from sfransen.FROC.blob_preprocess import *
from sfransen.FROC.cal_froc_from_np import *
from sfransen.load_images import load_images_parrallel
from sfransen.DWI_exp.losses import weighted_binary_cross_entropy
from umcglib.froc import *
from umcglib.binarize import dynamic_threshold
from tensorflow.keras.models import load_model
def get_paths(main_dir):
all_niftis = glob(main_dir, recursive=True)
dwis_b800 = [i for i in all_niftis if ("diff" in i.lower() or "dwi" in i.lower()) and ("b-800" in i.lower() or "b800" in i.lower())]
dwis_b400 = [i for i in all_niftis if ("diff" in i.lower() or "dwi" in i.lower()) and ("b-400" in i.lower() or "b400" in i.lower())]
return dwis_b800, dwis_b400
def get_paths_seg(main_dir):
seg = glob(main_dir, recursive=True)
return seg
def get_paths_train(dir,SERIES,pat_id):
image_path = {}
for s in SERIES:
with open(path.join(dir, f"{s}.txt"), 'r') as f:
image_paths = [l.strip() for l in f.readlines()]
image_path[s] = [i for i in image_paths if pat_id in i]
return image_path
pat_numbers_worst = ['pat0132','pat0091','pat0352','pat0844','pat1006','pat0406','pat0128','pat0153','pat0062','pat0758','pat0932','pat0248','pat0129','pat0429','pat0181','pat0063','pat0674','pat0176','pat0366','pat0082']
pat_numbers_best = ['pat0651','pat0889','pat0448','pat1022','pat0887','pat0194','pat0603','pat0742','pat0811','pat0489','pat0622','pat0582','pat0105','pat0084','pat0643','pat0529','pat0476','pat0514','pat0506','pat0567']
pat_numbers_worst = ['pat0132', 'pat0091','pat0352','pat0844','pat1006','pat0636','pat1009','pat0584','pat0588','pat0198']
load_path = '../../datasets/radboud_new/{pat_number}/2016/**/*.nii.gz'
for idx, pat_number in enumerate(pat_numbers_worst):
print(pat_number)
dwis_b800,dwis_b400 = get_paths(f'../../datasets/radboud_new/{pat_number}/2016/**/*.nii.gz')
seg_path = get_paths_seg(f'/data/pca-rad/datasets/radboud_lesions_2022/{pat_number}*.nii.gz')
# load
dwi_b800 = sitk.ReadImage(dwis_b800, sitk.sitkFloat32)
dwi_b400 = sitk.ReadImage(dwis_b400, sitk.sitkFloat32)
seg = sitk.ReadImage(seg_path, sitk.sitkFloat32)
seg = sitk.GetArrayFromImage(seg)
print('count:', np.sum(np.clip(seg,0,1)))
# write
output_path_b800 = f'../temp/lowest_pred_exp/worst/{idx}_{pat_number}_b800.nii.gz'
output_path_b400 = f'../temp/lowest_pred_exp/worst/{idx}_{pat_number}_b400.nii.gz'
sitk.WriteImage(dwi_b800, output_path_b800)
sitk.WriteImage(dwi_b400, output_path_b400)
###################################################################################################################