NuMRI/kalman/graphics/figure3.py
2021-09-15 10:45:08 +02:00

257 lines
7.3 KiB
Python

import matplotlib.pyplot as plt
import numpy as np
from itertools import cycle
import argparse
import pickle
import yaml
#import matplotlib.font_manager
from matplotlib import rc
rc('font',**{'family':'sans-serif','sans-serif':['Helvetica']})
rc('text', usetex=True)
def is_ipython():
''' Check if script is run in IPython.
Returns:
bool: True if IPython, else False '''
try:
get_ipython()
ipy = True
except NameError:
ipy = False
return ipy
def load_data(file):
''' Load numpy data from file.
Returns
dict: data dictionary
'''
dat = np.load(file)
return dat
def plot_parameters(dat, input_file, deparameterize=False, ref=None):
''' Plot the parameters in separate subplots with uncertainties.
Args:
dat (dict): data dictionary
deparameterize (bool): flag indicating if parameters should be
deparameterized via 2**theta
ref: reference value to be plotted with parameters
'''
if is_ipython():
plt.ion()
idx_a = input_file.find('/')
idx_b = input_file[idx_a+1::].find('/')
name_file = input_file[idx_a+1:idx_b+idx_a+1]
inputfile_path = 'results/' + name_file + '/input.yaml'
with open(inputfile_path) as file:
inputfile = yaml.full_load(file)
true_values = {
3: 4800,
4: 7200,
5: 11520,
6: 11520,
2: 75
}
true_values_C = {
3: 0.0004,
4: 0.0004,
5: 0.0003,
6: 0.0003,
}
dim = dat['theta'].shape[-1]
meas_flag = False
if dim==7:
RC_flag = True
else:
RC_flag = False
line_split = 1.5
current_val = []
current_val_C = []
ids_type = []
labels = []
ids = []
for bnd_c in inputfile['estimation']['boundary_conditions']:
if 'windkessel' in bnd_c['type']:
for bnd_set in inputfile['boundary_conditions']:
if bnd_c['id'] == bnd_set['id']:
ids.append(bnd_c['id'])
ids_type.append('windkessel')
current_val.append(bnd_set['parameters']['R_d'])
labels.append('$R_' + str(bnd_c['id']-3))
if RC_flag:
current_val_C.append(bnd_set['parameters']['C'])
labels.append('$C_' + str(bnd_c['id']-3))
elif 'dirichlet' in bnd_c['type']:
current_val.append(inputfile['boundary_conditions'][0]['parameters']['U'])
ids.append(bnd_c['id'])
ids_type.append('dirichlet')
labels.append('$U')
fig1, axes1 = plt.subplots(1,1,figsize=(12,7))
if RC_flag:
fig2, axes2 = plt.subplots(1,1,figsize=(12,7))
t = dat['times']
theta = dat['theta']
P = dat['P_theta']
col = cycle(['C0', 'C1', 'C2', 'C3','C4'])
ls = cycle(['-', '-', '--', '--', ':', ':', '-.', '-.'])
legends = cycle(labels)
if meas_flag:
t_und = t[0::30]
t_und = np.append( t_und , [t[-1]])
meas_mark = t_und*0
col_ = next(col)
ls_ = next(ls)
legends_=next(legends)
if dim == 1:
theta = theta.reshape((-1, 1))
P = P.reshape((-1, 1, 1))
idx = 0
idc = 0
for i in range(len(ids)):
cur_key = ids[i]
rec_value = np.round(2**theta[-1, idx]*current_val[i],2)
curve = 2**theta[:, idx]*current_val[i]
std_down = 2**(-np.sqrt(P[:, idx, idx]))*curve
std_up = 2**np.sqrt(P[:, idx, idx])*curve
dash_curve = true_values[ids[i]] + t*0
if ids_type[i] == 'dirichlet':
fig3, axes3 = plt.subplots(1,1,figsize=(12,5))
axes3.plot(t, curve , '-', color=col_,label= legends_ + '= ' + str(rec_value) + '/' + str(true_values[cur_key]) + '$', linewidth = 4)
axes3.fill_between(t, std_down, std_up, alpha=0.3, color=col_)
legends_=next(legends)
axes3.plot(t, dash_curve , color=col_,ls='--' , linewidth = 3)
axes3.set_ylabel(r'$U$',fontsize=36)
axes3.legend(fontsize=36,loc='upper right')
axes3.set_xlim([-0.01,0.81])
axes3.set_xlabel(r'$t (s)$',fontsize=36)
axes3.set_box_aspect(1/4)
plt.xticks(fontsize=28)
plt.yticks(fontsize=28)
plt.savefig('results/' + name_file + '/U.png')
else:
axes1.plot(t, curve , '-', color=col_,label= legends_ + '= ' + str(rec_value) + '/' + str(true_values[cur_key]) + '$', linewidth = 3)
axes1.fill_between(t, std_down, std_up, alpha=0.3, color=col_)
axes1.plot(t, dash_curve , color=col_,ls='--',linewidth = 3)
legends_=next(legends)
if RC_flag:
if i<len(current_val_C):
rec_value_C = np.round(2**theta[-1, idx+1]*current_val_C[idc],6)
curve_C = 2**theta[:, idx+1]*current_val_C[idc]
dash_curve_C = true_values_C[ids[i]] + t*0
std_C_down = 2**(-np.sqrt(P[:, idx+1, idx+1]))*curve_C
std_C_up = 2**np.sqrt(P[:, idx+1, idx+1])*curve_C
axes2.plot(t, curve_C , '-', color=col_,label= legends_ + '= ' + str(rec_value_C) + '/' + str(true_values_C[cur_key]) + '$', linewidth = 3)
axes2.fill_between(t, std_C_down, std_C_up, alpha=0.3, color=col_)
axes2.plot(t, dash_curve_C , color=col_,ls='--',linewidth = 3)
legends_=next(legends)
idx +=1
idc +=1
if meas_flag:
axes1.plot(t_und, meas_mark + line_split*idx, marker = 'x', color='red')
col_ = next(col)
idx +=1
axes1.set_ylabel(r'$R_d$',fontsize=36)
axes1.legend(fontsize=36,loc='upper right')
axes1.set_xlim([-0.01,0.81])
axes1.set_ylim([1700,55000])
axes1.set_box_aspect(1/2)
plt.xticks(fontsize=28)
plt.yticks(fontsize=28)
axes1.set_xlabel(r'$t (s)$',fontsize=36)
plt.savefig('Rd.png')
if RC_flag:
axes2.set_ylabel(r'$C$',fontsize=36)
axes2.legend(fontsize=36,loc='upper right')
axes2.set_xlim([-0.01,0.81])
plt.xticks(fontsize=28)
plt.yticks(fontsize=28)
axes2.set_xlabel(r'$t (s)$',fontsize=36)
fig2.savefig('C.png')
fig1.savefig('results/' + name_file + '/Rd.png')
if not is_ipython():
plt.show()
def get_parser():
parser = argparse.ArgumentParser(
description='''
Plot the time evolution of the ROUKF estimated parameters.
To execute in IPython::
%run plot_roukf_parameters.py [-d] [-r N [N \
...]] file
''',
formatter_class=argparse.RawDescriptionHelpFormatter)
parser.add_argument('file', type=str, help='path to ROUKF stats file')
parser.add_argument('-d', '--deparameterize', action='store_true',
help='deparameterize the parameters by 2**theta')
parser.add_argument('-r', '--ref', metavar='N', nargs='+', default=None,
type=float, help='Reference values for parameters')
return parser
if __name__ == '__main__':
args = get_parser().parse_args()
dat = load_data(args.file)
plot_parameters(dat, args.file,deparameterize=args.deparameterize, ref=args.ref)