add scripts
This commit is contained in:
		
							
								
								
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							| @@ -143,7 +143,6 @@ cython_debug/ | |||||||
| /old_code/ | /old_code/ | ||||||
| /data/ | /data/ | ||||||
| /job_scripts/ | /job_scripts/ | ||||||
| /scripts/ |  | ||||||
| /temp/ | /temp/ | ||||||
| /slurms/ | /slurms/ | ||||||
| *.out | *.out | ||||||
|   | |||||||
							
								
								
									
										203
									
								
								scripts/1.U-net_chris.py
									
									
									
									
									
										Executable file
									
								
							
							
						
						
									
										203
									
								
								scripts/1.U-net_chris.py
									
									
									
									
									
										Executable file
									
								
							| @@ -0,0 +1,203 @@ | |||||||
|  | import multiprocessing | ||||||
|  | from os import path | ||||||
|  | import argparse | ||||||
|  | import time | ||||||
|  | from datetime import datetime | ||||||
|  | import sys  | ||||||
|  | # sys.path.append('./../code') | ||||||
|  | # from utils_quintin import *  | ||||||
|  | # from sfransen.utils_quintin import * | ||||||
|  | # sys.path.append('./../code/DWI_exp')     | ||||||
|  | # from callbacks import IntermediateImages, dice_coef | ||||||
|  | # from callbacks import RocCallback | ||||||
|  | from sfransen.DWI_exp import IntermediateImages, dice_coef | ||||||
|  | from sfransen.DWI_exp.preprocessing_function import preprocess | ||||||
|  | import yaml | ||||||
|  | import numpy as np | ||||||
|  | from tqdm import tqdm | ||||||
|  | import tensorflow as tf | ||||||
|  | tf.compat.v1.disable_eager_execution() | ||||||
|  |  | ||||||
|  | from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, CSVLogger | ||||||
|  | from tensorflow.keras.optimizers import Adam | ||||||
|  | from sklearn.model_selection import KFold | ||||||
|  |  | ||||||
|  | from sfransen.DWI_exp.helpers import * | ||||||
|  | from sfransen.DWI_exp.batchgenerator import BatchGenerator | ||||||
|  |  | ||||||
|  | from sfransen.DWI_exp.unet import build_dual_attention_unet | ||||||
|  | from focal_loss import BinaryFocalLoss | ||||||
|  |  | ||||||
|  | parser = argparse.ArgumentParser( | ||||||
|  |     description='Train a U-Net model for segmentation/detection tasks.' +  | ||||||
|  |                 'using cross-validation.') | ||||||
|  | parser.add_argument('--series', '-s',  | ||||||
|  |     metavar='[series_name]', required=True, nargs='+', | ||||||
|  |     help='List of series to include, must correspond with' + | ||||||
|  |         "path files in ./data/") | ||||||
|  | parser.add_argument('-experiment', | ||||||
|  |     help='add experiment title to store the files correctly: test_b50_b400_b800'  | ||||||
|  | ) | ||||||
|  | args = parser.parse_args() | ||||||
|  |  | ||||||
|  | # Determine the number of input series | ||||||
|  | num_series = len(args.series) | ||||||
|  |  | ||||||
|  | # Identify this job by the series included in the training | ||||||
|  | # Output folder will have this name, e.g.: b0_b50_b100 | ||||||
|  | # JOB_NAME = '_'.join(args.series) | ||||||
|  | JOB_NAME = args.experiment | ||||||
|  | DATA_DIR = "./../data/Nijmegen paths/" | ||||||
|  | # DATA_DIR = "./../data/new/" | ||||||
|  | # PROJECT_DIR = f"/data/pca-rad/sfransen/train_output/{args.experiment}" | ||||||
|  | PROJECT_DIR = f"/data/pca-rad/sfransen/train_output/{JOB_NAME}" | ||||||
|  | # 2 x 2mm2 in-plane resolution, 3.6mm slice thickness | ||||||
|  | TARGET_SPACING = (0.5, 0.5, 3)  | ||||||
|  | INPUT_SHAPE = (192, 192, 24, num_series) #(64, 64, 20, num_series) | ||||||
|  | IMAGE_SHAPE = INPUT_SHAPE[:3] | ||||||
|  | OUTPUT_SHAPE = (192, 192, 24, 1) # One output channel (segmentation) | ||||||
|  |  | ||||||
|  | # Hyperparameters | ||||||
|  | FOCAL_LOSS_GAMMA = 2 | ||||||
|  | INITIAL_LEARNING_RATE = 1e-4 | ||||||
|  | MAX_EPOCHS = 600 | ||||||
|  | EARLY_STOPPING = 50 | ||||||
|  | # increase batch size | ||||||
|  | BATCH_SIZE = 12 | ||||||
|  | MODEL_SELECTION_METRIC = 'val_loss' | ||||||
|  | MODEL_SELECTION_DIRECTION = "min" # Change to 'max' if higher value is better | ||||||
|  | EARLY_STOPPING_METRIC = 'val_loss' | ||||||
|  | EARLY_STOPPING_DIRECTION = "min" # Change to 'max' if higher value is better | ||||||
|  |  | ||||||
|  | # Training configuration | ||||||
|  | # add metric ROC_AUC | ||||||
|  | TRAINING_METRICS = ["binary_crossentropy", "binary_accuracy", dice_coef] | ||||||
|  | loss = BinaryFocalLoss(gamma=FOCAL_LOSS_GAMMA) | ||||||
|  | optimizer = Adam(learning_rate=INITIAL_LEARNING_RATE) | ||||||
|  |  | ||||||
|  | # Create folder structure in the output directory | ||||||
|  | if path.exists(PROJECT_DIR): | ||||||
|  |     prepare_project_dir(PROJECT_DIR+'_(2)') | ||||||
|  | else: | ||||||
|  |     prepare_project_dir(PROJECT_DIR) | ||||||
|  |  | ||||||
|  | #save params to yaml | ||||||
|  | params = { | ||||||
|  |     "focal_loss_gamma": FOCAL_LOSS_GAMMA, | ||||||
|  |     "initial_learning_rate": INITIAL_LEARNING_RATE, | ||||||
|  |     "max_epochs": MAX_EPOCHS, | ||||||
|  |     "MODEL_SELECTION_METRIC": MODEL_SELECTION_METRIC, | ||||||
|  |     "EARLY_STOPPING_METRIC": EARLY_STOPPING_METRIC, | ||||||
|  |     "train_output_dir": PROJECT_DIR, | ||||||
|  |     "batch_size": BATCH_SIZE, | ||||||
|  |     "optimizer": optimizer, | ||||||
|  |     "loss": loss, | ||||||
|  |     "datetime": print(datetime.now().strftime("%Y-%m-%d"))} | ||||||
|  | dump_dict_to_yaml(params, f"{PROJECT_DIR}", filename=f"params") | ||||||
|  |  | ||||||
|  | # Build the U-Net model | ||||||
|  | detection_model = build_dual_attention_unet(INPUT_SHAPE) | ||||||
|  | detection_model.summary(line_length=120) | ||||||
|  |  | ||||||
|  | # Load all numpy images into RAM | ||||||
|  | images, image_paths = {s: [] for s in args.series}, {} | ||||||
|  | segmentations = [] | ||||||
|  | print_(f"> Loading images into RAM...") | ||||||
|  |  | ||||||
|  | # Read the image paths from the data directory. | ||||||
|  | # Texts files are expected to have the name "[series_name].txt" | ||||||
|  | for s in args.series: | ||||||
|  |     with open(path.join(DATA_DIR, f"{s}.txt"), 'r') as f: | ||||||
|  |         image_paths[s] = [l.strip() for l in f.readlines()] | ||||||
|  | with open(path.join(DATA_DIR, f"seg.txt"), 'r') as f: | ||||||
|  |     seg_paths = [l.strip() for l in f.readlines()] | ||||||
|  | num_images = len(seg_paths) | ||||||
|  |  | ||||||
|  | # Read and preprocess each of the paths for each series, and the segmentations. | ||||||
|  | for img_idx in tqdm(range(num_images)):  # [:40]):for less images | ||||||
|  |     img_s = {s: sitk.ReadImage(image_paths[s][img_idx], sitk.sitkFloat32)  | ||||||
|  |         for s in args.series} | ||||||
|  |     seg_s = sitk.ReadImage(seg_paths[img_idx], sitk.sitkFloat32) | ||||||
|  |     img_n, seg_n = preprocess(img_s, seg_s,  | ||||||
|  |         shape=IMAGE_SHAPE, spacing=TARGET_SPACING) | ||||||
|  |     for seq in img_n: | ||||||
|  |         images[seq].append(img_n[seq]) | ||||||
|  |     segmentations.append(seg_n) | ||||||
|  |  | ||||||
|  | # Split train and validation  | ||||||
|  | # We use KFold to split the data, but we don't actually do cross validation, we  | ||||||
|  | # just use it to split the data 1:9. | ||||||
|  | # kfold = KFold(10, shuffle=True, random_state=123) | ||||||
|  | # train_idxs, valid_idxs = list(kfold.split(segmentations))[0] | ||||||
|  | # train_idxs = list(train_idxs) | ||||||
|  | # valid_idxs = list(valid_idxs) | ||||||
|  |  | ||||||
|  | yml_paths = read_yaml_to_dict('./../data/Nijmegen paths/train_val_test_idxs.yml') | ||||||
|  | train_idxs = yml_paths['train_set0'] | ||||||
|  | valid_idxs = yml_paths['val_set0'] | ||||||
|  |  | ||||||
|  |  | ||||||
|  | detection_model.compile( | ||||||
|  |     optimizer=optimizer,  | ||||||
|  |     loss=loss, | ||||||
|  |     metrics=TRAINING_METRICS | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  | train_generator = BatchGenerator(images, segmentations, | ||||||
|  |     sequences=args.series, | ||||||
|  |     shape=IMAGE_SHAPE, | ||||||
|  |     indexes=train_idxs,  | ||||||
|  |     batch_size=BATCH_SIZE, | ||||||
|  |     shuffle=True,  | ||||||
|  |     augmentation_function=augment | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | valid_generator = get_generator(images, segmentations, | ||||||
|  |     sequences=args.series, | ||||||
|  |     shape=IMAGE_SHAPE, | ||||||
|  |     indexes=valid_idxs,  | ||||||
|  |     batch_size=None, | ||||||
|  |     shuffle=False,  | ||||||
|  |     augmentation=None | ||||||
|  | ) | ||||||
|  | valid_data = next(valid_generator) | ||||||
|  | print_(f"The shape of valid_data input = {np.shape(valid_data[0])}") | ||||||
|  | print_(f"The shape of valid_data label = {np.shape(valid_data[1])}") | ||||||
|  |  | ||||||
|  | callbacks = [ | ||||||
|  |     EarlyStopping( | ||||||
|  |         monitor=EARLY_STOPPING_METRIC, | ||||||
|  |         mode=EARLY_STOPPING_DIRECTION, | ||||||
|  |         patience=EARLY_STOPPING, | ||||||
|  |         verbose=1), | ||||||
|  |     ModelCheckpoint( | ||||||
|  |         filepath=path.join(PROJECT_DIR, "models", JOB_NAME + ".h5"),  | ||||||
|  |         monitor=MODEL_SELECTION_METRIC, | ||||||
|  |         mode=MODEL_SELECTION_DIRECTION, | ||||||
|  |         verbose=1, | ||||||
|  |         save_best_only=True), | ||||||
|  |     # ModelCheckpoint( | ||||||
|  |     #     filepath=path.join(PROJECT_DIR, "models_dice", JOB_NAME + ".h5"),  | ||||||
|  |     #     monitor='val_dice_coef', | ||||||
|  |     #     mode='max', | ||||||
|  |     #     verbose=0, | ||||||
|  |     #     save_best_only=True), | ||||||
|  |     CSVLogger( | ||||||
|  |         filename=path.join(PROJECT_DIR, "logs", f"{JOB_NAME}.csv")), | ||||||
|  |     IntermediateImages( | ||||||
|  |         validation_set=valid_data, | ||||||
|  |         sequences=args.series, | ||||||
|  |         prefix=path.join(PROJECT_DIR, "output", JOB_NAME), | ||||||
|  |         num_images=25) | ||||||
|  |     # RocCallback( | ||||||
|  |     #      validation_set=valid_data, | ||||||
|  |     #      num_images=25) | ||||||
|  | ] | ||||||
|  |  | ||||||
|  | detection_model.fit(train_generator, | ||||||
|  |     validation_data    = valid_data, | ||||||
|  |     steps_per_epoch    = len(train_idxs) // BATCH_SIZE,  | ||||||
|  |     epochs             = MAX_EPOCHS, | ||||||
|  |     callbacks          = callbacks, | ||||||
|  |     verbose            = 1, | ||||||
|  |     ) | ||||||
							
								
								
									
										107
									
								
								scripts/3.make_train_val_test_indexes.py
									
									
									
									
									
										Executable file
									
								
							
							
						
						
									
										107
									
								
								scripts/3.make_train_val_test_indexes.py
									
									
									
									
									
										Executable file
									
								
							| @@ -0,0 +1,107 @@ | |||||||
|  | import argparse | ||||||
|  | import random | ||||||
|  | import sys | ||||||
|  | sys.path.append('./../code')   | ||||||
|  | from utils_quintin import list_from_file, dump_dict_to_yaml, print_p | ||||||
|  |  | ||||||
|  | ################################  README  ###################################### | ||||||
|  | # NEW - This script will create indexes for the training, validation and test | ||||||
|  | # sets. Create a yaml file with 10 sets of indexes for train, val and test. So  | ||||||
|  | # that they can be used for 10 fold cross validation when needed. The filename | ||||||
|  | # should contain an integer, which is the seed used to generate the indexes. The | ||||||
|  | # filename also indicates the type of split used. | ||||||
|  | # Output structure: | ||||||
|  | #   trainSet0: [indexes] | ||||||
|  | #   valSet0:   [indexes] | ||||||
|  | #   testSet0:  [indexes] | ||||||
|  | #   trainSet1: [indexes] | ||||||
|  | #   valSet1:   [indexes] | ||||||
|  | #   testSet1:  [indexes] | ||||||
|  | # etc... | ||||||
|  |  | ||||||
|  |  | ||||||
|  | ################################ PARSER ######################################## | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def parse_input_args(): | ||||||
|  |     parser = argparse.ArgumentParser(description='Parse arguments for splitting training, validation and the test set.') | ||||||
|  |  | ||||||
|  |     parser.add_argument('-n', | ||||||
|  |                         '--num_folds', | ||||||
|  |                         type=int, | ||||||
|  |                         default=1, | ||||||
|  |                         help='The number of folds in total. This amount of index sets will be created.') | ||||||
|  |  | ||||||
|  |     parser.add_argument('-p', | ||||||
|  |                         '--path_to_paths_file', | ||||||
|  |                         type=str, | ||||||
|  |                         default="./../data/Nijmegen paths/seg.txt", | ||||||
|  |                         help='Path to the .txt file containting paths to the nifti files.') | ||||||
|  |  | ||||||
|  |     parser.add_argument('-s', | ||||||
|  |                         '--split', | ||||||
|  |                         type=str, | ||||||
|  |                         default="80/10/10", | ||||||
|  |                         help='Train/validation/test split in percentages.') | ||||||
|  |  | ||||||
|  |     args = parser.parse_args() | ||||||
|  |  | ||||||
|  |     # split the given split string. | ||||||
|  |     args.p_train = int(args.split.split('/')[0]) | ||||||
|  |     args.p_val = int(args.split.split('/')[1]) | ||||||
|  |     args.p_test = int(args.split.split('/')[2]) | ||||||
|  |  | ||||||
|  |     assert args.p_train + args.p_val + args.p_test == 100, "The train, val, test split to sum to 100%." | ||||||
|  |      | ||||||
|  |     return args | ||||||
|  |  | ||||||
|  |  | ||||||
|  | ################################################################################ | ||||||
|  | SEED = 3478 | ||||||
|  |  | ||||||
|  | if __name__ == '__main__': | ||||||
|  |     print_p('\n\nMaking Train - Validation - Test indexes based.') | ||||||
|  |  | ||||||
|  |     # Parse some arguments | ||||||
|  |     args = parse_input_args() | ||||||
|  |     print_p(args) | ||||||
|  |  | ||||||
|  |     # Read the amount of observations/subjects in the data. | ||||||
|  |     t2_paths = list_from_file(args.path_to_paths_file) | ||||||
|  |     num_obs = len(t2_paths) | ||||||
|  |     print_p(f"Number of observations in {args.path_to_paths_file}: {len(t2_paths)}") | ||||||
|  |  | ||||||
|  |     # Create cutoff points for training, validation and test set. | ||||||
|  |     train_cutoff = int(args.p_train/100 * num_obs) | ||||||
|  |     val_cutoff   = int(args.p_val/100 * num_obs) + train_cutoff | ||||||
|  |     test_cutoff  = int(args.p_test/100 * num_obs) + val_cutoff | ||||||
|  |     print(f"\ncutoffs: {train_cutoff}, {val_cutoff}, {test_cutoff}") | ||||||
|  |  | ||||||
|  |     # Create dict that will hold all the data | ||||||
|  |     data_dict = {} | ||||||
|  |     data_dict["init_seed"] = SEED | ||||||
|  |     data_dict["split"] = args.split | ||||||
|  |          | ||||||
|  |     # loop over the amount of folds, that many sets will be created in a yaml file. | ||||||
|  |     for set_idx in range(args.num_folds): | ||||||
|  |  | ||||||
|  |         # Set new seed first | ||||||
|  |         random.seed(SEED + set_idx) | ||||||
|  |  | ||||||
|  |         # shuffle the indexes  | ||||||
|  |         indexes = list(range(num_obs)) | ||||||
|  |         random.shuffle(indexes) | ||||||
|  |  | ||||||
|  |         train_idxs = indexes[:train_cutoff] | ||||||
|  |         val_idxs   = indexes[train_cutoff:val_cutoff] | ||||||
|  |         test_idxs  = indexes[val_cutoff:test_cutoff] | ||||||
|  |          | ||||||
|  |         data_dict[f"train_set{set_idx}"] = train_idxs | ||||||
|  |         data_dict[f"val_set{set_idx}"]   = val_idxs | ||||||
|  |         data_dict[f"test_set{set_idx}"]  = test_idxs | ||||||
|  |          | ||||||
|  |     for key in data_dict: | ||||||
|  |         if type(data_dict[key]) == list: | ||||||
|  |             print(f"{key}: {len(data_dict[key])}") | ||||||
|  |  | ||||||
|  |     dump_dict_to_yaml(data_dict, "./../data", filename=f"train_val_test_idxs", verbose=False) | ||||||
							
								
								
									
										217
									
								
								scripts/4.frocs.py
									
									
									
									
									
										Executable file
									
								
							
							
						
						
									
										217
									
								
								scripts/4.frocs.py
									
									
									
									
									
										Executable file
									
								
							| @@ -0,0 +1,217 @@ | |||||||
|  | import sys | ||||||
|  | from os import path | ||||||
|  | import SimpleITK as sitk | ||||||
|  | import tensorflow as tf | ||||||
|  | from tensorflow.keras.models import load_model | ||||||
|  | from focal_loss import BinaryFocalLoss | ||||||
|  | import json | ||||||
|  | import matplotlib.pyplot as plt | ||||||
|  | import numpy as np | ||||||
|  | import multiprocessing | ||||||
|  | from functools import partial | ||||||
|  |  | ||||||
|  | sys.path.append('./../code') | ||||||
|  | from utils_quintin import *  | ||||||
|  |  | ||||||
|  | sys.path.append('./../code/DWI_exp')     | ||||||
|  | from helpers import * | ||||||
|  | from preprocessing_function import preprocess   | ||||||
|  | from callbacks import dice_coef | ||||||
|  |  | ||||||
|  | sys.path.append('./../code/FROC') | ||||||
|  | from blob_preprocess import * | ||||||
|  | from cal_froc_from_np import * | ||||||
|  |  | ||||||
|  |  | ||||||
|  | parser = argparse.ArgumentParser( | ||||||
|  |     description='Train a U-Net model for segmentation/detection tasks.' +  | ||||||
|  |                 'using cross-validation.') | ||||||
|  | parser.add_argument('--series', '-s',  | ||||||
|  |     metavar='[series_name]', required=True, nargs='+', | ||||||
|  |     help='List of series to include, must correspond with' + | ||||||
|  |         "path files in ./data/") | ||||||
|  | args = parser.parse_args() | ||||||
|  |  | ||||||
|  | ######## parsed inputs ############# | ||||||
|  | # SERIES = ['b50', 'b400', 'b800'] #can be parsed | ||||||
|  | SERIES = args.series | ||||||
|  | series_ = '_'.join(args.series) | ||||||
|  | # Import model | ||||||
|  | # MODEL_PATH = f'./../train_output/train_10h_{series_}/models/train_10h_{series_}.h5' | ||||||
|  | # YAML_DIR = f'./../train_output/train_10h_{series_}' | ||||||
|  | MODEL_PATH = f'./../train_output/train_n0.001_{series_}/models/train_n0.001_{series_}.h5' | ||||||
|  | print(MODEL_PATH) | ||||||
|  | YAML_DIR = f'./../train_output/train_n0.001_{series_}' | ||||||
|  | ################ constants ############ | ||||||
|  | DATA_DIR = "./../data/Nijmegen paths/" | ||||||
|  | TARGET_SPACING = (0.5, 0.5, 3)  | ||||||
|  | INPUT_SHAPE = (192, 192, 24, len(SERIES)) | ||||||
|  | IMAGE_SHAPE = INPUT_SHAPE[:3] | ||||||
|  |  | ||||||
|  | # import val_indx | ||||||
|  | DATA_SPLIT_INDEX = read_yaml_to_dict('./../data/Nijmegen paths/train_val_test_idxs.yml') | ||||||
|  | TEST_INDEX = DATA_SPLIT_INDEX['val_set0'] | ||||||
|  |  | ||||||
|  |  | ||||||
|  | ########## load images ############## | ||||||
|  | images, image_paths = {s: [] for s in SERIES}, {} | ||||||
|  | segmentations = [] | ||||||
|  | print_(f"> Loading images into RAM...") | ||||||
|  |  | ||||||
|  | for s in SERIES: | ||||||
|  |     with open(path.join(DATA_DIR, f"{s}.txt"), 'r') as f: | ||||||
|  |         image_paths[s] = [l.strip() for l in f.readlines()] | ||||||
|  | with open(path.join(DATA_DIR, f"seg.txt"), 'r') as f: | ||||||
|  |     seg_paths = [l.strip() for l in f.readlines()] | ||||||
|  | num_images = len(seg_paths) | ||||||
|  |  | ||||||
|  | # Read and preprocess each of the paths for each SERIES, and the segmentations. | ||||||
|  |  | ||||||
|  | from typing import List | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def load_images( | ||||||
|  |     image_paths: str, | ||||||
|  |     seq: str, | ||||||
|  |     target_shape: List[int], | ||||||
|  |     target_space = List[float]): | ||||||
|  |  | ||||||
|  |     img_s = sitk.ReadImage(image_paths, sitk.sitkFloat32) | ||||||
|  |  | ||||||
|  |     #resample | ||||||
|  |     mri_tra_s = resample(img_s,  | ||||||
|  |                         min_shape=target_shape,  | ||||||
|  |                         method=sitk.sitkNearestNeighbor,  | ||||||
|  |                         new_spacing=target_space) | ||||||
|  |  | ||||||
|  |     #center crop | ||||||
|  |     mri_tra_s = center_crop(mri_tra_s, shape=target_shape) | ||||||
|  |     #normalize | ||||||
|  |     if seq != 'seg': | ||||||
|  |         filter = sitk.NormalizeImageFilter() | ||||||
|  |         mri_tra_s = filter.Execute(mri_tra_s) | ||||||
|  |     else: | ||||||
|  |         filter = sitk.BinaryThresholdImageFilter() | ||||||
|  |         filter.SetLowerThreshold(1.0) | ||||||
|  |         mri_tra_s = filter.Execute(mri_tra_s) | ||||||
|  |      | ||||||
|  |     return sitk.GetArrayFromImage(mri_tra_s).T | ||||||
|  |  | ||||||
|  | N_CPUS = 12 | ||||||
|  | pool = multiprocessing.Pool(processes=N_CPUS) | ||||||
|  | partial_f = partial(load_images, | ||||||
|  |                     seq = 'images', | ||||||
|  |                     target_shape=IMAGE_SHAPE, | ||||||
|  |                     target_space = TARGET_SPACING) | ||||||
|  |  | ||||||
|  | images_2 = [] | ||||||
|  | for s in SERIES: | ||||||
|  |     image_paths_seq = image_paths[s] | ||||||
|  |     image_paths_index = np.asarray(image_paths_seq)[TEST_INDEX] | ||||||
|  |     data_list = pool.map(partial_f,image_paths_index) | ||||||
|  |     data = np.stack(data_list, axis=0) | ||||||
|  |     images_2.append(data) | ||||||
|  |     # print(s) | ||||||
|  |     # print(np.shape(data)) | ||||||
|  |     print(np.shape(images_2)) | ||||||
|  |  | ||||||
|  | partial_f = partial(load_images, | ||||||
|  |                     seq = 'seg', | ||||||
|  |                     target_shape=IMAGE_SHAPE, | ||||||
|  |                     target_space = TARGET_SPACING) | ||||||
|  | seg_paths_index = np.asarray(seg_paths)[TEST_INDEX] | ||||||
|  | data_list = pool.map(partial_f,seg_paths_index) | ||||||
|  | segmentations = np.stack(data_list, axis=0) | ||||||
|  |  | ||||||
|  | # print("segmentations pool",np.shape(segmentations_2)) | ||||||
|  |  | ||||||
|  | # for img_idx in TEST_INDEX: #for less images | ||||||
|  | #     img_s = {s: sitk.ReadImage(image_paths[s][img_idx], sitk.sitkFloat32)  | ||||||
|  | #         for s in SERIES} | ||||||
|  | #     seg_s = sitk.ReadImage(seg_paths[img_idx], sitk.sitkFloat32) | ||||||
|  | #     img_n, seg_n = preprocess(img_s, seg_s,  | ||||||
|  | #         shape=IMAGE_SHAPE, spacing=TARGET_SPACING) | ||||||
|  | #     for seq in img_n: | ||||||
|  | #         images[seq].append(img_n[seq]) | ||||||
|  | #     segmentations.append(seg_n) | ||||||
|  |  | ||||||
|  | # print("segmentations old",np.shape(segmentations)) | ||||||
|  |  | ||||||
|  | # # from dict to list | ||||||
|  | # # images_list = [img nmbr, [INPUT_SHAPE]] | ||||||
|  | # images_list = [images[s] for s in images.keys()] | ||||||
|  | # images_list = np.transpose(images_list, (1, 2, 3, 4, 0)) | ||||||
|  | images_list = np.transpose(images_2, (1, 2, 3, 4, 0)) | ||||||
|  |  | ||||||
|  | print("images size ",np.shape(images_list)) | ||||||
|  | print("size segmentation",np.shape(segmentations)) | ||||||
|  | # print("images size pool",np.shape(images_list_2)) | ||||||
|  |  | ||||||
|  | import os  | ||||||
|  | os.environ["CUDA_VISIBLE_DEVICES"] = "2" | ||||||
|  | ########### load module ################## | ||||||
|  | print(' >>>>>>> LOAD MODEL <<<<<<<<<') | ||||||
|  |  | ||||||
|  | dependencies = { | ||||||
|  |     'dice_coef': dice_coef | ||||||
|  | } | ||||||
|  | reconstructed_model = load_model(MODEL_PATH, custom_objects=dependencies) | ||||||
|  | # reconstructed_model.summary(line_length=120) | ||||||
|  |  | ||||||
|  | # make predictions on all val_indx | ||||||
|  | print(' >>>>>>> START prediction <<<<<<<<<') | ||||||
|  | predictions_blur = reconstructed_model.predict(images_list, batch_size=1) | ||||||
|  |  | ||||||
|  | # print("The shape of the predictions list is: ",np.shape(predictions_blur)) | ||||||
|  | # print(type(predictions)) | ||||||
|  | # np.save('predictions',predictions) | ||||||
|  |  | ||||||
|  | # preprocess predictions by removing the blur and making individual blobs  | ||||||
|  | print('>>>>>>>> START preprocess') | ||||||
|  |  | ||||||
|  | def move_dims(arr): | ||||||
|  |     # UMCG numpy dimensions convention: dims = (batch, width, heigth, depth) | ||||||
|  |     # Joeran numpy dimensions convention: dims = (batch, depth, heigth, width) | ||||||
|  |     arr = np.moveaxis(arr, 3, 1) # Joeran has his numpy arrays ordered differently. | ||||||
|  |     arr = np.moveaxis(arr, 3, 2) | ||||||
|  |     return arr | ||||||
|  |  | ||||||
|  | # Joeran has his numpy arrays ordered differently. | ||||||
|  | predictions_blur = move_dims(np.squeeze(predictions_blur)) | ||||||
|  | segmentations = move_dims(np.squeeze(segmentations)) | ||||||
|  | predictions = [preprocess_softmax(pred, threshold="dynamic")[0] for pred in predictions_blur] | ||||||
|  |  | ||||||
|  | # Remove outer edges  | ||||||
|  | zeros = np.zeros(np.shape(predictions)) | ||||||
|  | test = np.squeeze(predictions)[:,:,2:190,2:190] | ||||||
|  | zeros[:,:,2:190,2:190] = test | ||||||
|  | predictions = zeros | ||||||
|  |  | ||||||
|  | # perform Froc | ||||||
|  | metrics = evaluate(y_true=segmentations, y_pred=predictions) | ||||||
|  | dump_dict_to_yaml(metrics, YAML_DIR, "froc_metrics", verbose=True) | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # save one image | ||||||
|  | IMAGE_DIR = f'./../train_output/train_10h_{series_}' | ||||||
|  | img_s = sitk.GetImageFromArray(predictions_blur[3].squeeze()) | ||||||
|  | sitk.WriteImage(img_s, f"{IMAGE_DIR}/predictions_blur_001.nii.gz") | ||||||
|  |  | ||||||
|  | img_s = sitk.GetImageFromArray(predictions[3].squeeze()) | ||||||
|  | sitk.WriteImage(img_s, f"{IMAGE_DIR}/predictions_001.nii.gz") | ||||||
|  |  | ||||||
|  | img_s = sitk.GetImageFromArray(segmentations[3].squeeze()) | ||||||
|  | sitk.WriteImage(img_s, f"{IMAGE_DIR}/segmentations_001.nii.gz") | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # create plot  | ||||||
|  | # json_path = './../scripts/metrics.json' | ||||||
|  | # f = open(json_path) | ||||||
|  | # data = json.load(f) | ||||||
|  | # x = data['fpr'] | ||||||
|  | # y = data['tpr'] | ||||||
|  | # auroc = data['auroc'] | ||||||
|  | # plt.plot(x,y) | ||||||
							
								
								
									
										68
									
								
								scripts/5.Visualize_frocs.py
									
									
									
									
									
										Executable file
									
								
							
							
						
						
									
										68
									
								
								scripts/5.Visualize_frocs.py
									
									
									
									
									
										Executable file
									
								
							| @@ -0,0 +1,68 @@ | |||||||
|  | import sys | ||||||
|  | sys.path.append('./../code') | ||||||
|  | from utils_quintin import * | ||||||
|  | import matplotlib.pyplot as plt | ||||||
|  | import argparse | ||||||
|  |  | ||||||
|  | parser = argparse.ArgumentParser( | ||||||
|  |     description='Visualise froc results') | ||||||
|  | parser.add_argument('-saveas',  | ||||||
|  |     help='') | ||||||
|  | parser.add_argument('-comparison',  | ||||||
|  |     help='') | ||||||
|  | parser.add_argument('--experiment', '-s',  | ||||||
|  |     metavar='[series_name]', required=True, nargs='+', | ||||||
|  |     help='List of series to include, must correspond with' + | ||||||
|  |         "path files in ./data/") | ||||||
|  | args = parser.parse_args() | ||||||
|  |  | ||||||
|  | if args.comparison: | ||||||
|  |     colors = ['r','r','b','b','g','g'] | ||||||
|  |     plot_type = ['-','--','-','--','-','--'] | ||||||
|  | else:  | ||||||
|  |     colors = ['r','b','g','k'] | ||||||
|  |     plot_type = ['-','-','-','-'] | ||||||
|  |  | ||||||
|  | experiments = args.experiment | ||||||
|  | print(experiments) | ||||||
|  | experiment_path = [] | ||||||
|  | experiment_metrics = {} | ||||||
|  | auroc = [] | ||||||
|  | for idx in range(len(args.experiment)): | ||||||
|  |     experiment_path = f'./../train_output/{experiments[idx]}/froc_metrics.yml' | ||||||
|  |     experiment_metrics = read_yaml_to_dict(experiment_path) | ||||||
|  |     auroc.append(round(experiment_metrics['auroc'],3)) | ||||||
|  |  | ||||||
|  |     plt.figure(1) | ||||||
|  |     plt.plot(experiment_metrics["FP_per_case"], experiment_metrics["sensitivity"],color=colors[idx],linestyle=plot_type[idx]) | ||||||
|  |  | ||||||
|  |     plt.figure(2) | ||||||
|  |     plt.plot(experiment_metrics["fpr"], experiment_metrics["tpr"],color=colors[idx],linestyle=plot_type[idx]) | ||||||
|  |  | ||||||
|  | print(auroc) | ||||||
|  | experiments = [exp.replace('train_10h_', '') for exp in experiments]  | ||||||
|  | experiments = [exp.replace('train_n0.001_', '') for exp in experiments]  | ||||||
|  | experiments = [exp.replace('_', ' ') for exp in experiments]  | ||||||
|  | # experiments = ['10% noise','1% noise','0.1% noise','0.05% noise'] | ||||||
|  |  | ||||||
|  | plt.figure(1) | ||||||
|  | plt.title('fROC curve') | ||||||
|  | plt.xlabel('False positive per case') | ||||||
|  | plt.ylabel('Sensitivity') | ||||||
|  | plt.legend(experiments,loc='lower right') | ||||||
|  | plt.xlim([0,3]) | ||||||
|  | plt.ylim([0,1]) | ||||||
|  | plt.yticks([0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1]) | ||||||
|  | plt.grid() | ||||||
|  | plt.savefig(f"./../train_output/fROC_{args.saveas}.png", dpi=300) | ||||||
|  |  | ||||||
|  | concat_func = lambda x,y: x + " (" + str(y) + ")" | ||||||
|  | experiments_auroc = list(map(concat_func,experiments,auroc)) # list the map function | ||||||
|  |  | ||||||
|  | plt.figure(2) | ||||||
|  | plt.title('ROC curve') | ||||||
|  | plt.legend(experiments_auroc,loc='lower right') | ||||||
|  | plt.xlabel('False positive rate') | ||||||
|  | plt.ylabel('True positive rate') | ||||||
|  | plt.grid() | ||||||
|  | plt.savefig(f"./../train_output/ROC_{args.saveas}.png", dpi=300) | ||||||
							
								
								
									
										108
									
								
								scripts/6.saliency_map.py
									
									
									
									
									
										Executable file
									
								
							
							
						
						
									
										108
									
								
								scripts/6.saliency_map.py
									
									
									
									
									
										Executable file
									
								
							| @@ -0,0 +1,108 @@ | |||||||
|  | import sys | ||||||
|  | from os import path | ||||||
|  | import SimpleITK as sitk | ||||||
|  | import tensorflow as tf | ||||||
|  | from tensorflow import keras | ||||||
|  | from tensorflow.keras.models import load_model | ||||||
|  | from focal_loss import BinaryFocalLoss | ||||||
|  | import json | ||||||
|  | import matplotlib.pyplot as plt | ||||||
|  | import numpy as np  | ||||||
|  |  | ||||||
|  | from sfransen.Saliency.base import * | ||||||
|  | from sfransen.Saliency.integrated_gradients import * | ||||||
|  | # from tensorflow.keras.vis.visualization import visualize_saliency | ||||||
|  |  | ||||||
|  | sys.path.append('./../code') | ||||||
|  | from utils_quintin import *  | ||||||
|  |  | ||||||
|  | sys.path.append('./../code/DWI_exp')  | ||||||
|  | # from preprocessing_function import preprocess   | ||||||
|  | from sfransen.DWI_exp import preprocess | ||||||
|  | print("done step 1") | ||||||
|  | from sfransen.DWI_exp.helpers import * | ||||||
|  | # from helpers import * | ||||||
|  | from callbacks import dice_coef | ||||||
|  |  | ||||||
|  | sys.path.append('./../code/FROC') | ||||||
|  | from blob_preprocess import * | ||||||
|  | from cal_froc_from_np import * | ||||||
|  |  | ||||||
|  | quit() | ||||||
|  | # train_10h_t2_b50_b400_b800_b1400_adc | ||||||
|  | SERIES = ['t2','b50','b400','b800','b1400','adc'] | ||||||
|  | MODEL_PATH = f'./../train_output/train_10h_t2_b50_b400_b800_b1400_adc/models/train_10h_t2_b50_b400_b800_b1400_adc.h5' | ||||||
|  | YAML_DIR = f'./../train_output/train_10h_t2_b50_b400_b800_b1400_adc' | ||||||
|  | ################ constants ############ | ||||||
|  | DATA_DIR = "./../data/Nijmegen paths/" | ||||||
|  | TARGET_SPACING = (0.5, 0.5, 3)  | ||||||
|  | INPUT_SHAPE = (192, 192, 24, len(SERIES)) | ||||||
|  | IMAGE_SHAPE = INPUT_SHAPE[:3] | ||||||
|  |  | ||||||
|  | # import val_indx | ||||||
|  | # DATA_SPLIT_INDEX = read_yaml_to_dict('./../data/Nijmegen paths/train_val_test_idxs.yml') | ||||||
|  | # TEST_INDEX = DATA_SPLIT_INDEX['val_set0'] | ||||||
|  |  | ||||||
|  | experiment_path = f'./../train_output/train_10h_t2_b50_b400_b800_b1400_adc/froc_metrics.yml' | ||||||
|  | experiment_metrics = read_yaml_to_dict(experiment_path) | ||||||
|  | DATA_SPLIT_INDEX = read_yaml_to_dict('./../data/Nijmegen paths/train_val_test_idxs.yml') | ||||||
|  | TEST_INDEX = DATA_SPLIT_INDEX['val_set0'] | ||||||
|  |  | ||||||
|  | top_10_idx = np.argsort(experiment_metrics['roc_pred'])[-10:] | ||||||
|  | TEST_INDEX = [TEST_INDEX[i] for i in top_10_idx] | ||||||
|  |  | ||||||
|  | ########## load images ############## | ||||||
|  | images, image_paths = {s: [] for s in SERIES}, {} | ||||||
|  | segmentations = [] | ||||||
|  | print_(f"> Loading images into RAM...") | ||||||
|  |  | ||||||
|  | for s in SERIES: | ||||||
|  |     with open(path.join(DATA_DIR, f"{s}.txt"), 'r') as f: | ||||||
|  |         image_paths[s] = [l.strip() for l in f.readlines()] | ||||||
|  | with open(path.join(DATA_DIR, f"seg.txt"), 'r') as f: | ||||||
|  |     seg_paths = [l.strip() for l in f.readlines()] | ||||||
|  | num_images = len(seg_paths) | ||||||
|  |  | ||||||
|  | # Read and preprocess each of the paths for each SERIES, and the segmentations. | ||||||
|  | for img_idx in TEST_INDEX[:5]: #for less images | ||||||
|  |     img_s = {s: sitk.ReadImage(image_paths[s][img_idx], sitk.sitkFloat32)  | ||||||
|  |         for s in SERIES} | ||||||
|  |     seg_s = sitk.ReadImage(seg_paths[img_idx], sitk.sitkFloat32) | ||||||
|  |     img_n, seg_n = preprocess(img_s, seg_s,  | ||||||
|  |         shape=IMAGE_SHAPE, spacing=TARGET_SPACING) | ||||||
|  |     for seq in img_n: | ||||||
|  |         images[seq].append(img_n[seq]) | ||||||
|  |     segmentations.append(seg_n) | ||||||
|  |  | ||||||
|  | images_list = [images[s] for s in images.keys()] | ||||||
|  | images_list = np.transpose(images_list, (1, 2, 3, 4, 0)) | ||||||
|  |   | ||||||
|  | ########### load module ################## | ||||||
|  | dependencies = { | ||||||
|  |     'dice_coef': dice_coef | ||||||
|  | } | ||||||
|  | reconstructed_model = load_model(MODEL_PATH, custom_objects=dependencies) | ||||||
|  |  | ||||||
|  | # reconstructed_model.layers[-1].activation = tf.keras.activations.linear | ||||||
|  |  | ||||||
|  | print('START prediction') | ||||||
|  |  | ||||||
|  | ig = IntegratedGradients(reconstructed_model) | ||||||
|  | saliency_map = [] | ||||||
|  | for img_idx in range(len(images_list)): | ||||||
|  |     # input_img = np.resize(images_list[img_idx],(1,48,48,8,8)) | ||||||
|  |     input_img = np.resize(images_list[img_idx],(1,192,192,24,len(SERIES))) | ||||||
|  |     saliency_map.append(ig.get_mask(input_img).numpy()) | ||||||
|  |     print("size saliency map is:",np.shape(saliency_map)) | ||||||
|  |  | ||||||
|  | np.save('saliency',saliency_map) | ||||||
|  |  | ||||||
|  | # Christian Roest, [11-3-2022 15:30] | ||||||
|  | # input_img heeft dimensies (1, 48, 48, 8, 8) | ||||||
|  |  | ||||||
|  | # reconstructed_model.summary(line_length=120) | ||||||
|  |  | ||||||
|  | # make predictions on all val_indx | ||||||
|  | print('START saliency') | ||||||
|  | # predictions_blur = reconstructed_model.predict(images_list, batch_size=1) | ||||||
|  |  | ||||||
							
								
								
									
										90
									
								
								scripts/7.Visualize_saliency.py
									
									
									
									
									
										Executable file
									
								
							
							
						
						
									
										90
									
								
								scripts/7.Visualize_saliency.py
									
									
									
									
									
										Executable file
									
								
							| @@ -0,0 +1,90 @@ | |||||||
|  | import numpy as np  | ||||||
|  | import matplotlib.pyplot as plt | ||||||
|  | import matplotlib.cm as cm | ||||||
|  |  | ||||||
|  | heatmap = np.load('saliency.npy') | ||||||
|  | print(np.shape(heatmap)) | ||||||
|  | heatmap = np.squeeze(heatmap) | ||||||
|  | print(np.shape(heatmap)) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | ### take average over 5 ######### | ||||||
|  | heatmap = np.mean(abs(heatmap),axis=0) | ||||||
|  | print(np.shape(heatmap)) | ||||||
|  |  | ||||||
|  | SERIES = ['t2','b50','b400','b800','b1400','adc'] | ||||||
|  | fig, axes = plt.subplots(1,6) | ||||||
|  | max_value = np.amax(heatmap) | ||||||
|  | pri | ||||||
|  | min_value = np.amin(heatmap) | ||||||
|  | # vmin vmax van hele heatmap voor scaling in imshow | ||||||
|  | # cmap naar grey  | ||||||
|  |  | ||||||
|  | im = axes[0].imshow(np.squeeze(heatmap[:,:,12,0])) | ||||||
|  | axes[1].imshow(np.squeeze(heatmap[:,:,12,1]), vmin=min_value, vmax=max_value) | ||||||
|  | axes[2].imshow(np.squeeze(heatmap[:,:,12,2]), vmin=min_value, vmax=max_value) | ||||||
|  | axes[3].imshow(np.squeeze(heatmap[:,:,12,3]), vmin=min_value, vmax=max_value) | ||||||
|  | axes[4].imshow(np.squeeze(heatmap[:,:,12,4]), vmin=min_value, vmax=max_value) | ||||||
|  | axes[5].imshow(np.squeeze(heatmap[:,:,12,5]), vmin=min_value, vmax=max_value) | ||||||
|  |  | ||||||
|  | axes[0].set_title("t2") | ||||||
|  | axes[1].set_title("b50") | ||||||
|  | axes[2].set_title("b400") | ||||||
|  | axes[3].set_title("b800") | ||||||
|  | axes[4].set_title("b1400") | ||||||
|  | axes[5].set_title("adc") | ||||||
|  |  | ||||||
|  | cbar = fig.colorbar(im, ax=axes.ravel().tolist(), shrink=0.5, orientation='horizontal') | ||||||
|  | cbar.set_ticks([-0.1,0,0.1]) | ||||||
|  | cbar.set_ticklabels(['less importance', '0', 'important']) | ||||||
|  | fig.suptitle('Average saliency maps over the 5 highest predictions', fontsize=16) | ||||||
|  | plt.show() | ||||||
|  |  | ||||||
|  | quit() | ||||||
|  |  | ||||||
|  | #take one image out | ||||||
|  | heatmap = np.squeeze(heatmap[0]) | ||||||
|  |  | ||||||
|  | import numpy as np | ||||||
|  | import matplotlib.pyplot as plt | ||||||
|  |  | ||||||
|  | # Fixing random state for reproducibility | ||||||
|  | np.random.seed(19680801) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class IndexTracker: | ||||||
|  |     def __init__(self, ax, X): | ||||||
|  |         self.ax = ax | ||||||
|  |         ax.set_title('use scroll wheel to navigate images') | ||||||
|  |  | ||||||
|  |         self.X = X | ||||||
|  |         rows, cols, self.slices = X.shape | ||||||
|  |         self.ind = self.slices//2 | ||||||
|  |  | ||||||
|  |         self.im = ax.imshow(self.X[:, :, self.ind], cmap='jet') | ||||||
|  |         self.update() | ||||||
|  |  | ||||||
|  |     def on_scroll(self, event): | ||||||
|  |         print("%s %s" % (event.button, event.step)) | ||||||
|  |         if event.button == 'up': | ||||||
|  |             self.ind = (self.ind + 1) % self.slices | ||||||
|  |         else: | ||||||
|  |             self.ind = (self.ind - 1) % self.slices | ||||||
|  |         self.update() | ||||||
|  |  | ||||||
|  |     def update(self): | ||||||
|  |         self.im.set_data(self.X[:, :, self.ind]) | ||||||
|  |         self.ax.set_ylabel('slice %s' % self.ind) | ||||||
|  |         self.im.axes.figure.canvas.draw() | ||||||
|  |  | ||||||
|  | plt.figure(0) | ||||||
|  | fig, ax = plt.subplots(1, 1) | ||||||
|  | tracker = IndexTracker(ax, heatmap[:,:,:,5]) | ||||||
|  | fig.canvas.mpl_connect('scroll_event', tracker.on_scroll) | ||||||
|  | plt.show() | ||||||
|  |  | ||||||
|  | plt.figure(1) | ||||||
|  | fig, ax = plt.subplots(1, 1) | ||||||
|  | tracker = IndexTracker(ax, heatmap[:,:,:,3]) | ||||||
|  | fig.canvas.mpl_connect('scroll_event', tracker.on_scroll) | ||||||
|  | plt.show() | ||||||
							
								
								
									
										59
									
								
								scripts/8.Visualize_training.py
									
									
									
									
									
										Executable file
									
								
							
							
						
						
									
										59
									
								
								scripts/8.Visualize_training.py
									
									
									
									
									
										Executable file
									
								
							| @@ -0,0 +1,59 @@ | |||||||
|  | import matplotlib.pyplot as plt  | ||||||
|  | import pandas as pd | ||||||
|  | import glob | ||||||
|  | import argparse | ||||||
|  | import os | ||||||
|  |  | ||||||
|  | #create parser | ||||||
|  | def parse_input_args(): | ||||||
|  |     parser = argparse.ArgumentParser(description='Parse arguments for training a Reconstruction model') | ||||||
|  |  | ||||||
|  |     parser.add_argument('train_out_dir', | ||||||
|  |                         type=str, | ||||||
|  |                         help='Directory name in train_output dir of the desired experiment folder. There should be a .csv file in this directory with train statistics.') | ||||||
|  |  | ||||||
|  |     args = parser.parse_args() | ||||||
|  |     return args | ||||||
|  |  | ||||||
|  | args = parse_input_args() | ||||||
|  |  | ||||||
|  | print(f"Plotting {args}") | ||||||
|  |  | ||||||
|  | # find csv file | ||||||
|  | # csv = glob.glob(f"train_output/{args.train_out_dir}/*.csv")[0] | ||||||
|  | folder_input = args.train_out_dir  | ||||||
|  |  | ||||||
|  | # load csv file | ||||||
|  | df = pd.read_csv(f'{folder_input}') | ||||||
|  |  | ||||||
|  | # read csv file | ||||||
|  | for metric in df: | ||||||
|  |     # if not metric == 'epoch': | ||||||
|  |     if metric == 'loss' or metric == 'val_loss':  | ||||||
|  |         plt.plot(df['epoch'], df[metric], label=metric) | ||||||
|  |         plt.ylim(ymin=0,ymax=0.01) | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  | folder, csvfile = os.path.split(args.train_out_dir) | ||||||
|  | root, experiment = os.path.split(os.path.split(folder)[0]) | ||||||
|  |  | ||||||
|  | plt.title(experiment) | ||||||
|  | plt.xlabel('Epoch') | ||||||
|  | plt.ylabel('Loss') | ||||||
|  | plt.grid() | ||||||
|  | plt.legend() | ||||||
|  | plt.savefig(f"{folder}/{experiment}.png") | ||||||
|  | plt.clf() | ||||||
|  | plt.close() | ||||||
|  | print(folder+".png") | ||||||
|  | print(f"\nsaved figure to {folder}") | ||||||
|  |  | ||||||
|  | # van yaml inladen 'loss' | ||||||
|  | # vanuit utils > dict to yaml | ||||||
|  | # csv viewer extension  | ||||||
|  | # course .venv/bin/activate  | ||||||
|  | # loop over alles en if metric then 1-metric | ||||||
|  | # kleur codering color=[] | ||||||
|  |  | ||||||
|  |  | ||||||
		Reference in New Issue
	
	Block a user