You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

247 lines
6.9 KiB

import matplotlib.pyplot as plt
import numpy as np
from itertools import cycle
import argparse
import pickle
import yaml
from matplotlib import rc
rc('font',**{'family':'sans-serif','sans-serif':['Helvetica']})
rc('text', usetex=True)
import matplotlib.font_manager
def is_ipython():
''' Check if script is run in IPython.
Returns:
bool: True if IPython, else False '''
try:
get_ipython()
ipy = True
except NameError:
ipy = False
return ipy
def load_data(file):
''' Load numpy data from file.
Returns
dict: data dictionary
'''
dat = np.load(file)
return dat
def plot_parameters(dat, input_file, deparameterize=False, ref=None):
''' Plot the parameters in separate subplots with uncertainties.
Args:
dat (dict): data dictionary
deparameterize (bool): flag indicating if parameters should be
deparameterized via 2**theta
ref: reference value to be plotted with parameters
'''
if is_ipython():
plt.ion()
idx_a = input_file.find('/')
idx_b = input_file[idx_a+1::].find('/')
name_file = input_file[idx_a+1:idx_b+idx_a+1]
inputfile_path = 'results/' + name_file + '/input.yaml'
with open(inputfile_path) as file:
inputfile = yaml.full_load(file)
true_values = {
3: 4800,
4: 7200,
5: 11520,
6: 11520,
2: 75
}
true_values_C = {
3: 0.0004,
4: 0.0004,
5: 0.0003,
6: 0.0003,
}
dim = dat['theta'].shape[-1]
meas_flag = False
if dim==7:
RC_flag = True
else:
RC_flag = False
line_split = 1.5
current_val = []
current_val_C = []
ids_type = []
labels = []
ids = []
for bnd_c in inputfile['estimation']['boundary_conditions']:
if 'windkessel' in bnd_c['type']:
for bnd_set in inputfile['boundary_conditions']:
if bnd_c['id'] == bnd_set['id']:
ids.append(bnd_c['id'])
ids_type.append('windkessel')
current_val.append(bnd_set['parameters']['R_d'])
labels.append('$R_' + str(bnd_c['id']))
if RC_flag:
current_val_C.append(bnd_set['parameters']['C'])
labels.append('$C_' + str(bnd_c['id']))
elif 'dirichlet' in bnd_c['type']:
current_val.append(inputfile['boundary_conditions'][0]['parameters']['U'])
ids.append(bnd_c['id'])
ids_type.append('dirichlet')
labels.append('$U')
fig1, axes1 = plt.subplots(1,1,figsize=(12,6))
if RC_flag:
fig2, axes2 = plt.subplots(1,1,figsize=(12,6))
t = dat['times']
theta = dat['theta']
P = dat['P_theta']
col = cycle(['C0', 'C1', 'C2', 'C3','C4'])
ls = cycle(['-', '-', '--', '--', ':', ':', '-.', '-.'])
legends = cycle(labels)
if meas_flag:
t_und = t[0::30]
t_und = np.append( t_und , [t[-1]])
meas_mark = t_und*0
col_ = next(col)
ls_ = next(ls)
legends_=next(legends)
if dim == 1:
theta = theta.reshape((-1, 1))
P = P.reshape((-1, 1, 1))
idx = 0
idc = 0
for i in range(len(ids)):
cur_key = ids[i]
true_level = np.log(true_values[ids[i]]/current_val[i])/np.log(2)
rec_value = np.round(2**theta[-1, idx]*current_val[i],2)
#curve = theta[:,idx] + line_split*idx - true_level
#dash_curve = line_split*idx + t*0
curve = 2**theta[:, idx]*current_val[i]
std_down = 2**(-np.sqrt(P[:, idx, idx]))*curve
std_up = 2**np.sqrt(P[:, idx, idx])*curve
dash_curve = true_values[ids[i]] + t*0
if ids_type[i] == 'dirichlet':
pass
#axes3.plot(t, curve , '-', color=col_,label= legends_ + '= ' + str(rec_value) + '/' + str(true_values[cur_key]) + '$')
#axes3.fill_between(t, curve - np.sqrt(P[:, idx, idx]), curve + np.sqrt(P[:, idx, idx]), alpha=0.3, color=col_)
#legends_=next(legends)
#axes3.plot(t, dash_curve , color=col_,ls='--')
else:
axes1.plot(t, curve , '-', color=col_,label= legends_ + '= ' + str(rec_value) + '/' + str(true_values[cur_key]) + '$', linewidth = 2)
axes1.fill_between(t, std_down, std_up, alpha=0.3, color=col_)
axes1.plot(t, dash_curve , color=col_,ls='--')
legends_=next(legends)
if RC_flag:
if i<len(current_val_C):
true_level_C = np.log(true_values_C[ids[i]]/current_val_C[i])/np.log(2)
rec_value_C = np.round(2**theta[-1, idc]*current_val_C[idc],6)
curve_C = 2**theta[:, idx+1]*current_val_C[idc]
dash_curve_C = true_values_C[ids[i]] + t*0
std_C_down = 2**(-np.sqrt(P[:, idx+1, idx+1]))*curve_C
std_C_up = 2**np.sqrt(P[:, idx+1, idx+1])*curve_C
axes2.plot(t, curve_C , '-', color=col_,label= legends_ + '= ' + str(rec_value_C) + '/' + str(true_values_C[cur_key]) + '$', linewidth = 2)
axes2.fill_between(t, std_C_down, std_C_up, alpha=0.3, color=col_)
axes2.plot(t, dash_curve_C , color=col_,ls='--')
legends_=next(legends)
idx +=1
idc +=1
if meas_flag:
axes1.plot(t_und, meas_mark + line_split*idx, marker = 'x', color='red')
col_ = next(col)
idx +=1
axes1.set_ylabel(r'$R_d$',fontsize=22)
axes1.legend(fontsize=18,loc='upper right')
axes1.set_xlim([-0.01,0.81])
axes1.set_ylim([1700,35000])
axes1.set_xlabel(r'$t (s)$',fontsize=22)
plt.savefig('C.png')
if RC_flag:
axes2.set_ylabel(r'$C$',fontsize=22)
axes2.legend(fontsize=18,loc='upper right')
axes2.set_xlim([-0.01,0.81])
axes2.set_xlabel(r'$t (s)$',fontsize=22)
fig2.savefig('C.png')
fig1.savefig('Rd.png')
if not is_ipython():
plt.show()
def get_parser():
parser = argparse.ArgumentParser(
description='''
Plot the time evolution of the ROUKF estimated parameters.
To execute in IPython::
%run plot_roukf_parameters.py [-d] [-r N [N \
...]] file
''',
formatter_class=argparse.RawDescriptionHelpFormatter)
parser.add_argument('file', type=str, help='path to ROUKF stats file')
parser.add_argument('-d', '--deparameterize', action='store_true',
help='deparameterize the parameters by 2**theta')
parser.add_argument('-r', '--ref', metavar='N', nargs='+', default=None,
type=float, help='Reference values for parameters')
return parser
if __name__ == '__main__':
args = get_parser().parse_args()
dat = load_data(args.file)
plot_parameters(dat, args.file,deparameterize=args.deparameterize, ref=args.ref)