fast-mri/scripts/8.Visualize_training.py

60 lines
1.5 KiB
Python
Executable File

import matplotlib.pyplot as plt
import pandas as pd
import glob
import argparse
import os
#create parser
def parse_input_args():
parser = argparse.ArgumentParser(description='Parse arguments for training a Reconstruction model')
parser.add_argument('train_out_dir',
type=str,
help='Directory name in train_output dir of the desired experiment folder. There should be a .csv file in this directory with train statistics.')
args = parser.parse_args()
return args
args = parse_input_args()
print(f"Plotting {args}")
# find csv file
# csv = glob.glob(f"train_output/{args.train_out_dir}/*.csv")[0]
folder_input = args.train_out_dir
# load csv file
df = pd.read_csv(f'{folder_input}')
# read csv file
for metric in df:
if not metric == 'epoch':
# if metric == 'loss' or metric == 'val_loss':
plt.plot(df['epoch'], df[metric], label=metric)
# plt.ylim(ymin=0,ymax=0.01)
folder, csvfile = os.path.split(args.train_out_dir)
root, experiment = os.path.split(os.path.split(folder)[0])
plt.title(experiment)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.grid()
plt.legend()
plt.savefig(f"{folder}/{experiment}.png")
plt.clf()
plt.close()
print(folder+".png")
print(f"\nsaved figure to {folder}")
# van yaml inladen 'loss'
# vanuit utils > dict to yaml
# csv viewer extension
# course .venv/bin/activate
# loop over alles en if metric then 1-metric
# kleur codering color=[]