import multiprocessing from os import path import argparse import time from datetime import datetime import sys # sys.path.append('./../code') # from utils_quintin import * from sfransen.utils_quintin import * # sys.path.append('./../code/DWI_exp') # from callbacks import IntermediateImages, dice_coef # from callbacks import RocCallback from sfransen.utils_quintin import * from sfransen.DWI_exp import IntermediateImages, dice_coef from sfransen.DWI_exp.preprocessing_function import preprocess from sfransen.DWI_exp.losses import weighted_binary_cross_entropy import yaml import numpy as np from tqdm import tqdm import tensorflow as tf tf.compat.v1.disable_eager_execution() from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, CSVLogger from tensorflow.keras.optimizers import Adam from sklearn.model_selection import KFold from sfransen.DWI_exp.helpers import * from sfransen.DWI_exp.batchgenerator import BatchGenerator from sfransen.DWI_exp.unet import build_dual_attention_unet from focal_loss import BinaryFocalLoss parser = argparse.ArgumentParser( description='Train a U-Net model for segmentation/detection tasks.' + 'using cross-validation.') parser.add_argument('--series', '-s', metavar='[series_name]', required=True, nargs='+', help='List of series to include, must correspond with' + "path files in ./data/") parser.add_argument('-experiment', help='add experiment title to store the files correctly: test_b50_b400_b800' ) args = parser.parse_args() # Determine the number of input series num_series = len(args.series) # Identify this job by the series included in the training # Output folder will have this name, e.g.: b0_b50_b100 # JOB_NAME = '_'.join(args.series) JOB_NAME = args.experiment DATA_DIR = "./../data/Nijmegen paths/" # DATA_DIR = "./../data/new/" # PROJECT_DIR = f"/data/pca-rad/sfransen/train_output/{args.experiment}" PROJECT_DIR = f"/data/pca-rad/sfransen/train_output/{JOB_NAME}" # 2 x 2mm2 in-plane resolution, 3.6mm slice thickness TARGET_SPACING = (0.5, 0.5, 3) INPUT_SHAPE = (192, 192, 24, num_series) #(64, 64, 20, num_series) IMAGE_SHAPE = INPUT_SHAPE[:3] OUTPUT_SHAPE = (192, 192, 24, 1) # One output channel (segmentation) # Hyperparameters FOCAL_LOSS_GAMMA = 2 INITIAL_LEARNING_RATE = 1e-4 MAX_EPOCHS = 1500 EARLY_STOPPING = 50 # increase batch size BATCH_SIZE = 12 MODEL_SELECTION_METRIC = 'val_loss' MODEL_SELECTION_DIRECTION = "min" # Change to 'max' if higher value is better EARLY_STOPPING_METRIC = 'val_loss' EARLY_STOPPING_DIRECTION = 'min' # MODEL_SELECTION_METRIC = 'weighted_binary_cross_entropy' # MODEL_SELECTION_DIRECTION = "min" # Change to 'max' if higher value is better # EARLY_STOPPING_METRIC = 'weighted_binary_cross_entropy' # EARLY_STOPPING_DIRECTION = 'min' # Training configuration # add metric ROC_AUC TRAINING_METRICS = ["binary_crossentropy", "binary_accuracy", dice_coef] # loss = BinaryFocalLoss(gamma=FOCAL_LOSS_GAMMA) weight_for_0 = 0.05 weight_for_1 = 0.95 loss = weighted_binary_cross_entropy({0: weight_for_0, 1: weight_for_1}) optimizer = Adam(learning_rate=INITIAL_LEARNING_RATE) # Create folder structure in the output directory if path.exists(PROJECT_DIR): prepare_project_dir(PROJECT_DIR+'_(2)') else: prepare_project_dir(PROJECT_DIR) #save params to yaml params = { "focal_loss_gamma": FOCAL_LOSS_GAMMA, "initial_learning_rate": INITIAL_LEARNING_RATE, "max_epochs": MAX_EPOCHS, "MODEL_SELECTION_METRIC": MODEL_SELECTION_METRIC, "EARLY_STOPPING_METRIC": EARLY_STOPPING_METRIC, "train_output_dir": PROJECT_DIR, "batch_size": BATCH_SIZE, "optimizer": optimizer, "loss": loss, "datetime": print(datetime.now().strftime("%Y-%m-%d"))} dump_dict_to_yaml(params, f"{PROJECT_DIR}", filename=f"params") # Build the U-Net model detection_model = build_dual_attention_unet(INPUT_SHAPE) detection_model.summary(line_length=120) # Load all numpy images into RAM images, image_paths = {s: [] for s in args.series}, {} segmentations = [] print_(f"> Loading images into RAM...") # Read the image paths from the data directory. # Texts files are expected to have the name "[series_name].txt" for s in args.series: with open(path.join(DATA_DIR, f"{s}.txt"), 'r') as f: image_paths[s] = [l.strip() for l in f.readlines()] with open(path.join(DATA_DIR, f"seg.txt"), 'r') as f: seg_paths = [l.strip() for l in f.readlines()] num_images = len(seg_paths) # Read and preprocess each of the paths for each series, and the segmentations. for img_idx in tqdm(range(num_images)): #[:40]): #for less images img_s = {s: sitk.ReadImage(image_paths[s][img_idx], sitk.sitkFloat32) for s in args.series} seg_s = sitk.ReadImage(seg_paths[img_idx], sitk.sitkFloat32) img_n, seg_n = preprocess(img_s, seg_s, shape=IMAGE_SHAPE, spacing=TARGET_SPACING) for seq in img_n: images[seq].append(img_n[seq]) segmentations.append(seg_n) # Split train and validation # We use KFold to split the data, but we don't actually do cross validation, we # just use it to split the data 1:9. # kfold = KFold(10, shuffle=True, random_state=123) # train_idxs, valid_idxs = list(kfold.split(segmentations))[0] # train_idxs = list(train_idxs) # valid_idxs = list(valid_idxs) yml_paths = read_yaml_to_dict('./../data/Nijmegen paths/train_val_test_idxs.yml') train_idxs = yml_paths['train_set0'] valid_idxs = yml_paths['val_set0'] detection_model.compile( optimizer=optimizer, loss=loss, metrics=TRAINING_METRICS ) train_generator = BatchGenerator(images, segmentations, sequences=args.series, shape=IMAGE_SHAPE, indexes=train_idxs, batch_size=BATCH_SIZE, shuffle=True, augmentation_function=augment ) valid_generator = get_generator(images, segmentations, sequences=args.series, shape=IMAGE_SHAPE, indexes=valid_idxs, batch_size=None, shuffle=False, augmentation=None ) valid_data = next(valid_generator) print_(f"The shape of valid_data input = {np.shape(valid_data[0])}") print_(f"The shape of valid_data label = {np.shape(valid_data[1])}") callbacks = [ ModelCheckpoint( filepath=path.join(PROJECT_DIR, "models", JOB_NAME + ".h5"), monitor=MODEL_SELECTION_METRIC, mode=MODEL_SELECTION_DIRECTION, verbose=2, save_best_only=True), ModelCheckpoint( filepath=path.join(PROJECT_DIR, "models", JOB_NAME + "_dice" + ".h5"), monitor='val_dice_coef', mode='max', verbose=2, save_best_only=True), CSVLogger( filename=path.join(PROJECT_DIR, "logs", f"{JOB_NAME}.csv")), EarlyStopping( monitor=EARLY_STOPPING_METRIC, mode=EARLY_STOPPING_DIRECTION, patience=EARLY_STOPPING, verbose=2), IntermediateImages( validation_set=valid_data, sequences=args.series, prefix=path.join(PROJECT_DIR, "output", JOB_NAME), num_images=25) # RocCallback( # validation_set=valid_data, # num_images=25) ] detection_model.fit(train_generator, validation_data = valid_data, steps_per_epoch = len(train_idxs) // BATCH_SIZE, epochs = MAX_EPOCHS, callbacks = callbacks, verbose = 2 , )