505 lines
22 KiB
Python
Executable File
505 lines
22 KiB
Python
Executable File
# Copyright 2022 Diagnostic Image Analysis Group, Radboudumc, Nijmegen, The Netherlands
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
import os
|
|
import numpy as np
|
|
from tqdm import tqdm
|
|
from scipy import ndimage
|
|
from sklearn.metrics import roc_curve, auc
|
|
import concurrent.futures
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from pathlib import Path
|
|
import itertools
|
|
|
|
from typing import List, Tuple, Dict, Any, Union, Optional, Callable, Iterable, Hashable, Sized
|
|
try:
|
|
import numpy.typing as npt
|
|
except ImportError:
|
|
pass
|
|
|
|
from image_utils import (
|
|
resize_image_with_crop_or_pad, read_label, read_prediction
|
|
)
|
|
from analysis_utils import (
|
|
parse_detection_map, calculate_iou, calculate_dsc
|
|
)
|
|
|
|
|
|
# Compute base prediction metrics TP/FP/FN with associated model confidences
|
|
def evaluate_case(
|
|
detection_map: "Union[npt.NDArray[np.float32], str]",
|
|
label: "Union[npt.NDArray[np.int32], str]",
|
|
min_overlap: float = 0.10,
|
|
overlap_func: "Union[str, Callable[[npt.NDArray[np.float32], npt.NDArray[np.int32]], float]]" = 'IoU',
|
|
multiple_lesion_candidates_selection_criteria: str = 'overlap',
|
|
allow_unmatched_candidates_with_minimal_overlap: bool = True
|
|
) -> Tuple[List[Tuple[int, float, float]], int]:
|
|
"""
|
|
Gather the list of lesion candidates, and classify in TP/FP/FN.
|
|
- multiple_lesion_candidates_selection_criteria: when multiple lesion candidates have overlap with the same
|
|
ground truth mask, use 'overlap' or 'confidence' to choose
|
|
which lesion is matched against the ground truth mask.
|
|
|
|
Returns:
|
|
- a list of tuples with:
|
|
(is_lesion, prediction confidence, Dice similarity coefficient, number of voxels in label)
|
|
- number of ground truth lesions
|
|
"""
|
|
y_list: List[Tuple[int, float, float]] = []
|
|
if isinstance(label, str):
|
|
label = read_label(label)
|
|
if isinstance(detection_map, str):
|
|
detection_map = read_prediction(detection_map)
|
|
if overlap_func == 'IoU':
|
|
overlap_func = calculate_iou
|
|
elif overlap_func == 'DSC':
|
|
overlap_func = calculate_dsc
|
|
else:
|
|
raise ValueError(f"Overlap function with name {overlap_func} not recognized. Supported are 'IoU' and 'DSC'")
|
|
|
|
# convert dtype to float32
|
|
label = label.astype('int32')
|
|
detection_map = detection_map.astype('float32')
|
|
|
|
if detection_map.shape[0] < label.shape[0]:
|
|
print("Warning: padding prediction to match label!")
|
|
detection_map = resize_image_with_crop_or_pad(detection_map, label.shape)
|
|
|
|
confidences, indexed_pred = parse_detection_map(detection_map)
|
|
|
|
lesion_candidates_best_overlap: Dict[str, float] = {}
|
|
|
|
# note to stefan: check wether the if statements are correct and that the append goes correct
|
|
if label.any():
|
|
# for each malignant scan
|
|
labeled_gt, num_gt_lesions = ndimage.label(label, np.ones((3, 3, 3)))
|
|
# print("test3, werkt if label.any", num_gt_lesions) WE
|
|
for lesiong_id in range(1, num_gt_lesions+1):
|
|
# for each lesion in ground-truth (GT) label
|
|
gt_lesion_mask = (labeled_gt == lesiong_id)
|
|
|
|
# collect indices of lesion candidates that have any overlap with the current GT lesion
|
|
overlapping_lesion_candidate_indices = set(np.unique(indexed_pred[gt_lesion_mask]))
|
|
overlapping_lesion_candidate_indices -= {0} # remove index 0, if present
|
|
|
|
# collect lesion candidates for current GT lesion
|
|
lesion_candidates_for_target_gt: List[Dict[str, Union[int, float]]] = []
|
|
for lesion_candidate_id, lesion_confidence in confidences:
|
|
if lesion_candidate_id in overlapping_lesion_candidate_indices:
|
|
# calculate overlap between lesion candidate and GT mask
|
|
lesion_pred_mask = (indexed_pred == lesion_candidate_id)
|
|
overlap_score = overlap_func(lesion_pred_mask, gt_lesion_mask)
|
|
|
|
# keep track of the highest overlap a lesion candidate has with any GT lesion
|
|
lesion_candidates_best_overlap[lesion_candidate_id] = max(
|
|
overlap_score, lesion_candidates_best_overlap.get(lesion_candidate_id, 0)
|
|
)
|
|
|
|
# store lesion candidate info for current GT mask
|
|
lesion_candidates_for_target_gt.append({
|
|
'id': lesion_candidate_id,
|
|
'confidence': lesion_confidence,
|
|
'overlap': overlap_score,
|
|
})
|
|
print("test 4, lesion_candidates_for_target_gt:",lesion_candidates_for_target_gt)
|
|
# Min overlap wordt niet behaald: +- 0.001
|
|
if len(lesion_candidates_for_target_gt) == 0:
|
|
# no lesion candidate matched with GT mask. Add FN.
|
|
y_list.append((1, 0., 0.))
|
|
elif len(lesion_candidates_for_target_gt) == 1:
|
|
# single lesion candidate overlapped with GT mask. Add TP if overlap is sufficient, or FN otherwise.
|
|
candidate_info = lesion_candidates_for_target_gt[0]
|
|
lesion_pred_mask = (indexed_pred == candidate_info['id'])
|
|
|
|
if candidate_info['overlap'] > min_overlap:
|
|
# overlap between lesion candidate and GT mask is sufficient, add TP
|
|
indexed_pred[lesion_pred_mask] = 0 # remove lesion candidate after assignment
|
|
y_list.append((1, candidate_info['confidence'], candidate_info['overlap']))
|
|
else:
|
|
# overlap between lesion candidate and GT mask is insufficient, add FN
|
|
y_list.append((1, 0., 0.))
|
|
else:
|
|
# multiple predictions for current GT lesion
|
|
# sort lesion candidates based on overlap or confidence
|
|
key = multiple_lesion_candidates_selection_criteria
|
|
lesion_candidates_for_target_gt = sorted(lesion_candidates_for_target_gt, key=lambda x: x[key], reverse=True)
|
|
|
|
gt_lesion_matched = False
|
|
for candidate_info in lesion_candidates_for_target_gt:
|
|
lesion_pred_mask = (indexed_pred == candidate_info['id'])
|
|
|
|
if candidate_info['overlap'] > min_overlap:
|
|
indexed_pred[lesion_pred_mask] = 0
|
|
y_list.append((1, candidate_info['confidence'], candidate_info['overlap']))
|
|
gt_lesion_matched = True
|
|
break
|
|
|
|
if not gt_lesion_matched:
|
|
# ground truth lesion not matched to a lesion candidate. Add FN.
|
|
y_list.append((1, 0., 0.))
|
|
|
|
# Remaining lesions are FPs
|
|
remaining_lesions = set(np.unique(indexed_pred))
|
|
remaining_lesions -= {0} # remove index 0, if present
|
|
for lesion_candidate_id, lesion_confidence in confidences:
|
|
if lesion_candidate_id in remaining_lesions:
|
|
overlap_score = lesion_candidates_best_overlap.get(lesion_candidate_id, 0)
|
|
if allow_unmatched_candidates_with_minimal_overlap and overlap_score > min_overlap:
|
|
# The lesion candidate was not matched to a GT lesion, but did have overlap > min_overlap
|
|
# with a GT lesion. The GT lesion is, however, matched to another lesion candidate.
|
|
# In this operation mode, this lesion candidate is not considered as a false positive.
|
|
pass
|
|
else:
|
|
y_list.append((0, lesion_confidence, 0.)) # add FP
|
|
# print("test 4, gaat alles hiernaartoe?") == JA
|
|
# print("test 3, hoe ziet y_list eruit na labels",y_list)
|
|
|
|
else:
|
|
# for benign case, all predictions are FPs
|
|
num_gt_lesions = 0
|
|
if len(confidences) > 0:
|
|
for _, lesion_confidence in confidences:
|
|
y_list.append((0, lesion_confidence, 0.))
|
|
else:
|
|
y_list.append((0, 0., 0.)) # avoid empty list
|
|
|
|
return y_list, num_gt_lesions
|
|
|
|
|
|
# Calculate macro metrics (true positives (TP), false positives (FP))
|
|
def counts_from_lesion_evaluations(
|
|
y_list: List[Tuple[int, float, float]],
|
|
thresholds: "Optional[npt.NDArray[np.float64]]" = None
|
|
) -> "Tuple[npt.NDArray[np.float32], npt.NDArray[np.float32], npt.NDArray[np.float64], int]":
|
|
"""
|
|
Calculate true positives (TP) and false positives (FP) as function of threshold,
|
|
based on the case evaluations from `evaluate_case`.
|
|
"""
|
|
# sort predictions
|
|
# print("test 4, zitten er predicties bij leasions voor sort?",y_list)
|
|
y_list.sort()
|
|
# print("test 4, zitten er predicties bij leasions na sort?",y_list,'len y_lsit:',len(y_list))
|
|
# collect targets and predictions
|
|
y_true: "npt.NDArray[np.float64]" = np.array([target for target, *_ in y_list])
|
|
y_pred: "npt.NDArray[np.float64]" = np.array([pred for _, pred, *_ in y_list])
|
|
# print("test,zijn de laatste y-pred hoog?", y_pred)
|
|
# calculate number of lesions
|
|
num_lesions = y_true.sum()
|
|
|
|
if thresholds is None:
|
|
# collect thresholds for FROC analysis
|
|
thresholds = np.unique(y_pred)
|
|
thresholds[::-1].sort() # sort thresholds in descending order (inplace)
|
|
|
|
# for >10,000 thresholds: resample to 10,000 unique thresholds, while also
|
|
# keeping all thresholds higher than 0.8 and the first 20 thresholds
|
|
if len(thresholds) > 10_000:
|
|
rng = np.arange(1, len(thresholds), len(thresholds)/10_000, dtype=np.int32)
|
|
st = [thresholds[i] for i in rng]
|
|
low_thresholds = thresholds[-20:]
|
|
thresholds = np.array([t for t in thresholds if t > 0.8 or t in st or t in low_thresholds])
|
|
|
|
# define placeholders
|
|
FP: "npt.NDArray[np.float32]" = np.zeros_like(thresholds, dtype=np.float32)
|
|
TP: "npt.NDArray[np.float32]" = np.zeros_like(thresholds, dtype=np.float32)
|
|
|
|
# for each threshold: count FPs and calculate the sensitivity
|
|
for i, th in enumerate(thresholds):
|
|
if th > 0:
|
|
y_pred_thresholded = (y_pred >= th).astype(int)
|
|
tp = np.sum(y_true*y_pred_thresholded)
|
|
fp = np.sum(y_pred_thresholded - y_true*y_pred_thresholded)
|
|
# print("test, is y_pred_thresholded altijd 0?",y_pred_thresholded)
|
|
# update FROC with new point
|
|
FP[i] = fp
|
|
TP[i] = tp
|
|
else:
|
|
# extend FROC curve to infinity
|
|
TP[i] = TP[-2]
|
|
FP[i] = np.inf
|
|
|
|
|
|
# print("test if tp werkt",TP)
|
|
# print("test if fp werkt",FP)
|
|
return TP, FP, thresholds, num_lesions
|
|
|
|
|
|
# Calculate FROC metrics (FP rate, sensitivity)
|
|
def froc_from_lesion_evaluations(y_list, num_patients, thresholds=None):
|
|
# calculate counts
|
|
TP, FP, thresholds, num_lesions = counts_from_lesion_evaluations(
|
|
y_list=y_list, thresholds=thresholds
|
|
)
|
|
|
|
# calculate FROC metrics from counts
|
|
sensitivity = TP / num_lesions if num_lesions > 0 else np.nan
|
|
# print('test,Hieronder staat de tp waarde:',TP)
|
|
# print('test,Hieronder staat de num_lesions waarde:',num_lesions)
|
|
|
|
FP_per_case = FP / num_patients
|
|
|
|
return sensitivity, FP_per_case, thresholds, num_lesions
|
|
|
|
|
|
def ap_from_lesion_evaluations(y_list, thresholds=None):
|
|
# calculate counts
|
|
TP, FP, thresholds, num_lesions = counts_from_lesion_evaluations(
|
|
y_list=y_list, thresholds=thresholds
|
|
)
|
|
|
|
# calculate precision (lesion-level)
|
|
precision = TP / (TP + FP)
|
|
precision = np.append(precision, [0])
|
|
|
|
# calculate recall (lesion-level)
|
|
FN = num_lesions - TP
|
|
recall = TP / (TP + FN)
|
|
recall = np.append(recall, recall[-1:])
|
|
|
|
# calculate average precission (lesion-level)
|
|
AP = np.trapz(y=precision, x=recall)
|
|
|
|
return AP, precision, recall, thresholds
|
|
|
|
|
|
# Compute full FROC
|
|
def froc(
|
|
y_det: "Iterable[Union[npt.NDArray[np.float64], str, Path]]",
|
|
y_true: "Iterable[Union[npt.NDArray[np.float64], str, Path]]",
|
|
subject_list: Optional[Iterable[Hashable]] = None,
|
|
min_overlap=0.10,
|
|
overlap_func: "Union[str, Callable[[npt.NDArray[np.float32], npt.NDArray[np.int32]], float]]" = 'IoU',
|
|
case_confidence: str = 'max',
|
|
multiple_lesion_candidates_selection_criteria='overlap',
|
|
allow_unmatched_candidates_with_minimal_overlap=True,
|
|
flat: Optional[bool] = None,
|
|
num_parallel_calls: int = 8,
|
|
verbose: int = 0,
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
FROC evaluation pipeline
|
|
(written 19 January 2022 by Joeran Bosma)
|
|
|
|
Usage:
|
|
For normal usage of the FROC evaluation pipeline, use the function `froc` with parameters
|
|
`y_det`, `y_true` and (optional) `subject_list`. Please note that this function is written
|
|
for binary 3D FROC analysis.
|
|
|
|
- `y_det`: iterable of all detection_map volumes to evaluate. Alternatively, y_det may contain
|
|
filenames ending in .nii.gz/.mha/.mhd/.npy/.npz, which will be loaded on-the-fly.
|
|
Provide an array of shape `(num_samples, D, H, W)`, where D, H, W are the spatial
|
|
dimensions (depth, height and width).
|
|
- `y_true`: iterable of all ground truth labels. Alternatively, y_det may contain filenames
|
|
ending in .nii.gz/.mha/.mhd/.npy/.npz, which should contain binary labels and
|
|
will be loaded on-the-fly. Provide an array of the same shape as `y_det`. Use
|
|
`1` to encode ground truth lesion, and `0` to encode background.
|
|
|
|
Additional settings:
|
|
For more control over the FROC evaluation pipeline, use:
|
|
- `min_overlap`: defines the minimal required Intersection over Union (IoU) or Dice similarity
|
|
coefficient (DSC) between a lesion candidate and ground truth lesion, to be
|
|
counted as a true positive detection.
|
|
|
|
"""
|
|
# Initialize Lists
|
|
roc_true = {}
|
|
roc_pred = {}
|
|
y_list = []
|
|
num_lesions = 0
|
|
|
|
if subject_list is None:
|
|
# generate indices to keep track of each case during multiprocessing
|
|
subject_list = itertools.count()
|
|
if flat is None:
|
|
flat = True
|
|
|
|
with ThreadPoolExecutor(max_workers=num_parallel_calls) as pool:
|
|
# define the functions that need to be processed: compute_pred_vector, with each individual
|
|
# detection_map prediction, ground truth label and parameters
|
|
future_to_args = {
|
|
pool.submit(evaluate_case, y_pred, y_true, min_overlap=min_overlap, overlap_func=overlap_func,
|
|
multiple_lesion_candidates_selection_criteria=multiple_lesion_candidates_selection_criteria,
|
|
allow_unmatched_candidates_with_minimal_overlap=allow_unmatched_candidates_with_minimal_overlap): idx
|
|
for (y_pred, y_true, idx) in zip(y_det, y_true, subject_list)
|
|
}
|
|
|
|
# process the cases in parallel
|
|
iterator = concurrent.futures.as_completed(future_to_args)
|
|
if verbose:
|
|
total: Optional[int] = None
|
|
if isinstance(subject_list, Sized):
|
|
total = len(subject_list)
|
|
iterator = tqdm(iterator, desc='Computing FROC', total=total)
|
|
for future in iterator:
|
|
try:
|
|
res = future.result()
|
|
except Exception as e:
|
|
print(f"Exception: {e}")
|
|
else:
|
|
# unpack results
|
|
y_list_pat, num_lesions_gt = res
|
|
# note: y_list_pat contains: is_lesion, confidence[, Dice, gt num voxels]
|
|
# print("test 3,", y_list_pat)
|
|
# aggregate results
|
|
idx = future_to_args[future]
|
|
# print("test2, indx", idx)
|
|
# test: allemaal ingelezen
|
|
roc_true[idx] = np.max([a[0] for a in y_list_pat])
|
|
# print("test2, roc_true",roc_true)
|
|
if case_confidence == 'max':
|
|
# take highest lesion confidence as case-level confidence
|
|
roc_pred[idx] = np.max([a[1] for a in y_list_pat])
|
|
elif case_confidence == 'bayesian':
|
|
# if a_i is the probability the i-th lesion is csPCa, then the case-level
|
|
# probability to have any csPCa lesion is 1 - Π_i{ 1 - a_i}
|
|
roc_pred[idx] = 1 - np.prod([(1-a[1]) for a in y_list_pat])
|
|
else:
|
|
raise ValueError(f"Patient confidence calculation method not recognised. Got: {case_confidence}.")
|
|
|
|
# accumulate outputs
|
|
y_list += y_list_pat
|
|
num_lesions += num_lesions_gt
|
|
|
|
# print("test2,heeft y-list ook leasie pred:",y_list)
|
|
# calculate statistics
|
|
num_patients = len(roc_true)
|
|
|
|
# get lesion-level results
|
|
sensitivity, FP_per_case, thresholds, num_lesions = froc_from_lesion_evaluations(
|
|
y_list=y_list, num_patients=num_patients
|
|
)
|
|
|
|
# calculate recall, precision and average precision
|
|
AP, precision, recall, _ = ap_from_lesion_evaluations(y_list, thresholds=thresholds)
|
|
|
|
# calculate case-level AUROC
|
|
fpr, tpr, _ = roc_curve(y_true=[roc_true[s] for s in subject_list],
|
|
y_score=[roc_pred[s] for s in subject_list],
|
|
pos_label=1)
|
|
auc_score = auc(fpr, tpr)
|
|
|
|
if flat:
|
|
# flatten roc_true and roc_pred
|
|
roc_true_flat = [roc_true[s] for s in subject_list]
|
|
roc_pred_flat = [roc_pred[s] for s in subject_list]
|
|
|
|
metrics = {
|
|
"FP_per_case": FP_per_case,
|
|
"sensitivity": sensitivity,
|
|
"thresholds": thresholds,
|
|
"num_lesions": num_lesions,
|
|
"num_patients": num_patients,
|
|
"roc_true": (roc_true_flat if flat else roc_true),
|
|
"roc_pred": (roc_pred_flat if flat else roc_pred),
|
|
"AP": AP,
|
|
"precision": precision,
|
|
"recall": recall,
|
|
|
|
# patient-level predictions
|
|
'auroc': auc_score,
|
|
'tpr': tpr,
|
|
'fpr': fpr,
|
|
}
|
|
|
|
return metrics
|
|
|
|
|
|
def froc_for_folder(
|
|
y_det_dir: Union[Path, str],
|
|
y_true_dir: Optional[Union[Path, str]] = None,
|
|
subject_list: Optional[List[str]] = None,
|
|
min_overlap: float = 0.10,
|
|
overlap_func: "Union[str, Callable[[npt.NDArray[np.float32], npt.NDArray[np.int32]], float]]" = 'IoU',
|
|
case_confidence: str = 'max',
|
|
flat: Optional[bool] = None,
|
|
num_parallel_calls: int = 8,
|
|
verbose: int = 1
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Perform FROC evaluation for all samples found in y_det_dir, or the samples specified in the subject_list
|
|
|
|
Input:
|
|
- y_det_dir: path to folder containing the detection maps
|
|
- y_true_dir: (optioinal) allow labels to be stored in a different folder
|
|
- min_overlap: minimum overlap threshold
|
|
- overlap_func: intersection over union (IoU), Dice similarity coefficient (DSC), or custom function
|
|
"""
|
|
if y_true_dir is None:
|
|
y_true_dir = y_det_dir
|
|
|
|
y_det = []
|
|
y_true = []
|
|
if subject_list:
|
|
# collect the detection maps and labels for each case specified in the subject list
|
|
for subject_id in subject_list:
|
|
# construct paths to detection maps and labels
|
|
# print(np.type(subject_id))
|
|
# print(subject_list)
|
|
for postfix in [
|
|
"_detection_map.nii.gz", "_detection_map.npy", "_detection_map.npz",
|
|
".nii.gz", ".npy", ".npz", "_pred.nii.gz",
|
|
]:
|
|
detection_path = os.path.join(y_det_dir, f"{subject_id}{postfix}")
|
|
if os.path.exists(detection_path):
|
|
break
|
|
|
|
for postfix in [
|
|
"_label.nii.gz", "label.npy", "label.npz", "_seg.nii.gz",
|
|
]:
|
|
label_path = os.path.join(y_true_dir, f"{subject_id}{postfix}")
|
|
if os.path.exists(label_path):
|
|
break
|
|
if not os.path.exists(label_path):
|
|
assert y_true_dir != y_det_dir, f"Could not find label for {subject_id}!"
|
|
for postfix in [
|
|
".nii.gz", ".npy", ".npz",
|
|
]:
|
|
label_path = os.path.join(y_true_dir, f"{subject_id}{postfix}")
|
|
if os.path.exists(label_path):
|
|
break
|
|
|
|
# collect file paths
|
|
y_det += [detection_path]
|
|
y_true += [label_path]
|
|
else:
|
|
# collect all detection maps found in detection_map_dir
|
|
file_list = sorted(os.listdir(y_det_dir))
|
|
subject_list = []
|
|
if verbose >= 1:
|
|
print(f"Found {len(file_list)} files in the input directory, collecting detection_mapes with " +
|
|
"_detection_map.nii.gz and labels with _label.nii.gz..")
|
|
|
|
# collect filenames of detection_map predictions and labels
|
|
for fn in file_list:
|
|
if '_detection_map' in fn:
|
|
y_det += [os.path.join(y_det_dir, fn)]
|
|
y_true += [os.path.join(y_true_dir, fn.replace('_detection_map', '_label'))]
|
|
subject_list += [fn]
|
|
|
|
# ensure files exist
|
|
for detection_path in y_det:
|
|
assert os.path.exists(detection_path), f"Could not find detection map for {subject_id} at {detection_path}!"
|
|
for label_path in y_true:
|
|
assert os.path.exists(label_path), f"Could not find label for {subject_id} at {label_path}!"
|
|
|
|
if verbose >= 1:
|
|
print(f"Found prediction and label for {len(y_det)} cases. Here are some examples:")
|
|
print(subject_list[0:5])
|
|
|
|
# perform FROC evaluation with compiled file lists
|
|
return froc(y_det=y_det, y_true=y_true, subject_list=subject_list,
|
|
min_overlap=min_overlap, overlap_func=overlap_func, case_confidence=case_confidence,
|
|
flat=flat, num_parallel_calls=num_parallel_calls, verbose=verbose) |