2022-03-21 10:14:00 +01:00
|
|
|
import multiprocessing
|
|
|
|
from os import path
|
|
|
|
import argparse
|
|
|
|
import time
|
|
|
|
from datetime import datetime
|
|
|
|
import sys
|
|
|
|
# sys.path.append('./../code')
|
|
|
|
# from utils_quintin import *
|
2022-03-23 17:00:22 +01:00
|
|
|
from sfransen.utils_quintin import *
|
2022-03-21 10:14:00 +01:00
|
|
|
# sys.path.append('./../code/DWI_exp')
|
|
|
|
# from callbacks import IntermediateImages, dice_coef
|
|
|
|
# from callbacks import RocCallback
|
2022-03-23 17:00:22 +01:00
|
|
|
from sfransen.utils_quintin import *
|
2022-03-21 10:14:00 +01:00
|
|
|
from sfransen.DWI_exp import IntermediateImages, dice_coef
|
|
|
|
from sfransen.DWI_exp.preprocessing_function import preprocess
|
2022-04-21 14:29:36 +02:00
|
|
|
from sfransen.DWI_exp.losses import weighted_binary_cross_entropy
|
2022-03-21 10:14:00 +01:00
|
|
|
import yaml
|
|
|
|
import numpy as np
|
|
|
|
from tqdm import tqdm
|
|
|
|
import tensorflow as tf
|
|
|
|
tf.compat.v1.disable_eager_execution()
|
|
|
|
|
|
|
|
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, CSVLogger
|
|
|
|
from tensorflow.keras.optimizers import Adam
|
|
|
|
from sklearn.model_selection import KFold
|
|
|
|
|
|
|
|
from sfransen.DWI_exp.helpers import *
|
|
|
|
from sfransen.DWI_exp.batchgenerator import BatchGenerator
|
|
|
|
|
|
|
|
from sfransen.DWI_exp.unet import build_dual_attention_unet
|
|
|
|
from focal_loss import BinaryFocalLoss
|
|
|
|
|
|
|
|
parser = argparse.ArgumentParser(
|
|
|
|
description='Train a U-Net model for segmentation/detection tasks.' +
|
|
|
|
'using cross-validation.')
|
|
|
|
parser.add_argument('--series', '-s',
|
|
|
|
metavar='[series_name]', required=True, nargs='+',
|
|
|
|
help='List of series to include, must correspond with' +
|
|
|
|
"path files in ./data/")
|
|
|
|
parser.add_argument('-experiment',
|
2022-04-25 10:22:45 +02:00
|
|
|
help='add experiment title to store the files correctly: test_b50_b400_b800')
|
|
|
|
parser.add_argument('-fold',
|
|
|
|
help='import fold'
|
2022-03-21 10:14:00 +01:00
|
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
# Determine the number of input series
|
|
|
|
num_series = len(args.series)
|
|
|
|
|
|
|
|
# Identify this job by the series included in the training
|
|
|
|
# Output folder will have this name, e.g.: b0_b50_b100
|
|
|
|
# JOB_NAME = '_'.join(args.series)
|
|
|
|
JOB_NAME = args.experiment
|
|
|
|
DATA_DIR = "./../data/Nijmegen paths/"
|
|
|
|
# DATA_DIR = "./../data/new/"
|
|
|
|
# PROJECT_DIR = f"/data/pca-rad/sfransen/train_output/{args.experiment}"
|
|
|
|
PROJECT_DIR = f"/data/pca-rad/sfransen/train_output/{JOB_NAME}"
|
|
|
|
# 2 x 2mm2 in-plane resolution, 3.6mm slice thickness
|
|
|
|
TARGET_SPACING = (0.5, 0.5, 3)
|
|
|
|
INPUT_SHAPE = (192, 192, 24, num_series) #(64, 64, 20, num_series)
|
|
|
|
IMAGE_SHAPE = INPUT_SHAPE[:3]
|
|
|
|
OUTPUT_SHAPE = (192, 192, 24, 1) # One output channel (segmentation)
|
|
|
|
|
|
|
|
# Hyperparameters
|
|
|
|
FOCAL_LOSS_GAMMA = 2
|
|
|
|
INITIAL_LEARNING_RATE = 1e-4
|
2022-04-07 09:13:17 +02:00
|
|
|
MAX_EPOCHS = 1500
|
2022-03-21 10:14:00 +01:00
|
|
|
EARLY_STOPPING = 50
|
|
|
|
# increase batch size
|
|
|
|
BATCH_SIZE = 12
|
|
|
|
MODEL_SELECTION_METRIC = 'val_loss'
|
|
|
|
MODEL_SELECTION_DIRECTION = "min" # Change to 'max' if higher value is better
|
|
|
|
EARLY_STOPPING_METRIC = 'val_loss'
|
2022-04-07 09:13:17 +02:00
|
|
|
EARLY_STOPPING_DIRECTION = 'min'
|
2022-04-21 14:29:36 +02:00
|
|
|
# MODEL_SELECTION_METRIC = 'weighted_binary_cross_entropy'
|
|
|
|
# MODEL_SELECTION_DIRECTION = "min" # Change to 'max' if higher value is better
|
|
|
|
# EARLY_STOPPING_METRIC = 'weighted_binary_cross_entropy'
|
|
|
|
# EARLY_STOPPING_DIRECTION = 'min'
|
2022-03-21 10:14:00 +01:00
|
|
|
|
|
|
|
# Training configuration
|
|
|
|
# add metric ROC_AUC
|
|
|
|
TRAINING_METRICS = ["binary_crossentropy", "binary_accuracy", dice_coef]
|
2022-04-21 14:29:36 +02:00
|
|
|
# loss = BinaryFocalLoss(gamma=FOCAL_LOSS_GAMMA)
|
|
|
|
weight_for_0 = 0.05
|
|
|
|
weight_for_1 = 0.95
|
|
|
|
loss = weighted_binary_cross_entropy({0: weight_for_0, 1: weight_for_1})
|
2022-03-21 10:14:00 +01:00
|
|
|
optimizer = Adam(learning_rate=INITIAL_LEARNING_RATE)
|
|
|
|
|
|
|
|
# Create folder structure in the output directory
|
|
|
|
if path.exists(PROJECT_DIR):
|
|
|
|
prepare_project_dir(PROJECT_DIR+'_(2)')
|
|
|
|
else:
|
|
|
|
prepare_project_dir(PROJECT_DIR)
|
|
|
|
|
|
|
|
#save params to yaml
|
|
|
|
params = {
|
|
|
|
"focal_loss_gamma": FOCAL_LOSS_GAMMA,
|
|
|
|
"initial_learning_rate": INITIAL_LEARNING_RATE,
|
|
|
|
"max_epochs": MAX_EPOCHS,
|
|
|
|
"MODEL_SELECTION_METRIC": MODEL_SELECTION_METRIC,
|
|
|
|
"EARLY_STOPPING_METRIC": EARLY_STOPPING_METRIC,
|
|
|
|
"train_output_dir": PROJECT_DIR,
|
|
|
|
"batch_size": BATCH_SIZE,
|
|
|
|
"optimizer": optimizer,
|
|
|
|
"loss": loss,
|
|
|
|
"datetime": print(datetime.now().strftime("%Y-%m-%d"))}
|
|
|
|
dump_dict_to_yaml(params, f"{PROJECT_DIR}", filename=f"params")
|
|
|
|
|
|
|
|
# Build the U-Net model
|
|
|
|
detection_model = build_dual_attention_unet(INPUT_SHAPE)
|
|
|
|
detection_model.summary(line_length=120)
|
|
|
|
|
|
|
|
# Load all numpy images into RAM
|
|
|
|
images, image_paths = {s: [] for s in args.series}, {}
|
|
|
|
segmentations = []
|
|
|
|
print_(f"> Loading images into RAM...")
|
|
|
|
|
|
|
|
# Read the image paths from the data directory.
|
|
|
|
# Texts files are expected to have the name "[series_name].txt"
|
|
|
|
for s in args.series:
|
|
|
|
with open(path.join(DATA_DIR, f"{s}.txt"), 'r') as f:
|
|
|
|
image_paths[s] = [l.strip() for l in f.readlines()]
|
|
|
|
with open(path.join(DATA_DIR, f"seg.txt"), 'r') as f:
|
|
|
|
seg_paths = [l.strip() for l in f.readlines()]
|
|
|
|
num_images = len(seg_paths)
|
|
|
|
|
|
|
|
# Read and preprocess each of the paths for each series, and the segmentations.
|
2022-04-21 14:29:36 +02:00
|
|
|
for img_idx in tqdm(range(num_images)): #[:40]): #for less images
|
2022-03-21 10:14:00 +01:00
|
|
|
img_s = {s: sitk.ReadImage(image_paths[s][img_idx], sitk.sitkFloat32)
|
|
|
|
for s in args.series}
|
|
|
|
seg_s = sitk.ReadImage(seg_paths[img_idx], sitk.sitkFloat32)
|
|
|
|
img_n, seg_n = preprocess(img_s, seg_s,
|
|
|
|
shape=IMAGE_SHAPE, spacing=TARGET_SPACING)
|
|
|
|
for seq in img_n:
|
|
|
|
images[seq].append(img_n[seq])
|
|
|
|
segmentations.append(seg_n)
|
|
|
|
|
|
|
|
# Split train and validation
|
|
|
|
# We use KFold to split the data, but we don't actually do cross validation, we
|
|
|
|
# just use it to split the data 1:9.
|
|
|
|
# kfold = KFold(10, shuffle=True, random_state=123)
|
|
|
|
# train_idxs, valid_idxs = list(kfold.split(segmentations))[0]
|
|
|
|
# train_idxs = list(train_idxs)
|
|
|
|
# valid_idxs = list(valid_idxs)
|
|
|
|
|
2022-04-25 10:22:45 +02:00
|
|
|
yml_paths = read_yaml_to_dict(f'./../data/Nijmegen paths/train_val_test_idxs_{args.fold}.yml')
|
|
|
|
print('test, train paths',yml_paths)
|
2022-03-21 10:14:00 +01:00
|
|
|
train_idxs = yml_paths['train_set0']
|
|
|
|
valid_idxs = yml_paths['val_set0']
|
|
|
|
|
|
|
|
|
|
|
|
detection_model.compile(
|
|
|
|
optimizer=optimizer,
|
|
|
|
loss=loss,
|
|
|
|
metrics=TRAINING_METRICS
|
|
|
|
)
|
|
|
|
|
|
|
|
train_generator = BatchGenerator(images, segmentations,
|
|
|
|
sequences=args.series,
|
|
|
|
shape=IMAGE_SHAPE,
|
|
|
|
indexes=train_idxs,
|
|
|
|
batch_size=BATCH_SIZE,
|
|
|
|
shuffle=True,
|
|
|
|
augmentation_function=augment
|
|
|
|
)
|
|
|
|
|
|
|
|
valid_generator = get_generator(images, segmentations,
|
|
|
|
sequences=args.series,
|
|
|
|
shape=IMAGE_SHAPE,
|
|
|
|
indexes=valid_idxs,
|
|
|
|
batch_size=None,
|
|
|
|
shuffle=False,
|
|
|
|
augmentation=None
|
|
|
|
)
|
|
|
|
valid_data = next(valid_generator)
|
|
|
|
print_(f"The shape of valid_data input = {np.shape(valid_data[0])}")
|
|
|
|
print_(f"The shape of valid_data label = {np.shape(valid_data[1])}")
|
|
|
|
|
|
|
|
callbacks = [
|
|
|
|
ModelCheckpoint(
|
|
|
|
filepath=path.join(PROJECT_DIR, "models", JOB_NAME + ".h5"),
|
|
|
|
monitor=MODEL_SELECTION_METRIC,
|
|
|
|
mode=MODEL_SELECTION_DIRECTION,
|
2022-04-07 09:13:17 +02:00
|
|
|
verbose=2,
|
|
|
|
save_best_only=True),
|
|
|
|
ModelCheckpoint(
|
|
|
|
filepath=path.join(PROJECT_DIR, "models", JOB_NAME + "_dice" + ".h5"),
|
|
|
|
monitor='val_dice_coef',
|
|
|
|
mode='max',
|
|
|
|
verbose=2,
|
2022-03-21 10:14:00 +01:00
|
|
|
save_best_only=True),
|
|
|
|
CSVLogger(
|
|
|
|
filename=path.join(PROJECT_DIR, "logs", f"{JOB_NAME}.csv")),
|
2022-04-21 14:29:36 +02:00
|
|
|
EarlyStopping(
|
|
|
|
monitor=EARLY_STOPPING_METRIC,
|
|
|
|
mode=EARLY_STOPPING_DIRECTION,
|
|
|
|
patience=EARLY_STOPPING,
|
|
|
|
verbose=2),
|
2022-03-21 10:14:00 +01:00
|
|
|
IntermediateImages(
|
|
|
|
validation_set=valid_data,
|
|
|
|
sequences=args.series,
|
|
|
|
prefix=path.join(PROJECT_DIR, "output", JOB_NAME),
|
|
|
|
num_images=25)
|
|
|
|
# RocCallback(
|
|
|
|
# validation_set=valid_data,
|
|
|
|
# num_images=25)
|
|
|
|
]
|
|
|
|
|
|
|
|
detection_model.fit(train_generator,
|
|
|
|
validation_data = valid_data,
|
|
|
|
steps_per_epoch = len(train_idxs) // BATCH_SIZE,
|
|
|
|
epochs = MAX_EPOCHS,
|
|
|
|
callbacks = callbacks,
|
2022-04-07 09:13:17 +02:00
|
|
|
verbose = 2
|
|
|
|
|
|
|
|
|
|
|
|
,
|
2022-03-21 10:14:00 +01:00
|
|
|
)
|