221 lines
		
	
	
		
			7.5 KiB
		
	
	
	
		
			Python
		
	
	
		
			Executable File
		
	
	
	
	
			
		
		
	
	
			221 lines
		
	
	
		
			7.5 KiB
		
	
	
	
		
			Python
		
	
	
		
			Executable File
		
	
	
	
	
| import multiprocessing
 | |
| from os import path
 | |
| import argparse
 | |
| import time
 | |
| from datetime import datetime
 | |
| import sys 
 | |
| # sys.path.append('./../code')
 | |
| # from utils_quintin import * 
 | |
| from sfransen.utils_quintin import *
 | |
| # sys.path.append('./../code/DWI_exp')    
 | |
| # from callbacks import IntermediateImages, dice_coef
 | |
| # from callbacks import RocCallback
 | |
| from sfransen.utils_quintin import *
 | |
| from sfransen.DWI_exp import IntermediateImages, dice_coef
 | |
| from sfransen.DWI_exp.preprocessing_function import preprocess
 | |
| from sfransen.DWI_exp.losses import weighted_binary_cross_entropy
 | |
| import yaml
 | |
| import numpy as np
 | |
| from tqdm import tqdm
 | |
| import tensorflow as tf
 | |
| tf.compat.v1.disable_eager_execution()
 | |
| 
 | |
| from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, CSVLogger
 | |
| from tensorflow.keras.optimizers import Adam
 | |
| from sklearn.model_selection import KFold
 | |
| 
 | |
| from sfransen.DWI_exp.helpers import *
 | |
| from sfransen.DWI_exp.batchgenerator import BatchGenerator
 | |
| 
 | |
| from sfransen.DWI_exp.unet import build_dual_attention_unet
 | |
| from focal_loss import BinaryFocalLoss
 | |
| # from umcglib.utils import set_gpu
 | |
| # set_gpu(gpu_idx=1)
 | |
| 
 | |
| parser = argparse.ArgumentParser(
 | |
|     description='Train a U-Net model for segmentation/detection tasks.' + 
 | |
|                 'using cross-validation.')
 | |
| parser.add_argument('--series', '-s', 
 | |
|     metavar='[series_name]', required=True, nargs='+',
 | |
|     help='List of series to include, must correspond with' +
 | |
|         "path files in ./data/")
 | |
| parser.add_argument('-experiment',
 | |
|     help='add experiment title to store the files correctly: test_b50_b400_b800') 
 | |
| parser.add_argument('-fold',
 | |
|     help='import fold' 
 | |
| )
 | |
| args = parser.parse_args()
 | |
| 
 | |
| # Determine the number of input series
 | |
| num_series = len(args.series)
 | |
| 
 | |
| # Identify this job by the series included in the training
 | |
| # Output folder will have this name, e.g.: b0_b50_b100
 | |
| # JOB_NAME = '_'.join(args.series)
 | |
| JOB_NAME = args.experiment
 | |
| DATA_DIR = "./../data/Nijmegen paths/"
 | |
| # DATA_DIR = "./../data/new/"
 | |
| # PROJECT_DIR = f"/data/pca-rad/sfransen/train_output/{args.experiment}"
 | |
| PROJECT_DIR = f"/data/pca-rad/sfransen/train_output/{JOB_NAME}"
 | |
| # 2 x 2mm2 in-plane resolution, 3.6mm slice thickness
 | |
| TARGET_SPACING = (0.5, 0.5, 3) 
 | |
| INPUT_SHAPE = (192, 192, 24, num_series) #(64, 64, 20, num_series)
 | |
| IMAGE_SHAPE = INPUT_SHAPE[:3]
 | |
| OUTPUT_SHAPE = (192, 192, 24, 1) # One output channel (segmentation)
 | |
| 
 | |
| # Hyperparameters
 | |
| FOCAL_LOSS_GAMMA = 2
 | |
| INITIAL_LEARNING_RATE = 1e-4
 | |
| MAX_EPOCHS = 1500
 | |
| EARLY_STOPPING = 50
 | |
| # increase batch size
 | |
