fast-mri/scripts/1.U-net_chris.py

219 lines
7.5 KiB
Python
Raw Normal View History

2022-03-21 10:14:00 +01:00
import multiprocessing
from os import path
import argparse
import time
from datetime import datetime
import sys
# sys.path.append('./../code')
# from utils_quintin import *
2022-03-23 17:00:22 +01:00
from sfransen.utils_quintin import *
2022-03-21 10:14:00 +01:00
# sys.path.append('./../code/DWI_exp')
# from callbacks import IntermediateImages, dice_coef
# from callbacks import RocCallback
2022-03-23 17:00:22 +01:00
from sfransen.utils_quintin import *
2022-03-21 10:14:00 +01:00
from sfransen.DWI_exp import IntermediateImages, dice_coef
from sfransen.DWI_exp.preprocessing_function import preprocess
2022-04-21 14:29:36 +02:00
from sfransen.DWI_exp.losses import weighted_binary_cross_entropy
2022-03-21 10:14:00 +01:00
import yaml
import numpy as np
from tqdm import tqdm
import tensorflow as tf
tf.compat.v1.disable_eager_execution()
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, CSVLogger
from tensorflow.keras.optimizers import Adam
from sklearn.model_selection import KFold
from sfransen.DWI_exp.helpers import *
from sfransen.DWI_exp.batchgenerator import BatchGenerator
from sfransen.DWI_exp.unet import build_dual_attention_unet
from focal_loss import BinaryFocalLoss
parser = argparse.ArgumentParser(
description='Train a U-Net model for segmentation/detection tasks.' +
'using cross-validation.')
parser.add_argument('--series', '-s',
metavar='[series_name]', required=True, nargs='+',
help='List of series to include, must correspond with' +
"path files in ./data/")
parser.add_argument('-experiment',
2022-04-25 10:22:45 +02:00
help='add experiment title to store the files correctly: test_b50_b400_b800')
parser.add_argument('-fold',
help='import fold'
2022-03-21 10:14:00 +01:00
)
args = parser.parse_args()
# Determine the number of input series
num_series = len(args.series)
# Identify this job by the series included in the training
# Output folder will have this name, e.g.: b0_b50_b100
# JOB_NAME = '_'.join(args.series)
JOB_NAME = args.experiment
DATA_DIR = "./../data/Nijmegen paths/"
# DATA_DIR = "./../data/new/"
# PROJECT_DIR = f"/data/pca-rad/sfransen/train_output/{args.experiment}"
PROJECT_DIR = f"/data/pca-rad/sfransen/train_output/{JOB_NAME}"
# 2 x 2mm2 in-plane resolution, 3.6mm slice thickness
TARGET_SPACING = (0.5, 0.5, 3)
INPUT_SHAPE = (192, 192, 24, num_series) #(64, 64, 20, num_series)
IMAGE_SHAPE = INPUT_SHAPE[:3]
OUTPUT_SHAPE = (192, 192, 24, 1) # One output channel (segmentation)
# Hyperparameters
FOCAL_LOSS_GAMMA = 2
INITIAL_LEARNING_RATE = 1e-4
MAX_EPOCHS = 1500
2022-03-21 10:14:00 +01:00
EARLY_STOPPING = 50
# increase batch size
BATCH_SIZE = 12
MODEL_SELECTION_METRIC = 'val_loss'
MODEL_SELECTION_DIRECTION = "min" # Change to 'max' if higher value is better
EARLY_STOPPING_METRIC = 'val_loss'
EARLY_STOPPING_DIRECTION = 'min'
2022-04-21 14:29:36 +02:00
# MODEL_SELECTION_METRIC = 'weighted_binary_cross_entropy'
# MODEL_SELECTION_DIRECTION = "min" # Change to 'max' if higher value is better
# EARLY_STOPPING_METRIC = 'weighted_binary_cross_entropy'
# EARLY_STOPPING_DIRECTION = 'min'
2022-03-21 10:14:00 +01:00
# Training configuration
# add metric ROC_AUC
TRAINING_METRICS = ["binary_crossentropy", "binary_accuracy", dice_coef]
2022-04-21 14:29:36 +02:00
# loss = BinaryFocalLoss(gamma=FOCAL_LOSS_GAMMA)
weight_for_0 = 0.05
weight_for_1 = 0.95
loss = weighted_binary_cross_entropy({0: weight_for_0, 1: weight_for_1})
2022-03-21 10:14:00 +01:00
optimizer = Adam(learning_rate=INITIAL_LEARNING_RATE)
# Create folder structure in the output directory
if path.exists(PROJECT_DIR):
prepare_project_dir(PROJECT_DIR+'_(2)')
else:
prepare_project_dir(PROJECT_DIR)
#save params to yaml
params = {
"focal_loss_gamma": FOCAL_LOSS_GAMMA,
"initial_learning_rate": INITIAL_LEARNING_RATE,
"max_epochs": MAX_EPOCHS,
"MODEL_SELECTION_METRIC": MODEL_SELECTION_METRIC,
"EARLY_STOPPING_METRIC": EARLY_STOPPING_METRIC,
"train_output_dir": PROJECT_DIR,
"batch_size": BATCH_SIZE,
"optimizer": optimizer,
"loss": loss,
"datetime": print(datetime.now().strftime("%Y-%m-%d"))}
dump_dict_to_yaml(params, f"{PROJECT_DIR}", filename=f"params")
# Build the U-Net model
detection_model = build_dual_attention_unet(INPUT_SHAPE)
detection_model.summary(line_length=120)
# Load all numpy images into RAM
images, image_paths = {s: [] for s in args.series}, {}
segmentations = []
print_(f"> Loading images into RAM...")
# Read the image paths from the data directory.
# Texts files are expected to have the name "[series_name].txt"
for s in args.series:
with open(path.join(DATA_DIR, f"{s}.txt"), 'r') as f:
image_paths[s] = [l.strip() for l in f.readlines()]
with open(path.join(DATA_DIR, f"seg.txt"), 'r') as f:
seg_paths = [l.strip() for l in f.readlines()]
num_images = len(seg_paths)
# Read and preprocess each of the paths for each series, and the segmentations.
2022-04-21 14:29:36 +02:00
for img_idx in tqdm(range(num_images)): #[:40]): #for less images
2022-03-21 10:14:00 +01:00
img_s = {s: sitk.ReadImage(image_paths[s][img_idx], sitk.sitkFloat32)
for s in args.series}
seg_s = sitk.ReadImage(seg_paths[img_idx], sitk.sitkFloat32)
img_n, seg_n = preprocess(img_s, seg_s,
shape=IMAGE_SHAPE, spacing=TARGET_SPACING)
for seq in img_n:
images[seq].append(img_n[seq])
segmentations.append(seg_n)
# Split train and validation
# We use KFold to split the data, but we don't actually do cross validation, we
# just use it to split the data 1:9.
# kfold = KFold(10, shuffle=True, random_state=123)
# train_idxs, valid_idxs = list(kfold.split(segmentations))[0]
# train_idxs = list(train_idxs)
# valid_idxs = list(valid_idxs)
2022-04-25 10:22:45 +02:00
yml_paths = read_yaml_to_dict(f'./../data/Nijmegen paths/train_val_test_idxs_{args.fold}.yml')
print('test, train paths',yml_paths)
2022-03-21 10:14:00 +01:00
train_idxs = yml_paths['train_set0']
valid_idxs = yml_paths['val_set0']
detection_model.compile(
optimizer=optimizer,
loss=loss,
metrics=TRAINING_METRICS
)
train_generator = BatchGenerator(images, segmentations,
sequences=args.series,
shape=IMAGE_SHAPE,
indexes=train_idxs,
batch_size=BATCH_SIZE,
shuffle=True,
augmentation_function=augment
)
valid_generator = get_generator(images, segmentations,
sequences=args.series,
shape=IMAGE_SHAPE,
indexes=valid_idxs,
batch_size=None,
shuffle=False,
augmentation=None
)
valid_data = next(valid_generator)
print_(f"The shape of valid_data input = {np.shape(valid_data[0])}")
print_(f"The shape of valid_data label = {np.shape(valid_data[1])}")
callbacks = [
ModelCheckpoint(
filepath=path.join(PROJECT_DIR, "models", JOB_NAME + ".h5"),
monitor=MODEL_SELECTION_METRIC,
mode=MODEL_SELECTION_DIRECTION,
verbose=2,
save_best_only=True),
ModelCheckpoint(
filepath=path.join(PROJECT_DIR, "models", JOB_NAME + "_dice" + ".h5"),
monitor='val_dice_coef',
mode='max',
verbose=2,
2022-03-21 10:14:00 +01:00
save_best_only=True),
CSVLogger(
filename=path.join(PROJECT_DIR, "logs", f"{JOB_NAME}.csv")),
2022-04-21 14:29:36 +02:00
EarlyStopping(
monitor=EARLY_STOPPING_METRIC,
mode=EARLY_STOPPING_DIRECTION,
patience=EARLY_STOPPING,
verbose=2),
2022-03-21 10:14:00 +01:00
IntermediateImages(
validation_set=valid_data,
sequences=args.series,
prefix=path.join(PROJECT_DIR, "output", JOB_NAME),
num_images=25)
# RocCallback(
# validation_set=valid_data,
# num_images=25)
]
detection_model.fit(train_generator,
validation_data = valid_data,
steps_per_epoch = len(train_idxs) // BATCH_SIZE,
epochs = MAX_EPOCHS,
callbacks = callbacks,
verbose = 2
,
2022-03-21 10:14:00 +01:00
)