fast-mri/scripts/3.make_train_val_test_index...

107 lines
3.9 KiB
Python
Executable File

import argparse
import random
import sys
sys.path.append('./../code')
from utils_quintin import list_from_file, dump_dict_to_yaml, print_p
################################ README ######################################
# NEW - This script will create indexes for the training, validation and test
# sets. Create a yaml file with 10 sets of indexes for train, val and test. So
# that they can be used for 10 fold cross validation when needed. The filename
# should contain an integer, which is the seed used to generate the indexes. The
# filename also indicates the type of split used.
# Output structure:
# trainSet0: [indexes]
# valSet0: [indexes]
# testSet0: [indexes]
# trainSet1: [indexes]
# valSet1: [indexes]
# testSet1: [indexes]
# etc...
################################ PARSER ########################################
def parse_input_args():
parser = argparse.ArgumentParser(description='Parse arguments for splitting training, validation and the test set.')
parser.add_argument('-n',
'--num_folds',
type=int,
default=1,
help='The number of folds in total. This amount of index sets will be created.')
parser.add_argument('-p',
'--path_to_paths_file',
type=str,
default="./../data/Nijmegen paths/seg.txt",
help='Path to the .txt file containting paths to the nifti files.')
parser.add_argument('-s',
'--split',
type=str,
default="80/10/10",
help='Train/validation/test split in percentages.')
args = parser.parse_args()
# split the given split string.
args.p_train = int(args.split.split('/')[0])
args.p_val = int(args.split.split('/')[1])
args.p_test = int(args.split.split('/')[2])
assert args.p_train + args.p_val + args.p_test == 100, "The train, val, test split to sum to 100%."
return args
################################################################################
SEED = 3478
if __name__ == '__main__':
print_p('\n\nMaking Train - Validation - Test indexes based.')
# Parse some arguments
args = parse_input_args()
print_p(args)
# Read the amount of observations/subjects in the data.
t2_paths = list_from_file(args.path_to_paths_file)
num_obs = len(t2_paths)
print_p(f"Number of observations in {args.path_to_paths_file}: {len(t2_paths)}")
# Create cutoff points for training, validation and test set.
train_cutoff = int(args.p_train/100 * num_obs)
val_cutoff = int(args.p_val/100 * num_obs) + train_cutoff
test_cutoff = int(args.p_test/100 * num_obs) + val_cutoff
print(f"\ncutoffs: {train_cutoff}, {val_cutoff}, {test_cutoff}")
# Create dict that will hold all the data
data_dict = {}
data_dict["init_seed"] = SEED
data_dict["split"] = args.split
# loop over the amount of folds, that many sets will be created in a yaml file.
for set_idx in range(args.num_folds):
# Set new seed first
random.seed(SEED + set_idx)
# shuffle the indexes
indexes = list(range(num_obs))
random.shuffle(indexes)
train_idxs = indexes[:train_cutoff]
val_idxs = indexes[train_cutoff:val_cutoff]
test_idxs = indexes[val_cutoff:test_cutoff]
data_dict[f"train_set{set_idx}"] = train_idxs
data_dict[f"val_set{set_idx}"] = val_idxs
data_dict[f"test_set{set_idx}"] = test_idxs
for key in data_dict:
if type(data_dict[key]) == list:
print(f"{key}: {len(data_dict[key])}")
dump_dict_to_yaml(data_dict, "./../data", filename=f"train_val_test_idxs", verbose=False)