| BATCH_SIZE = 12
 | |
| MODEL_SELECTION_METRIC = 'val_loss'
 | |
| MODEL_SELECTION_DIRECTION = "min" # Change to 'max' if higher value is better
 | |
| EARLY_STOPPING_METRIC = 'val_loss'
 | |
| EARLY_STOPPING_DIRECTION = 'min'
 | |
| # MODEL_SELECTION_METRIC = 'weighted_binary_cross_entropy'
 | |
| # MODEL_SELECTION_DIRECTION = "min" # Change to 'max' if higher value is better
 | |
| # EARLY_STOPPING_METRIC = 'weighted_binary_cross_entropy'
 | |
| # EARLY_STOPPING_DIRECTION = 'min'
 | |
| 
 | |
| # Training configuration
 | |
| # add metric ROC_AUC
 | |
| TRAINING_METRICS = ["binary_crossentropy", "binary_accuracy", dice_coef]
 | |
| # loss = BinaryFocalLoss(gamma=FOCAL_LOSS_GAMMA)
 | |
| weight_for_0 = 0.05
 | |
| weight_for_1 = 0.95
 | |
| loss = weighted_binary_cross_entropy({0: weight_for_0, 1: weight_for_1})
 | |
| optimizer = Adam(learning_rate=INITIAL_LEARNING_RATE)
 | |
| 
 | |
| # Create folder structure in the output directory
 | |
| if path.exists(PROJECT_DIR):
 | |
|     prepare_project_dir(PROJECT_DIR+'_(2)')
 | |
| else:
 | |
|     prepare_project_dir(PROJECT_DIR)
 | |
| 
 | |
| #save params to yaml
 | |
| params = {
 | |
|     "focal_loss_gamma": FOCAL_LOSS_GAMMA,
 | |
|     "initial_learning_rate": INITIAL_LEARNING_RATE,
 | |
|     "max_epochs": MAX_EPOCHS,
 | |
|     "MODEL_SELECTION_METRIC": MODEL_SELECTION_METRIC,
 | |
|     "EARLY_STOPPING_METRIC": EARLY_STOPPING_METRIC,
 | |
|     "train_output_dir": PROJECT_DIR,
 | |
|     "batch_size": BATCH_SIZE,
 | |
|     "optimizer": optimizer,
 | |
|     "loss": loss,
 | |
|     "datetime": print(datetime.now().strftime("%Y-%m-%d"))}
 | |
| dump_dict_to_yaml(params, f"{PROJECT_DIR}", filename=f"params")
 | |
| 
 | |
| # Build the U-Net model
 | |
| detection_model = build_dual_attention_unet(INPUT_SHAPE)
 | |
| detection_model.summary(line_length=120)
 | |
| 
 | |
| # Load all numpy images into RAM
 | |
| images, image_paths = {s: [] for s in args.series}, {}
 | |
| segmentations = []
 | |
| print_(f"> Loading images into RAM...")
 | |
| 
 | |
| # Read the image paths from the data directory.
 | |
| # Texts files are expected to have the name "[series_name].txt"
 | |
| for s in args.series:
 | |
|     with open(path.join(DATA_DIR, f"{s}.txt"), 'r') as f:
 | |
|         image_paths[s] = [l.strip() for l in f.readlines()]
 | |
| with open(path.join(DATA_DIR, f"seg.txt"), 'r') as f:
 | |
|     seg_paths = [l.strip() for l in f.readlines()]
 | |
| num_images = len(seg_paths)
 | |
| 
 | |
| # Read and preprocess each of the paths for each series, and the segmentations.
 | |
| for img_idx in tqdm(range(num_images)): #[:20]): #for less images
 | |
|     img_s = {s: sitk.ReadImage(image_paths[s][img_idx], sitk.sitkFloat32) 
 | |
|         for s in args.series}
 | |
|     seg_s = sitk.ReadImage(seg_paths[img_idx], sitk.sitkFloat32)
 | |
|     img_n, seg_n = preprocess(img_s, seg_s, 
 | |
|         shape=IMAGE_SHAPE, spacing=TARGET_SPACING)
 | |
