107 lines
3.9 KiB
Python
107 lines
3.9 KiB
Python
|
import argparse
|
||
|
import random
|
||
|
import sys
|
||
|
sys.path.append('./../code')
|
||
|
from utils_quintin import list_from_file, dump_dict_to_yaml, print_p
|
||
|
|
||
|
################################ README ######################################
|
||
|
# NEW - This script will create indexes for the training, validation and test
|
||
|
# sets. Create a yaml file with 10 sets of indexes for train, val and test. So
|
||
|
# that they can be used for 10 fold cross validation when needed. The filename
|
||
|
# should contain an integer, which is the seed used to generate the indexes. The
|
||
|
# filename also indicates the type of split used.
|
||
|
# Output structure:
|
||
|
# trainSet0: [indexes]
|
||
|
# valSet0: [indexes]
|
||
|
# testSet0: [indexes]
|
||
|
# trainSet1: [indexes]
|
||
|
# valSet1: [indexes]
|
||
|
# testSet1: [indexes]
|
||
|
# etc...
|
||
|
|
||
|
|
||
|
################################ PARSER ########################################
|
||
|
|
||
|
|
||
|
def parse_input_args():
|
||
|
parser = argparse.ArgumentParser(description='Parse arguments for splitting training, validation and the test set.')
|
||
|
|
||
|
parser.add_argument('-n',
|
||
|
'--num_folds',
|
||
|
type=int,
|
||
|
default=1,
|
||
|
help='The number of folds in total. This amount of index sets will be created.')
|
||
|
|
||
|
parser.add_argument('-p',
|
||
|
'--path_to_paths_file',
|
||
|
type=str,
|
||
|
default="./../data/Nijmegen paths/seg.txt",
|
||
|
help='Path to the .txt file containting paths to the nifti files.')
|
||
|
|
||
|
parser.add_argument('-s',
|
||
|
'--split',
|
||
|
type=str,
|
||
|
default="80/10/10",
|
||
|
help='Train/validation/test split in percentages.')
|
||
|
|
||
|
args = parser.parse_args()
|
||
|
|
||
|
# split the given split string.
|
||
|
args.p_train = int(args.split.split('/')[0])
|
||
|
args.p_val = int(args.split.split('/')[1])
|
||
|
args.p_test = int(args.split.split('/')[2])
|
||
|
|
||
|
assert args.p_train + args.p_val + args.p_test == 100, "The train, val, test split to sum to 100%."
|
||
|
|
||
|
return args
|
||
|
|
||
|
|
||
|
################################################################################
|
||
|
SEED = 3478
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
print_p('\n\nMaking Train - Validation - Test indexes based.')
|
||
|
|
||
|
# Parse some arguments
|
||
|
args = parse_input_args()
|
||
|
print_p(args)
|
||
|
|
||
|
# Read the amount of observations/subjects in the data.
|
||
|
t2_paths = list_from_file(args.path_to_paths_file)
|
||
|
num_obs = len(t2_paths)
|
||
|
print_p(f"Number of observations in {args.path_to_paths_file}: {len(t2_paths)}")
|
||
|
|
||
|
# Create cutoff points for training, validation and test set.
|
||
|
train_cutoff = int(args.p_train/100 * num_obs)
|
||
|
val_cutoff = int(args.p_val/100 * num_obs) + train_cutoff
|
||
|
test_cutoff = int(args.p_test/100 * num_obs) + val_cutoff
|
||
|
print(f"\ncutoffs: {train_cutoff}, {val_cutoff}, {test_cutoff}")
|
||
|
|
||
|
# Create dict that will hold all the data
|
||
|
data_dict = {}
|
||
|
data_dict["init_seed"] = SEED
|
||
|
data_dict["split"] = args.split
|
||
|
|
||
|
# loop over the amount of folds, that many sets will be created in a yaml file.
|
||
|
for set_idx in range(args.num_folds):
|
||
|
|
||
|
# Set new seed first
|
||
|
random.seed(SEED + set_idx)
|
||
|
|
||
|
# shuffle the indexes
|
||
|
indexes = list(range(num_obs))
|
||
|
random.shuffle(indexes)
|
||
|
|
||
|
train_idxs = indexes[:train_cutoff]
|
||
|
val_idxs = indexes[train_cutoff:val_cutoff]
|
||
|
test_idxs = indexes[val_cutoff:test_cutoff]
|
||
|
|
||
|
data_dict[f"train_set{set_idx}"] = train_idxs
|
||
|
data_dict[f"val_set{set_idx}"] = val_idxs
|
||
|
data_dict[f"test_set{set_idx}"] = test_idxs
|
||
|
|
||
|
for key in data_dict:
|
||
|
if type(data_dict[key]) == list:
|
||
|
print(f"{key}: {len(data_dict[key])}")
|
||
|
|
||
|
dump_dict_to_yaml(data_dict, "./../data", filename=f"train_val_test_idxs", verbose=False)
|