from pickle import FALSE from umcglib.froc import calculate_froc, plot_multiple_froc, partial_auc from umcglib.binarize import dynamic_threshold import optuna import sqlite3 from sfransen.utils_quintin import * from os import path from sfransen.DWI_exp.preprocessing_function import preprocess import SimpleITK as sitk import numpy as np from sfransen.DWI_exp.callbacks import dice_coef from sfransen.DWI_exp.losses import weighted_binary_cross_entropy from tensorflow.keras.models import load_model from tqdm import tqdm from optuna.samplers import TPESampler import argparse import shutil import os parser = argparse.ArgumentParser( description='Calculate the froc metrics and store in froc_metrics.yml') parser.add_argument('-experiment', help='Title of experiment') parser.add_argument('-series', '-s', metavar='[series_name]', required=True, nargs='+', help='List of series to include') parser.add_argument('-fold', default='', help='List of series to include') args = parser.parse_args() def does_table_exist(tablename: str, db_path: str): conn = sqlite3.connect(db_path) c = conn.cursor() #get the count of tables with the name c.execute(f''' SELECT count(name) FROM sqlite_master WHERE type='table' AND name='{tablename}' ''') does_exist = False #if the count is 1, then table exists if c.fetchone()[0] == 1: print(f"Table '{tablename}' exists.") does_exist = True else: print(f"Table '{tablename}' does not exists.") #commit the changes to db conn.commit() #close the connection conn.close() return does_exist def load_or_create_study( is_new_study: bool, study_dir: str, ): # Create an optuna if it does not exist. storage = f"sqlite:///{study_dir}/{DB_FNAME}" if is_new_study: print(f"Creating a NEW study. With name: {storage}") study = optuna.create_study(storage=storage, study_name=study_dir, direction='maximize', sampler=TPESampler(n_startup_trials=N_STARTUP_TRIALS)) else: print(f"LOADING study {storage} from database file.") study = optuna.load_study(storage=storage, study_name=study_dir) return study def p_auc_froc_obj(trial, y_true_val, y_pred_val): dyn_thresh = trial.suggest_float('dyn_thresh', 0.0, 1.0) min_conf = trial.suggest_float('min_conf', 0.0, 1.0) stats = calculate_froc(y_true=y_true_val, y_pred=y_pred_val, preprocess_func=dynamic_threshold, dynamic_threshold_factor=dyn_thresh, minimum_confidence=min_conf) sens, fpp = stats['sensitivity'], stats['fp_per_patient'] p_auc_froc = partial_auc(sens, fpp, low=0.1, high=2.5) print(f"dyn_threshold: {dyn_thresh}, min_conf{min_conf}") print(f"Trial {trial.number} pAUC FROC: {p_auc_froc}") return p_auc_froc def convert_np_to_list(flat_numpy_arr): ans = [] for elem in flat_numpy_arr: ans.append(float(elem)) return ans # >>>>>>>>> main <<<<<<<<<<<<< # DB_FNAME = "calc_exp_t2_b1400_adc.db" num_trials = 50 N_STARTUP_TRIALS = 10 SERIES = args.series series_ = '_'.join(SERIES) EXPERIMENT = args.experiment DB_FNAME = f'{EXPERIMENT}_{series_}_{args.fold}.db' MODEL_PATH = f'./../train_output/{EXPERIMENT}_{series_}/models/{EXPERIMENT}_{series_}.h5' YAML_DIR = f'./../train_output/{EXPERIMENT}_{series_}' DATA_DIR = "./../data/Nijmegen paths/" TARGET_SPACING = (0.5, 0.5, 3) INPUT_SHAPE = (192, 192, 24, len(SERIES)) IMAGE_SHAPE = INPUT_SHAPE[:3] DATA_SPLIT_INDEX = read_yaml_to_dict(f'./../data/Nijmegen paths/train_val_test_idxs.yml') TEST_INDEX = DATA_SPLIT_INDEX['val_set0'] print("test test_index",TEST_INDEX[:5]) ############ load data en preprocess / old method # print(">>>>> read images <<<<<<<<<<") # image_paths = {} # for s in SERIES: # with open(path.join(DATA_DIR, f"{s}.txt"), 'r') as f: # image_paths[s] = [l.strip() for l in f.readlines()] # with open(path.join(DATA_DIR, f"seg.txt"), 'r') as f: # seg_paths = [l.strip() for l in f.readlines()] # num_images = len(seg_paths) # images = [] # images_list = [] # segmentations = [] # # Read and preprocess each of the paths for each series, and the segmentations. # for img_idx in tqdm(range(len(TEST_INDEX))): #[:40]): #for less images # # print('images number',[TEST_INDEX[img_idx]]) # img_s = {f'{s}': sitk.ReadImage(image_paths[s][TEST_INDEX[img_idx]], sitk.sitkFloat32) for s in SERIES} # seg_s = sitk.ReadImage(seg_paths[TEST_INDEX[img_idx]], sitk.sitkFloat32) # img_n, seg_n = preprocess(img_s, seg_s, # shape=IMAGE_SHAPE, spacing=TARGET_SPACING) # for seq in img_n: # images.append(img_n[f'{seq}']) # images_list.append(images) # images = [] # segmentations.append(seg_n) # images_list = np.transpose(images_list, (0, 2, 3, 4, 1)) # print("shape of segmentations is",np.shape(segmentations)) # print('>>>>> size image_list nmr 2:', np.shape(images_list), '. equal to: (5, 192, 192, 24, 3)?') # ########### load module ################## # print(' >>>>>>> LOAD MODEL <<<<<<<<<') # dependencies = { # 'dice_coef': dice_coef, # 'weighted_cross_entropy_fn':weighted_binary_cross_entropy # } # reconstructed_model = load_model(MODEL_PATH, custom_objects=dependencies) # # reconstructed_model.summary(line_length=120) # # make predictions on all TEST_INDEX # print(' >>>>>>> START prediction <<<<<<<<<') # predictions_blur = reconstructed_model.predict(images_list, batch_size=1) # ############# preprocess ################# # # preprocess predictions by removing the blur and making individual blobs # print('>>>>>>>> START preprocess') # # def move_dims(arr): # # # UMCG numpy dimensions convention: dims = (batch, width, heigth, depth) # # # Joeran numpy dimensions convention: dims = (batch, depth, heigth, width) # # arr = np.moveaxis(arr, 3, 1) # # arr = np.moveaxis(arr, 3, 2) # # return arr # # # Joeran has his numpy arrays ordered differently. # # predictions_blur = move_dims(np.squeeze(predictions_blur)) # # segmentations = move_dims(np.squeeze(segmentations)) # y_pred_val = np.squeeze(predictions_blur) # y_true_val = segmentations study_dir = f"./../sqliteDB/optuna_dbs" check_for_file = path.isfile(f"{study_dir}/{DB_FNAME}") if check_for_file == False: shutil.copyfile(f"{study_dir}/dyn_thres_min_conf_opt_OG.db", f"{study_dir}/{DB_FNAME}") table_exists = does_table_exist('trials', f"{study_dir}/{DB_FNAME}") study = load_or_create_study(is_new_study=not table_exists, study_dir=study_dir) # # dyn_thresh = study.best_trial.params['dyn_thresh'] # # min_conf = study.best_trial.params['min_conf'] # dyn_thresh = 0.4 # min_conf = 0.01 # # print("step 1:",np.shape(y_pred_val)) # stats = calculate_froc(y_true=y_true_val, # y_pred=y_pred_val, # preprocess_func=dynamic_threshold, # dynamic_threshold_factor=dyn_thresh, # minimum_confidence=min_conf) # sens, fpp = stats['sensitivity'], stats['fp_per_patient'] # p_auc = partial_auc(sens, fpp, low=0.1, high=2.5) # print(f"the p_auc with old setting is: {p_auc}" ) # # Try to find the best value for the dynamic threshold and min_confidence # opt_func = lambda trail: p_auc_froc_obj(trail, y_true_val, y_pred_val) # study.optimize(opt_func, n_trials=num_trials) dyn_thresh = study.best_trial.params['dyn_thresh'] min_conf = study.best_trial.params['min_conf'] print(f"done. best dyn_thresh: {dyn_thresh} . Best min_conf: {min_conf}") ########## dump dict to yaml of best froc curve ############# ######## gooi dit in functie ############### DATA_SPLIT_INDEX = read_yaml_to_dict(f'./../data/Nijmegen paths/train_val_test_idxs.yml') TEST_INDEX = DATA_SPLIT_INDEX['test_set0'] print("test test_index",TEST_INDEX[:5]) ############ load data en preprocess / old method print(">>>>> read images <<<<<<<<<<") image_paths = {} for s in SERIES: with open(path.join(DATA_DIR, f"{s}.txt"), 'r') as f: image_paths[s] = [l.strip() for l in f.readlines()] with open(path.join(DATA_DIR, f"seg.txt"), 'r') as f: seg_paths = [l.strip() for l in f.readlines()] num_images = len(seg_paths) images = [] images_list = [] segmentations = [] # Read and preprocess each of the paths for each series, and the segmentations. for img_idx in tqdm(range(len(TEST_INDEX))): #[:40]): #for less images # print('images number',[TEST_INDEX[img_idx]]) img_s = {f'{s}': sitk.ReadImage(image_paths[s][TEST_INDEX[img_idx]], sitk.sitkFloat32) for s in SERIES} seg_s = sitk.ReadImage(seg_paths[TEST_INDEX[img_idx]], sitk.sitkFloat32) img_n, seg_n = preprocess(img_s, seg_s, shape=IMAGE_SHAPE, spacing=TARGET_SPACING) for seq in img_n: images.append(img_n[f'{seq}']) images_list.append(images) images = [] segmentations.append(seg_n) images_list = np.transpose(images_list, (0, 2, 3, 4, 1)) print("shape of segmentations is",np.shape(segmentations)) print('>>>>> size image_list nmr 2:', np.shape(images_list), '. equal to: (5, 192, 192, 24, 3)?') ########### load module ################## print(' >>>>>>> LOAD MODEL <<<<<<<<<') dependencies = { 'dice_coef': dice_coef, 'weighted_cross_entropy_fn':weighted_binary_cross_entropy } reconstructed_model = load_model(MODEL_PATH, custom_objects=dependencies) # reconstructed_model.summary(line_length=120) # make predictions on all TEST_INDEX print(' >>>>>>> START prediction <<<<<<<<<') predictions_blur = reconstructed_model.predict(images_list, batch_size=1) ############# preprocess ################# # preprocess predictions by removing the blur and making individual blobs print('>>>>>>>> START preprocess') # def move_dims(arr): # # UMCG numpy dimensions convention: dims = (batch, width, heigth, depth) # # Joeran numpy dimensions convention: dims = (batch, depth, heigth, width) # arr = np.moveaxis(arr, 3, 1) # arr = np.moveaxis(arr, 3, 2) # return arr # # Joeran has his numpy arrays ordered differently. # predictions_blur = move_dims(np.squeeze(predictions_blur)) # segmentations = move_dims(np.squeeze(segmentations)) y_pred_val = np.squeeze(predictions_blur) y_true_val = segmentations ########### einde functie ############ stats = calculate_froc(y_true=y_true_val, y_pred=y_pred_val, preprocess_func=dynamic_threshold, dynamic_threshold_factor=dyn_thresh, minimum_confidence=min_conf) subject_idxs = list(range(len(y_true_val))) metrics = { "num_patients": int(stats['num_patients']), "auroc": int(stats['patient_auc']), 'tpr': convert_np_to_list(stats['roc_tpr']), 'fpr': convert_np_to_list(stats['roc_fpr']), "roc_true": convert_np_to_list(stats['roc_patient_level_label'][s] for s in subject_idxs), "roc_pred": convert_np_to_list(stats['roc_patient_level_conf'][s] for s in subject_idxs), "num_lesions": int(stats['num_lesions']), "thresholds": convert_np_to_list(stats['thresholds']), "sensitivity": convert_np_to_list(stats['sensitivity']), "FP_per_case": convert_np_to_list(stats['fp_per_patient']), "precision": convert_np_to_list(stats['precision']), "recall": convert_np_to_list(stats['recall']), "AP": int(stats['average_precision']), } dump_dict_to_yaml(metrics, YAML_DIR, "froc_metrics_optuna_test", verbose=True)