|     for seq in img_n:
 | |
|         images[seq].append(img_n[seq])
 | |
|     segmentations.append(seg_n)
 | |
| 
 | |
| # Split train and validation 
 | |
| # We use KFold to split the data, but we don't actually do cross validation, we 
 | |
| # just use it to split the data 1:9.
 | |
| # kfold = KFold(10, shuffle=True, random_state=123)
 | |
| # train_idxs, valid_idxs = list(kfold.split(segmentations))[0]
 | |
| # train_idxs = list(train_idxs)
 | |
| # valid_idxs = list(valid_idxs)
 | |
| 
 | |
| yml_paths = read_yaml_to_dict(f'./../data/Nijmegen paths/train_val_test_idxs_{args.fold}.yml')
 | |
| print('test, train paths',yml_paths)
 | |
| train_idxs = yml_paths['train_set0']
 | |
| valid_idxs = yml_paths['val_set0']
 | |
| 
 | |
| 
 | |
| detection_model.compile(
 | |
|     optimizer=optimizer, 
 | |
|     loss=loss,
 | |
|     metrics=TRAINING_METRICS
 | |
|     )
 | |
| 
 | |
| train_generator = BatchGenerator(images, segmentations,
 | |
|     sequences=args.series,
 | |
|     shape=IMAGE_SHAPE,
 | |
|     indexes=train_idxs, 
 | |
|     batch_size=BATCH_SIZE,
 | |
|     shuffle=True, 
 | |
|     augmentation_function=augment
 | |
| )
 | |
| 
 | |
| valid_generator = get_generator(images, segmentations,
 | |
|     sequences=args.series,
 | |
|     shape=IMAGE_SHAPE,
 | |
|     indexes=valid_idxs, 
 | |
|     batch_size=None,
 | |
|     shuffle=False, 
 | |
|     augmentation=None
 | |
| )
 | |
| valid_data = next(valid_generator)
 | |
| print_(f"The shape of valid_data input = {np.shape(valid_data[0])}")
 | |
| print_(f"The shape of valid_data label = {np.shape(valid_data[1])}")
 | |
| 
 | |
| callbacks = [
 | |
|     ModelCheckpoint(
 | |
|         filepath=path.join(PROJECT_DIR, "models", JOB_NAME + ".h5"), 
 | |
|         monitor=MODEL_SELECTION_METRIC,
 | |
|         mode=MODEL_SELECTION_DIRECTION,
 | |
|         verbose=2,
 | |
|         save_best_only=True),
 | |
|     ModelCheckpoint(
 | |
|         filepath=path.join(PROJECT_DIR, "models", JOB_NAME + "_dice" + ".h5"), 
 | |
|         monitor='val_dice_coef',   
 | |
|         mode='max',
 | |
|         verbose=2,
 | |
|         save_best_only=True),
 | |
|     CSVLogger(
 | |
|         filename=path.join(PROJECT_DIR, "logs", f"{JOB_NAME}.csv")),
 | |
|     EarlyStopping(
 | |
|         monitor=EARLY_STOPPING_METRIC,
 | |
|         mode=EARLY_STOPPING_DIRECTION,
 | |
|         patience=EARLY_STOPPING,
 | |
|         verbose=2),
 | |
|     IntermediateImages(
 | |
|         validation_set=valid_data,
 | |
|         sequences=args.series,
 | |
|         prefix=path.join(PROJECT_DIR, "output", JOB_NAME),
 | |
|         num_images=25)
 | |
|     # RocCallback(
 | |
|     #      validation_set=valid_data,
 | |
|     #      num_images=25)
 | |
| ]
 | |
| 
 | |
| detection_model.fit(train_generator,
 | |
|     validation_data    = valid_data,
 | |
|     steps_per_epoch    = len(train_idxs) // BATCH_SIZE, 
 | |
|     epochs             = MAX_EPOCHS,
 | |
|     callbacks          = callbacks,
 | |
|     verbose            = 2
 | |
|     
 | |
|     
 | |
|     ,
 | |
|     )
 |