adding kalman
This commit is contained in:
		
							
								
								
									
										185
									
								
								kalman/compute_solution_errors.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										185
									
								
								kalman/compute_solution_errors.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,185 @@
 | 
			
		||||
from dolfin import *
 | 
			
		||||
import numpy as np
 | 
			
		||||
from common import inout
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
import argparse
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_indices_glob(path):
 | 
			
		||||
    path_all = list(Path().glob(path.format(i='*')))
 | 
			
		||||
    indices = sorted(int(str(s).split('/')[-2]) for s in path_all)
 | 
			
		||||
    return indices
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def find_checkpoints(options, path_checkpoint):
 | 
			
		||||
 | 
			
		||||
    indices = options['estimation']['measurements'][0]['indices']
 | 
			
		||||
 | 
			
		||||
    # look for u.h5 checkpoints
 | 
			
		||||
    chkpt_root = str(Path(path_checkpoint).joinpath('{i}/u.h5'))
 | 
			
		||||
 | 
			
		||||
    # if indices were given in input file, check if u.h5 checkpoints or X0.h5
 | 
			
		||||
    if indices:
 | 
			
		||||
        if not Path(chkpt_root.format(indices[0])).is_file():
 | 
			
		||||
            chkpt_root = str(Path(path_checkpoint).joinpath('{i}/X0.h5'))
 | 
			
		||||
            if not Path(chkpt_root.format(indices[0])).is_file():
 | 
			
		||||
                raise Exception('No checkpoints found in folder ' + chkpt_root)
 | 
			
		||||
 | 
			
		||||
    else:
 | 
			
		||||
 | 
			
		||||
        # get indices from folder names of u.h5 checkpoints
 | 
			
		||||
        indices = get_indices_glob(chkpt_root)
 | 
			
		||||
 | 
			
		||||
        # if no indices were found, look for X0.h5 files
 | 
			
		||||
        if not indices:
 | 
			
		||||
            chkpt_root = str(Path(path_checkpoint).joinpath('{i}/X0.h5'))
 | 
			
		||||
            indices = get_indices_glob(chkpt_root)
 | 
			
		||||
 | 
			
		||||
            # still no indices? raise Exception
 | 
			
		||||
            if not indices:
 | 
			
		||||
                print(chkpt_root)
 | 
			
		||||
                raise Exception('No checkpoint indices found')
 | 
			
		||||
 | 
			
		||||
    dt = options['timemarching']['dt']
 | 
			
		||||
    times = dt*np.array(indices)
 | 
			
		||||
 | 
			
		||||
    if MPI.rank(MPI.comm_world) == 0:
 | 
			
		||||
        print('indices: \n')
 | 
			
		||||
        print('\t', indices)
 | 
			
		||||
        print('times: \n')
 | 
			
		||||
        print('\t', times)
 | 
			
		||||
 | 
			
		||||
    files = [chkpt_root.format(i=i) for i in indices]
 | 
			
		||||
 | 
			
		||||
    # check if all files are found
 | 
			
		||||
    for f in files:
 | 
			
		||||
        if not Path(f).is_file():
 | 
			
		||||
            raise FileNotFoundError(f)
 | 
			
		||||
 | 
			
		||||
    return indices, times, files
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_h5_fun_name(file):
 | 
			
		||||
    if 'X0.h5' in file:
 | 
			
		||||
        fun = '/X'
 | 
			
		||||
    else:
 | 
			
		||||
        fun = '/u'
 | 
			
		||||
 | 
			
		||||
    return fun
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def filter_indices(ind1, ind2, files1, files2):
 | 
			
		||||
    ''' Filter indices and files such that only maching files and indicies
 | 
			
		||||
    remain.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        ind1 (list):    list of indices of dataset 1
 | 
			
		||||
        ind2 (list):    list of indices of dataset 2
 | 
			
		||||
        files1 (list):  list of files of dataset 1
 | 
			
		||||
        files2 (list):  list of files of dataset 2
 | 
			
		||||
    '''
 | 
			
		||||
    ind = []
 | 
			
		||||
    files_filt1 = []
 | 
			
		||||
    files_filt2 = []
 | 
			
		||||
    for i in ind1:
 | 
			
		||||
        if i in ind2:
 | 
			
		||||
            ind.append(i)
 | 
			
		||||
            files_filt1.append(files1[list(ind1).index(i)])
 | 
			
		||||
            files_filt2.append(files2[list(ind2).index(i)])
 | 
			
		||||
 | 
			
		||||
    return ind, files_filt1, files_filt2
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def compute_errors(inputfile, path_checkpoint_1,
 | 
			
		||||
                   path_checkpoint_2, relative=False):
 | 
			
		||||
 | 
			
		||||
    options = inout.read_parameters(inputfile)
 | 
			
		||||
 | 
			
		||||
    indices1, times1, files1 = find_checkpoints(options, path_checkpoint_1)
 | 
			
		||||
    indices2, times2, files2 = find_checkpoints(options, path_checkpoint_2)
 | 
			
		||||
 | 
			
		||||
    indices, files1, files2 = filter_indices(indices1, indices2, files1,
 | 
			
		||||
                                             files2)
 | 
			
		||||
 | 
			
		||||
    fun1 = get_h5_fun_name(files1[0])
 | 
			
		||||
    fun2 = get_h5_fun_name(files2[0])
 | 
			
		||||
 | 
			
		||||
    # assert np.allclose(indices1, indices2), 'Indices do not match!'
 | 
			
		||||
    # assert np.allclose(times1, times2), 'Time stamps do not match!'
 | 
			
		||||
 | 
			
		||||
    mesh, _, _ = inout.read_mesh(options['mesh'])
 | 
			
		||||
 | 
			
		||||
    if 'fluid' in options:
 | 
			
		||||
        assert options['fem']['velocity_space'] in ('p1', 'p2'), (
 | 
			
		||||
            'velocity space not supported, use p1 or p2')
 | 
			
		||||
        deg = int(options['fem']['velocity_space'][-1])
 | 
			
		||||
 | 
			
		||||
    elif 'material' in options:
 | 
			
		||||
        deg = options['solver']['fe_degree']
 | 
			
		||||
 | 
			
		||||
    V = VectorFunctionSpace(mesh, 'P', deg)
 | 
			
		||||
 | 
			
		||||
    u1 = Function(V)
 | 
			
		||||
    u2 = Function(V)
 | 
			
		||||
 | 
			
		||||
    err_l2 = []
 | 
			
		||||
    err_linf = []
 | 
			
		||||
 | 
			
		||||
    for i, (f1, f2) in enumerate(zip(files1, files2)):
 | 
			
		||||
        # file_ref = str(Path(path_fwd_tentative_checkpoint).joinpath(
 | 
			
		||||
        #     '{i}/u.h5'.format(i=i)))
 | 
			
		||||
        # file_roukf = str(Path(path_roukf_state_checkpoint).joinpath(
 | 
			
		||||
        #     '{i}/X0.h5'.format(i=i)))
 | 
			
		||||
 | 
			
		||||
        t0 = inout.read_HDF5_data(V.mesh().mpi_comm(), f1, u1, fun1)
 | 
			
		||||
        t1 = inout.read_HDF5_data(V.mesh().mpi_comm(), f2, u2, fun2)
 | 
			
		||||
 | 
			
		||||
        assert np.allclose(t0, t1), ('Timestamps do not match! {} vs {} '
 | 
			
		||||
                                     '(HDF5 files)'.format(t0, t1))
 | 
			
		||||
 | 
			
		||||
        if relative:
 | 
			
		||||
            u_l2 = norm(u1, 'l2')
 | 
			
		||||
            u_linf = norm(u1.vector(), 'linf')
 | 
			
		||||
            if u_l2 == 0:
 | 
			
		||||
                u_l2 = 1
 | 
			
		||||
                print('i = {} \t norm(u1) == 0, do not normalize!'.format(i))
 | 
			
		||||
            if u_linf == 0:
 | 
			
		||||
                u_linf = 1
 | 
			
		||||
                print('i = {} \t max(u1) == 0, do not normalize!'.format(i))
 | 
			
		||||
        else:
 | 
			
		||||
            u_l2 = u_linf = 1
 | 
			
		||||
 | 
			
		||||
        err_l2.append(errornorm(u1, u2, 'l2', degree_rise=0)/u_l2)
 | 
			
		||||
        err_linf.append(norm(u1.vector() - u2.vector(), 'linf')/u_linf)
 | 
			
		||||
 | 
			
		||||
        print('i = {} \t L2 error: {} \t Linf error: {}'.format(i, err_l2[-1],
 | 
			
		||||
                                                                err_linf[-1]))
 | 
			
		||||
 | 
			
		||||
    print('max L2 error:   {}'.format(max(err_l2)))
 | 
			
		||||
    print('max Linf error: {}'.format(max(err_linf)))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_parser():
 | 
			
		||||
    parser = argparse.ArgumentParser(
 | 
			
		||||
        description='''
 | 
			
		||||
Compute errors between vector function checkpoints of two simulations.
 | 
			
		||||
 | 
			
		||||
Can be ROUKF checkpoints, but if forward and ROUKF
 | 
			
		||||
files are found in the same checkpoint folder (0/u.h5 and 0/X.5),
 | 
			
		||||
the forward file will be preferred.''',
 | 
			
		||||
        formatter_class=argparse.RawDescriptionHelpFormatter)
 | 
			
		||||
    parser.add_argument('inputfile', type=str, help='path to yaml input file')
 | 
			
		||||
    parser.add_argument('path_checkpoint_1', type=str,
 | 
			
		||||
                        help='Path to checkpoints of simulation 1')
 | 
			
		||||
    parser.add_argument('path_checkpoint_2', type=str,
 | 
			
		||||
                        help='Path to checkpoints of simulation 2')
 | 
			
		||||
    parser.add_argument('-r', '--relative', action='store_true',
 | 
			
		||||
                        help='compute relative errors')
 | 
			
		||||
    return parser
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
    args = get_parser().parse_args()
 | 
			
		||||
 | 
			
		||||
    compute_errors(args.inputfile, args.path_checkpoint_1,
 | 
			
		||||
                   args.path_checkpoint_2, relative=args.relative)
 | 
			
		||||
		Reference in New Issue
	
	Block a user