from pickle import TRUE from sfransen.utils_quintin import * import matplotlib.pyplot as plt import argparse import matplotlib.ticker as tkr parser = argparse.ArgumentParser( description='Visualise froc results') parser.add_argument('-saveas', help='') parser.add_argument('-comparison', help='') parser.add_argument('--experiment', '-s', metavar='[series_name]', required=True, nargs='+', help='List of series to include, must correspond with' + "path files in ./data/") args = parser.parse_args() if args.comparison: colors = ['r','r','b','b','g','g','y','y'] plot_type = ['-','--','-','--','-','--','-','--'] else: colors = ['r','b','g','k','y','c'] plot_type = ['-','-','-','-','-','-'] experiments = args.experiment print(experiments) experiment_path = [] experiment_metrics = {} auroc = [] for idx in range(len(args.experiment)): experiment_path = f'./../train_output/{experiments[idx]}/froc_metrics_focal_10.yml' experiment_metrics = read_yaml_to_dict(experiment_path) auroc.append(round(experiment_metrics['auroc'],3)) fig1, ax1 = plt.subplots(1,1) ax1.plot(experiment_metrics["FP_per_case"], experiment_metrics["sensitivity"],color=colors[idx],linestyle=plot_type[idx]) ax1.set(xscale="log") ax1.xaxis.set_minor_locator(tkr.LogLocator(base=10, subs='all')) ax1.xaxis.set_minor_formatter(tkr.NullFormatter()) ax1.xaxis.set_major_formatter(tkr.ScalarFormatter()) ax1.grid(True, which="both", ls="--", c='#d3d3d3') ax1.set_xlim(left=0, right=3) ax1.xaxis.set_major_locator(tkr.FixedLocator([0,0.1,1,3])) plt.figure(2) plt.plot(experiment_metrics["fpr"], experiment_metrics["tpr"],color=colors[idx],linestyle=plot_type[idx]) print(auroc) experiments = [exp.replace('train_10h_', '') for exp in experiments] experiments = [exp.replace('train_n0.001_', '') for exp in experiments] experiments = [exp.replace('_', ' ') for exp in experiments] # experiments = ['10% noise','1% noise','0.1% noise','0.05% noise'] plt.figure(1) plt.title('fROC curve') plt.xlabel('False positive per case') plt.ylabel('Sensitivity') plt.legend(experiments,loc='lower right') # plt.xlim([0,50]) plt.grid() plt.ylim([0,1]) plt.yticks([0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1]) plt.savefig(f"./../train_output/fROC_{args.saveas}.png", dpi=300) concat_func = lambda x,y: x + " (" + str(y) + ")" experiments_auroc = list(map(concat_func,experiments,auroc)) # list the map function plt.figure(2) plt.title('ROC curve') plt.legend(experiments_auroc,loc='lower right') plt.xlabel('False positive rate') plt.ylabel('True positive rate') plt.grid() plt.savefig(f"./../train_output/ROC_{args.saveas}.png", dpi=300)