60 lines
1.5 KiB
Python
60 lines
1.5 KiB
Python
|
import matplotlib.pyplot as plt
|
||
|
import pandas as pd
|
||
|
import glob
|
||
|
import argparse
|
||
|
import os
|
||
|
|
||
|
#create parser
|
||
|
def parse_input_args():
|
||
|
parser = argparse.ArgumentParser(description='Parse arguments for training a Reconstruction model')
|
||
|
|
||
|
parser.add_argument('train_out_dir',
|
||
|
type=str,
|
||
|
help='Directory name in train_output dir of the desired experiment folder. There should be a .csv file in this directory with train statistics.')
|
||
|
|
||
|
args = parser.parse_args()
|
||
|
return args
|
||
|
|
||
|
args = parse_input_args()
|
||
|
|
||
|
print(f"Plotting {args}")
|
||
|
|
||
|
# find csv file
|
||
|
# csv = glob.glob(f"train_output/{args.train_out_dir}/*.csv")[0]
|
||
|
folder_input = args.train_out_dir
|
||
|
|
||
|
# load csv file
|
||
|
df = pd.read_csv(f'{folder_input}')
|
||
|
|
||
|
# read csv file
|
||
|
for metric in df:
|
||
|
# if not metric == 'epoch':
|
||
|
if metric == 'loss' or metric == 'val_loss':
|
||
|
plt.plot(df['epoch'], df[metric], label=metric)
|
||
|
plt.ylim(ymin=0,ymax=0.01)
|
||
|
|
||
|
|
||
|
|
||
|
folder, csvfile = os.path.split(args.train_out_dir)
|
||
|
root, experiment = os.path.split(os.path.split(folder)[0])
|
||
|
|
||
|
plt.title(experiment)
|
||
|
plt.xlabel('Epoch')
|
||
|
plt.ylabel('Loss')
|
||
|
plt.grid()
|
||
|
plt.legend()
|
||
|
plt.savefig(f"{folder}/{experiment}.png")
|
||
|
plt.clf()
|
||
|
plt.close()
|
||
|
print(folder+".png")
|
||
|
print(f"\nsaved figure to {folder}")
|
||
|
|
||
|
# van yaml inladen 'loss'
|
||
|
# vanuit utils > dict to yaml
|
||
|
# csv viewer extension
|
||
|
# course .venv/bin/activate
|
||
|
# loop over alles en if metric then 1-metric
|
||
|
# kleur codering color=[]
|
||
|
|
||
|
|