Explainable-Federated-Learn.../Explaining_the_prediction_v...

120 KiB

In [1]:
#IMPORTS

import numpy as np
import random
import tensorflow as tf
import tensorflow.keras as kr
import tensorflow.keras.backend as K
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, Flatten, Dense
from tensorflow.keras.datasets import mnist
import os
import csv

from scipy.spatial.distance import euclidean
from sklearn.metrics import confusion_matrix

from time import sleep
from tqdm import tqdm

import copy
import numpy
from sklearn.datasets import make_classification
from sklearn.ensemble import RandomForestClassifier
import pandas as pd
import matplotlib.pyplot as plt
import math
import seaborn as sns
from numpy.random import RandomState
import scipy as scp
from sklearn.model_selection import train_test_split
from sklearn.compose import ColumnTransformer
from sklearn.preprocessing import OneHotEncoder, LabelEncoder
from keras.models import Sequential
from keras.layers import Dense
from keras import optimizers
from keras.callbacks import EarlyStopping,ModelCheckpoint
from keras.utils import to_categorical
from keras import backend as K
from itertools import product
from sklearn.metrics import accuracy_score
from sklearn.metrics import precision_score
from sklearn.metrics import recall_score
from sklearn.metrics import f1_score
from sklearn.metrics import roc_auc_score
from sklearn.metrics import confusion_matrix

from sklearn import mixture

from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
%matplotlib inline
Using TensorFlow backend.
In [2]:
# Enter here the data set you want to explain (adult, activity, or synthatic)

data_set = 'adult'

# Enter here the numb er of peers you want in the experiments

n_peers = 100
In [3]:
# the random state we will use in the experiments. It can be changed 

rs = RandomState(92)
In [4]:
# preprocessing adults data set

if data_set == 'adult':
    #Load dataset into a pandas DataFrame
    adult_data = pd.read_csv('adult_data.csv', na_values='?')
    # Drop all records with missing values
    adult_data.dropna(inplace=True)
    adult_data.reset_index(drop=True, inplace=True)

    # Drop fnlwgt, not interesting for ML
    adult_data.drop('fnlwgt', axis=1, inplace=True)
    adult_data.drop('education', axis=1, inplace=True)

#     merging some similar features.
    adult_data['marital-status'].replace('Married-civ-spouse', 'Married', inplace=True)
    adult_data['marital-status'].replace('Divorced', 'Unmarried', inplace=True)
    adult_data['marital-status'].replace('Never-married', 'Unmarried', inplace=True)
    adult_data['marital-status'].replace('Separated', 'Unmarried', inplace=True)
    adult_data['marital-status'].replace('Widowed', 'Unmarried', inplace=True)
    adult_data['marital-status'].replace('Married-spouse-absent', 'Married', inplace=True)
    adult_data['marital-status'].replace('Married-AF-spouse', 'Married', inplace=True)
    
    adult_data = pd.concat([adult_data,pd.get_dummies(adult_data['income'], prefix='income')],axis=1)
    adult_data.drop('income', axis=1, inplace=True)
    obj_columns = adult_data.select_dtypes(['object']).columns
    adult_data[obj_columns] = adult_data[obj_columns].astype('category')
    # Convert numerics to floats and normalize
    num_columns = adult_data.select_dtypes(['int64']).columns
    adult_data[num_columns] = adult_data[num_columns].astype('float64')
    for c in num_columns:
        #adult[c] -= adult[c].mean()
        #adult[c] /= adult[c].std()
         adult_data[c] = (adult_data[c] - adult_data[c].min()) / (adult_data[c].max()-adult_data[c].min())
    # 'workclass', 'marital-status', 'occupation', 'relationship' ,'race', 'gender', 'native-country'
    # adult_data['income'] = adult_data['income'].cat.codes
    adult_data['marital-status'] = adult_data['marital-status'].cat.codes
    adult_data['occupation'] = adult_data['occupation'].cat.codes
    adult_data['relationship'] = adult_data['relationship'].cat.codes
    adult_data['race'] = adult_data['race'].cat.codes
    adult_data['gender'] = adult_data['gender'].cat.codes
    adult_data['native-country'] = adult_data['native-country'].cat.codes
    adult_data['workclass'] = adult_data['workclass'].cat.codes

    num_columns = adult_data.select_dtypes(['int8']).columns
    adult_data[num_columns] = adult_data[num_columns].astype('float64')
    for c in num_columns:
        #adult[c] -= adult[c].mean()
        #adult[c] /= adult[c].std()
         adult_data[c] = (adult_data[c] - adult_data[c].min()) / (adult_data[c].max()-adult_data[c].min())
    display(adult_data.info())
    display(adult_data.head(10))
    
    adult_data = adult_data.to_numpy()
    
#     splite the data to train and test datasets
    X_train, X_test, y_train, y_test = train_test_split(adult_data[:,:-2],adult_data[:,-2:], test_size=0.07, random_state=rs)
#     the names of the features
    names = ['age','workclass','educational-num','marital-status','occupation',
         'relationship','race','gender','capital-gain','capital-loss','hours-per-week','native-country']
    Features_number = len(X_train[0])
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 45222 entries, 0 to 45221
Data columns (total 14 columns):
age                45222 non-null float64
workclass          45222 non-null float64
educational-num    45222 non-null float64
marital-status     45222 non-null float64
occupation         45222 non-null float64
relationship       45222 non-null float64
race               45222 non-null float64
gender             45222 non-null float64
capital-gain       45222 non-null float64
capital-loss       45222 non-null float64
hours-per-week     45222 non-null float64
native-country     45222 non-null float64
income_<=50K       45222 non-null uint8
income_>50K        45222 non-null uint8
dtypes: float64(12), uint8(2)
memory usage: 4.2 MB
None
age workclass educational-num marital-status occupation relationship race gender capital-gain capital-loss hours-per-week native-country income_<=50K income_>50K
0 0.109589 0.333333 0.400000 1.0 0.461538 0.6 0.5 1.0 0.000000 0.0 0.397959 0.95 1 0
1 0.287671 0.333333 0.533333 0.0 0.307692 0.0 1.0 1.0 0.000000 0.0 0.500000 0.95 1 0
2 0.150685 0.166667 0.733333 0.0 0.769231 0.0 1.0 1.0 0.000000 0.0 0.397959 0.95 0 1
3 0.369863 0.333333 0.600000 0.0 0.461538 0.0 0.5 1.0 0.076881 0.0 0.397959 0.95 0 1
4 0.232877 0.333333 0.333333 1.0 0.538462 0.2 1.0 1.0 0.000000 0.0 0.295918 0.95 1 0
5 0.630137 0.666667 0.933333 0.0 0.692308 0.0 1.0 1.0 0.031030 0.0 0.316327 0.95 0 1
6 0.095890 0.333333 0.600000 1.0 0.538462 0.8 1.0 0.0 0.000000 0.0 0.397959 0.95 1 0
7 0.520548 0.333333 0.200000 0.0 0.153846 0.0 1.0 1.0 0.000000 0.0 0.091837 0.95 1 0
8 0.657534 0.333333 0.533333 0.0 0.461538 0.0 1.0 1.0 0.064181 0.0 0.397959 0.95 0 1
9 0.260274 0.000000 0.800000 0.0 0.000000 0.0 1.0 1.0 0.000000 0.0 0.397959 0.95 1 0
In [5]:
if data_set == 'synthatic':
    #generate the data
    X, y = make_classification(n_samples=1000000, n_features=10, n_redundant=3, n_repeated=2, #n_classes=3, 
                           n_informative=5, n_clusters_per_class=4, 
                           random_state=42)
    y = pd.DataFrame(data=y, columns=["y"])
    y = pd.get_dummies(y['y'], prefix='y')
    y = y.to_numpy()
    X_train, X_test, y_train, y_test = train_test_split(X,y, test_size=0.07, random_state=rs)
    #     the names of the features
    names = ['X(0)','X(1)','X(2)','X(3)','X(4)','X(5)','X(6)','X(7)','X(8)','X(9)']
    Features_number = len(X_train[0])
In [6]:
if data_set == 'activity':
    #Load dataset into a pandas DataFrame
    activity = pd.read_csv("activity_3_original.csv", sep=',')
#      drop some features that have non value in the majority of the samples
    to_drop = ['subject', 'timestamp', 'heart_rate','activityID']
    activity.drop(axis=1, columns=to_drop, inplace=True)
#     prepare the truth
    activity = pd.concat([activity,pd.get_dummies(activity['motion'], prefix='motion')],axis=1)
    activity.drop('motion', axis=1, inplace=True)
    class_label = [ 'motion_n', 'motion_y']
    predictors = [a for a in activity.columns.values if a not in class_label]

    for p in predictors:
        activity[p].fillna(activity[p].mean(), inplace=True)

    display(predictors)
    for p in predictors:
        activity[p] = (activity[p]-activity[p].min()) / (activity[p].max() - activity[p].min())
        activity[p].astype('float32')
    activity = activity.to_numpy()
    X_train, X_test, y_train, y_test = train_test_split(activity[:,:-2],activity[:,-2:], test_size=0.07, random_state=rs)
    #     the names of the features
    names = ['temp_hand','acceleration_16_x_hand',
        'acceleration_16_y_hand','acceleration_16_z_hand','acceleration_6_x_hand',
        'acceleration_6_y_hand','acceleration_6_z_hand','gyroscope_x_hand','gyroscope_y_hand',
        'gyroscope_z_hand','magnetometer_x_hand','magnetometer_y_hand','magnetometer_z_hand',
        'temp_chest','acceleration_16_x_chest','acceleration_16_y_chest','acceleration_16_z_chest','acceleration_6_x_chest',
        'acceleration_6_y_chest','acceleration_6_z_chest','gyroscope_x_chest','gyroscope_y_chest','gyroscope_z_chest',
        'magnetometer_x_chest','magnetometer_y_chest','magnetometer_z_chest','temp_ankle','acceleration_16_x_ankle',
        'acceleration_16_y_ankle','acceleration_16_z_ankle','acceleration_6_x_ankle','acceleration_6_y_ankle',
        'acceleration_6_z_ankle','gyroscope_x_ankle','gyroscope_y_ankle','gyroscope_z_ankle','magnetometer_x_ankle',
        'magnetometer_y_ankle','magnetometer_z_ankle']
    Features_number = len(X_train[0])
In [7]:
#begin federated

earlystopping = EarlyStopping(monitor = 'val_loss',
                              min_delta = 0.01,
                              patience = 50,
                              verbose = 1,
                              baseline = 2,
                              restore_best_weights = True)

checkpoint = ModelCheckpoint('test.h8',
                             monitor='val_loss',
                             mode='min',
                             save_best_only=True,
                             verbose=1)
    
model = Sequential()
model.add(Dense(70, input_dim=Features_number, activation='relu'))
model.add(Dense(50, activation='relu'))
model.add(Dense(50, activation='relu'))
model.add(Dense(2, activation='softmax'))
model.compile(loss='categorical_crossentropy', optimizer='adam',metrics=['accuracy'])
history = model.fit(X_train, y_train,
epochs=2,
validation_data=(X_test, y_test),
callbacks = [checkpoint, earlystopping],
shuffle=True)
Train on 42056 samples, validate on 3166 samples
Epoch 1/2
42056/42056 [==============================] - ETA: 2:15 - loss: 0.7931 - accuracy: 0.28 - ETA: 5s - loss: 0.5651 - accuracy: 0.6988 - ETA: 3s - loss: 0.5165 - accuracy: 0.72 - ETA: 2s - loss: 0.4896 - accuracy: 0.74 - ETA: 2s - loss: 0.4646 - accuracy: 0.76 - ETA: 2s - loss: 0.4489 - accuracy: 0.77 - ETA: 1s - loss: 0.4412 - accuracy: 0.77 - ETA: 1s - loss: 0.4316 - accuracy: 0.78 - ETA: 1s - loss: 0.4258 - accuracy: 0.78 - ETA: 1s - loss: 0.4224 - accuracy: 0.79 - ETA: 1s - loss: 0.4158 - accuracy: 0.79 - ETA: 1s - loss: 0.4118 - accuracy: 0.79 - ETA: 1s - loss: 0.4073 - accuracy: 0.80 - ETA: 1s - loss: 0.4052 - accuracy: 0.80 - ETA: 1s - loss: 0.4025 - accuracy: 0.80 - ETA: 1s - loss: 0.3992 - accuracy: 0.80 - ETA: 1s - loss: 0.3983 - accuracy: 0.80 - ETA: 0s - loss: 0.3964 - accuracy: 0.80 - ETA: 0s - loss: 0.3930 - accuracy: 0.80 - ETA: 0s - loss: 0.3893 - accuracy: 0.81 - ETA: 0s - loss: 0.3881 - accuracy: 0.81 - ETA: 0s - loss: 0.3873 - accuracy: 0.81 - ETA: 0s - loss: 0.3857 - accuracy: 0.81 - ETA: 0s - loss: 0.3831 - accuracy: 0.81 - ETA: 0s - loss: 0.3811 - accuracy: 0.81 - ETA: 0s - loss: 0.3793 - accuracy: 0.81 - ETA: 0s - loss: 0.3784 - accuracy: 0.81 - ETA: 0s - loss: 0.3764 - accuracy: 0.81 - ETA: 0s - loss: 0.3756 - accuracy: 0.81 - ETA: 0s - loss: 0.3740 - accuracy: 0.82 - ETA: 0s - loss: 0.3726 - accuracy: 0.82 - 2s 41us/step - loss: 0.3720 - accuracy: 0.8217 - val_loss: 0.3452 - val_accuracy: 0.8370

Epoch 00001: val_loss improved from inf to 0.34516, saving model to test.h8
Epoch 2/2
42056/42056 [==============================] - ETA: 3s - loss: 0.5095 - accuracy: 0.84 - ETA: 1s - loss: 0.3444 - accuracy: 0.83 - ETA: 1s - loss: 0.3407 - accuracy: 0.84 - ETA: 1s - loss: 0.3383 - accuracy: 0.83 - ETA: 1s - loss: 0.3365 - accuracy: 0.84 - ETA: 1s - loss: 0.3374 - accuracy: 0.84 - ETA: 1s - loss: 0.3386 - accuracy: 0.84 - ETA: 1s - loss: 0.3357 - accuracy: 0.84 - ETA: 1s - loss: 0.3364 - accuracy: 0.84 - ETA: 1s - loss: 0.3353 - accuracy: 0.84 - ETA: 1s - loss: 0.3351 - accuracy: 0.84 - ETA: 1s - loss: 0.3368 - accuracy: 0.83 - ETA: 1s - loss: 0.3381 - accuracy: 0.83 - ETA: 1s - loss: 0.3396 - accuracy: 0.83 - ETA: 0s - loss: 0.3388 - accuracy: 0.83 - ETA: 0s - loss: 0.3384 - accuracy: 0.83 - ETA: 0s - loss: 0.3393 - accuracy: 0.83 - ETA: 0s - loss: 0.3390 - accuracy: 0.83 - ETA: 0s - loss: 0.3392 - accuracy: 0.83 - ETA: 0s - loss: 0.3400 - accuracy: 0.83 - ETA: 0s - loss: 0.3400 - accuracy: 0.83 - ETA: 0s - loss: 0.3410 - accuracy: 0.83 - ETA: 0s - loss: 0.3406 - accuracy: 0.83 - ETA: 0s - loss: 0.3415 - accuracy: 0.83 - ETA: 0s - loss: 0.3419 - accuracy: 0.83 - ETA: 0s - loss: 0.3421 - accuracy: 0.83 - ETA: 0s - loss: 0.3417 - accuracy: 0.83 - ETA: 0s - loss: 0.3410 - accuracy: 0.83 - ETA: 0s - loss: 0.3415 - accuracy: 0.83 - ETA: 0s - loss: 0.3419 - accuracy: 0.83 - ETA: 0s - loss: 0.3420 - accuracy: 0.83 - ETA: 0s - loss: 0.3411 - accuracy: 0.83 - ETA: 0s - loss: 0.3413 - accuracy: 0.83 - 2s 41us/step - loss: 0.3412 - accuracy: 0.8375 - val_loss: 0.3417 - val_accuracy: 0.8370

Epoch 00002: val_loss improved from 0.34516 to 0.34170, saving model to test.h8
In [8]:
#AUXILIARY METHODS FOR FEDERATED LEARNING

# RETURN INDICES TO LAYERS WITH WEIGHTS AND BIASES
def trainable_layers(model):
    return [i for i, layer in enumerate(model.layers) if len(layer.get_weights()) > 0]

# RETURN WEIGHTS AND BIASES OF A MODEL
def get_parameters(model):
    weights = []
    biases = []
    index = trainable_layers(model)
    for i in index:
        weights.append(copy.deepcopy(model.layers[i].get_weights()[0]))
        biases.append(copy.deepcopy(model.layers[i].get_weights()[1]))           
    
    return weights, biases
        
# SET WEIGHTS AND BIASES OF A MODEL
def set_parameters(model, weights, biases):
    index = trainable_layers(model)
    for i, j in enumerate(index):
        model.layers[j].set_weights([weights[i], biases[i]])
    
# DEPRECATED: RETURN THE GRADIENTS OF THE MODEL AFTER AN UPDATE 
def get_gradients(model, inputs, outputs):
    """ Gets gradient of model for given inputs and outputs for all weights"""
    grads = model.optimizer.get_gradients(model.total_loss, model.trainable_weights)
    symb_inputs = (model._feed_inputs + model._feed_targets + model._feed_sample_weights)
    f = K.function(symb_inputs, grads)
    x, y, sample_weight = model._standardize_user_data(inputs, outputs)
    output_grad = f(x + y + sample_weight)
    
    w_grad = [w for i,w in enumerate(output_grad) if i%2==0]
    b_grad = [w for i,w in enumerate(output_grad) if i%2==1]
    
    return w_grad, b_grad

# RETURN THE DIFFERENCE OF MODELS' WEIGHTS AND BIASES AFTER AN UPDATE 
# NOTE: LEARNING RATE IS APPLIED, SO THE UPDATE IS DIFFERENT FROM THE
# GRADIENTS. IN CASE VANILLA SGD IS USED, THE GRADIENTS ARE OBTAINED
# AS (UPDATES / LEARNING_RATE)
def get_updates(model, inputs, outputs, batch_size, epochs):
    w, b = get_parameters(model)
    #model.train_on_batch(inputs, outputs)
    model.fit(inputs, outputs, batch_size=batch_size, epochs=epochs, verbose=0)
    w_new, b_new = get_parameters(model)
    
    weight_updates = [old - new for old,new in zip(w, w_new)]
    bias_updates = [old - new for old,new in zip(b, b_new)]
    
    return weight_updates, bias_updates

# UPDATE THE MODEL'S WEIGHTS AND PARAMETERS WITH AN UPDATE
def apply_updates(model, eta, w_new, b_new):
    w, b = get_parameters(model)
    new_weights = [theta - eta*delta for theta,delta in zip(w, w_new)]
    new_biases = [theta - eta*delta for theta,delta in zip(b, b_new)]
    set_parameters(model, new_weights, new_biases)
    
# FEDERATED AGGREGATION FUNCTION
def aggregate(n_layers, n_peers, f, w_updates, b_updates):
    agg_w = [f([w_updates[j][i] for j in range(n_peers)], axis=0) for i in range(n_layers)]
    agg_b = [f([b_updates[j][i] for j in range(n_peers)], axis=0) for i in range(n_layers)]
    return agg_w, agg_b

# SOLVE NANS
def nans_to_zero(W, B):
    W0 = [np.nan_to_num(w, nan=0.0, posinf=0.0, neginf=0.0) for w in W]
    B0 = [np.nan_to_num(b, nan=0.0, posinf=0.0, neginf=0.0) for b in B]
    return W0, B0

def build_forest(X,y):
    clf=RandomForestClassifier(n_estimators=1000, max_depth=7, random_state=0, verbose = 1)
    clf.fit(X,y)
    return clf

# COMPUTE EUCLIDEAN DISTANCE OF WEIGHTS
def dist_weights(w_a, w_b):
    wf_a = flatten_weights(w_a)
    wf_b = flatten_weights(w_b)
    return euclidean(wf_a, wf_b)

# TRANSFORM ALL WEIGHT TENSORS TO 1D ARRAY
def flatten_weights(w_in):
    h = w_in[0].reshape(-1)
    for w in w_in[1:]:
        h = np.append(h, w.reshape(-1))
    return h
    
In [9]:
# scan the forest for trees maches the wrong predictions of the black-box
def scan_wrong(forest_predictions, FL_predict1, forest , y_test_local, X_test_local):
    sum_feature_improtance= 0
    overal_wrong_feature_importance = 0
    counter = 0
    second_counter = 0
    never_seen = 0
    avr_wrong_importance = 0
    FL_predict1 = np.argmax(FL_predict1, axis=1)
    forest_predictions = np.argmax(forest_predictions, axis=1)
    y_test_local = np.argmax(y_test_local, axis=1)
    for i in range (len(FL_predict1)):
        i_tree = 0
#         if the black-box got a wrong prediction
        if (FL_predict1[i] != y_test_local[i]):
#         getting the prediction of the trees one by one
            for tree_in_forest in forest.estimators_:
                sample = X_test_local[i].reshape(1, -1)
                temp = forest.estimators_[i_tree].predict(sample)
                temp =  np.argmax(temp, axis=1)
                i_tree = i_tree + 1
#  if the prediction of the tree maches the predictions of the black-box
                if(FL_predict1[i] == temp):
#         getting the features importances
                    sum_feature_improtance = sum_feature_improtance + tree_in_forest.feature_importances_
                    counter = counter + 1
#         if we have trees maches the black-box predictions
        if(counter>0):
            ave_feature_importence = sum_feature_improtance/counter
            overal_wrong_feature_importance = ave_feature_importence + overal_wrong_feature_importance
            second_counter = second_counter + 1
            counter = 0
            sum_feature_improtance = 0
#             if there is no trees maches the black-box predictions
        else:
            if(FL_predict1[i] != y_test_local[i]):
                never_seen = never_seen +1
#                 getting the average features importances for all the samples that had wrong predictions.
    if(second_counter>0):
        avr_wrong_importance = overal_wrong_feature_importance / second_counter
    return forest.feature_importances_
In [10]:
trainable_layers(model)
Out[10]:
[0, 1, 2, 3]
In [11]:
get_parameters(model)
Out[11]:
([array([[ 1.39432400e-01,  8.84631574e-02, -4.47415888e-01,
           1.23670131e-01, -2.65049934e-01,  2.56673127e-01,
           2.82177985e-01, -3.88451487e-01, -8.48813355e-02,
          -4.55360711e-01, -2.55180508e-01, -1.34169891e-01,
          -4.19932574e-01,  9.50885192e-02, -4.10533138e-02,
           1.23161055e-01, -3.34913731e-01, -3.29331495e-02,
          -2.09537312e-01,  2.89370805e-01, -2.42182449e-01,
           9.41318497e-02, -8.54814351e-02, -2.53278345e-01,
           7.38841221e-02,  9.76254940e-02,  9.64644551e-03,
           4.62163612e-02,  1.47847623e-01,  3.28071006e-02,
          -2.16738522e-01, -5.52587435e-02, -1.01704948e-01,
           2.31297538e-01, -3.01694840e-01,  2.23755836e-02,
          -2.37541839e-01, -8.33741352e-02, -3.33046556e-01,
          -3.82800475e-02, -2.60576427e-01,  1.35413051e-01,
           5.84374070e-02, -1.67372033e-01,  7.50956163e-02,
          -2.44477212e-01,  4.34608996e-01,  1.95100769e-01,
          -1.71157598e-01,  2.94538945e-01, -2.78368771e-01,
           3.23733628e-01,  7.93107301e-02,  2.28328109e-01,
           6.06352724e-02, -7.03767091e-02,  1.33410409e-01,
           1.21751621e-01,  1.97286800e-01, -9.21699479e-02,
          -3.15490931e-01,  2.30563477e-01, -3.28507647e-02,
           8.77456143e-02,  5.48780151e-02, -4.60406430e-02,
          -1.89183086e-01,  3.93763036e-02,  2.96199113e-01,
          -2.79987492e-02],
         [ 3.03673856e-02,  1.97539851e-02, -1.50838614e-01,
           1.14162855e-01,  1.80196881e-01, -1.22831225e-01,
           7.67074972e-02, -1.13835640e-01, -1.38265222e-01,
          -6.62374571e-02,  1.81205988e-01, -2.81262010e-01,
           1.72400191e-01,  2.07341984e-01, -1.34065270e-01,
           6.87680393e-02, -6.93561733e-02, -2.21116617e-01,
           1.04925461e-01,  1.02081522e-02,  1.51008025e-01,
          -2.92544812e-02, -1.05958931e-01,  1.61262244e-01,
           1.58383980e-01, -1.24027103e-01, -1.80273309e-01,
           2.02690706e-01,  1.30619720e-01, -1.44045368e-01,
           5.87314926e-02, -6.84582517e-02,  5.60571887e-02,
          -1.27603471e-01,  2.20635161e-01,  1.71862170e-01,
           1.77298188e-02,  1.31710157e-01, -2.06363559e-01,
           1.41939849e-01,  4.67592143e-02, -2.25164890e-01,
           2.84170844e-02, -1.87025517e-01,  2.21437346e-02,
           2.89680868e-01,  2.44593516e-01,  5.39705567e-02,
           1.68798208e-01, -9.17015448e-02, -9.46003050e-02,
          -8.50451589e-02, -9.65483636e-02,  2.15933964e-01,
           3.86347598e-03, -2.29437221e-02,  8.44280720e-02,
           1.96231887e-01,  3.78342345e-02,  1.12372516e-02,
           7.45132491e-02, -1.45243943e-01,  1.38520822e-01,
           1.27623096e-01,  9.93933976e-02,  7.73796961e-02,
           1.07909396e-01,  5.35671674e-02, -2.25077912e-01,
           1.48774251e-01],
         [-2.48966157e-01, -1.18819617e-01,  3.78526822e-02,
          -4.11971584e-02,  5.32225370e-02,  2.79902488e-01,
           3.43969136e-01, -5.78653142e-02,  1.67140678e-01,
          -2.94734612e-02,  1.13698818e-01, -2.92426739e-02,
          -1.79812416e-01,  2.88941506e-02, -1.41450733e-01,
          -7.92392809e-03, -1.35528877e-01, -2.56182522e-01,
          -2.33598545e-01, -5.47329225e-02,  2.58110791e-01,
          -2.45282829e-01,  4.75647040e-02,  4.78960238e-02,
          -6.56322390e-02, -6.67297915e-02,  1.69852525e-01,
          -1.50414899e-01,  2.58721203e-01,  1.14194579e-01,
           2.65164256e-01,  8.89386758e-02,  2.67333359e-01,
          -3.09747636e-01, -1.52420253e-02,  2.57288337e-01,
           1.46575630e-01,  8.43582675e-02,  1.89198285e-01,
          -5.13301976e-02, -1.45431489e-01,  1.83323875e-01,
           2.22104147e-01, -7.55850300e-02,  1.44288674e-01,
          -1.75847083e-01, -1.43846169e-01,  1.33877620e-01,
           1.63822114e-01, -1.28378317e-01, -2.10838597e-02,
          -2.69852519e-01,  1.04066990e-01,  2.06833377e-01,
          -1.28662705e-01,  1.49911791e-01, -2.75938064e-01,
          -3.31552997e-02,  2.19017982e-01,  6.46202068e-04,
           1.66913256e-01, -1.72089741e-01,  9.96593982e-02,
          -2.43812397e-01, -8.03031027e-02, -1.92508698e-01,
          -3.14832121e-01, -9.16534588e-02, -3.15453112e-01,
           1.48415402e-01],
         [-2.35771656e-01,  3.27018127e-02,  1.60873935e-01,
          -1.28616795e-01,  3.11803758e-01, -2.35472228e-02,
          -1.39719948e-01,  1.74694061e-02,  7.51914829e-02,
           2.35624880e-01,  7.33765140e-02,  2.16503426e-01,
           4.06566672e-02, -2.05656707e-01,  1.96258724e-01,
           5.99774197e-02, -1.27538797e-02, -6.30170330e-02,
          -1.16274104e-01, -1.43104732e-01, -1.37973130e-01,
          -1.91767380e-01,  3.22461128e-01,  2.99887396e-02,
           2.64688015e-01, -2.45580390e-01, -2.41390377e-01,
          -1.29994661e-01, -1.80605844e-01, -2.61187732e-01,
           1.44567609e-01,  1.88110307e-01,  1.73101038e-01,
           2.86840070e-02, -1.33754045e-01, -5.33887371e-02,
          -1.13288000e-01, -8.15718770e-02,  2.53453523e-01,
          -1.54690027e-01, -1.32443011e-02, -6.94180205e-02,
          -1.20536266e-02, -2.19712891e-02, -2.30549023e-01,
           2.46970072e-01, -1.82330459e-02, -1.24268174e-01,
           2.66243219e-01,  1.11885495e-01,  8.33856687e-02,
          -1.06503241e-01, -2.80220248e-02, -1.17930442e-01,
           2.08708122e-01,  7.04001710e-02, -1.37973502e-02,
           1.89776018e-01, -7.30874389e-02, -2.11521506e-01,
           1.42071024e-01,  2.42409576e-02,  8.69186819e-02,
           3.34844626e-02, -2.07044452e-01, -1.04645088e-01,
           1.51515082e-01, -1.95780490e-02,  2.13911623e-01,
           9.59823653e-02],
         [-2.26251304e-01, -4.98282760e-02,  8.57945010e-02,
           1.85095415e-01,  1.94030240e-01,  1.70300901e-01,
          -1.48310944e-01, -1.68697998e-01,  1.38381734e-01,
          -8.20567235e-02,  1.35808028e-02, -1.75055087e-01,
           2.08388101e-02, -2.22936451e-01, -7.68952891e-02,
          -4.24526669e-02,  4.03720774e-02,  2.34893888e-01,
          -1.57926619e-01, -2.40865514e-01,  1.67401552e-01,
           2.16235057e-01, -1.50564939e-01,  1.77459866e-01,
          -1.02011845e-01,  9.56041086e-03, -1.36439502e-01,
           1.67499810e-01,  1.46594793e-01, -2.37665162e-03,
           2.35330492e-01, -4.87338640e-02, -8.25209543e-02,
          -7.34776333e-02,  2.11637601e-01, -8.63815099e-02,
          -2.52601802e-01, -1.03249528e-01,  1.14807218e-01,
           1.93410560e-01, -7.48374164e-02,  4.09806073e-02,
          -1.25015989e-01,  1.75860271e-01,  1.65006757e-01,
           1.63865000e-01,  1.56919926e-01, -2.22888529e-01,
          -3.29164751e-02,  4.06037048e-02,  2.24684268e-01,
           1.01046182e-01, -1.53632820e-01, -1.65310353e-01,
           4.86176573e-02, -2.46649399e-01, -2.84075760e-03,
           1.55264661e-01,  4.27330621e-02, -2.05510065e-01,
           1.62713528e-01, -3.14808562e-02,  1.86110288e-01,
           6.84845075e-02,  4.47224490e-02, -3.40451181e-01,
           1.40326787e-02,  2.19547436e-01,  7.52496868e-02,
           1.09770238e-01],
         [-2.14519277e-01, -1.97733060e-01, -1.04191333e-01,
          -1.52826672e-02,  1.04496861e-03, -6.56969398e-02,
          -7.04714730e-02, -1.19291015e-01,  1.01761602e-01,
           7.52121955e-02, -2.15532720e-01, -1.47176266e-01,
           1.51603609e-01, -1.83050726e-02, -3.25457342e-02,
          -5.11338934e-02,  1.16198196e-03, -2.66087204e-01,
           7.53995031e-02,  7.98415840e-02,  4.19246480e-02,
          -7.96627849e-02,  1.22839414e-01,  1.80793643e-01,
          -2.73334742e-01,  5.54925241e-02,  1.19968027e-01,
           1.63323641e-01,  1.11940101e-01, -1.46585837e-01,
           1.94005132e-01,  1.88561931e-01, -5.62924668e-02,
          -4.18225750e-02, -1.56423241e-01, -2.25715101e-01,
          -4.82656956e-02,  2.14031748e-02,  2.10182130e-01,
          -3.18871409e-01, -7.38589093e-02, -2.32924759e-01,
           8.74556080e-02, -1.10086516e-01,  1.84157446e-01,
          -1.46957889e-01, -1.06122330e-01,  2.88575172e-01,
           7.43130967e-02,  1.63028061e-01,  2.40940854e-01,
           8.84263813e-02,  1.86871052e-01, -1.03018314e-01,
          -2.51245052e-02, -2.32590944e-01,  2.58567259e-02,
           1.24988005e-01,  4.27892543e-02,  6.42778203e-02,
           2.41022035e-01, -5.46587259e-02, -1.77857980e-01,
           3.70368622e-02,  2.42744144e-02,  1.84613451e-01,
           2.30415717e-01, -1.80632919e-01, -9.84579027e-02,
          -4.87778150e-02],
         [-2.97077070e-03, -9.92525965e-02,  9.59780440e-02,
          -1.05714351e-01, -2.09908143e-01,  2.08500147e-01,
          -9.31153223e-02,  2.99151987e-01,  4.34016176e-02,
          -2.24611446e-01,  3.31769064e-02,  2.14490488e-01,
          -2.24754527e-01, -1.74998924e-01, -4.15243544e-02,
          -1.69698030e-01,  2.80564696e-01,  1.17882535e-01,
          -9.80678648e-02,  3.15327570e-03, -2.08990425e-01,
           1.49431065e-01, -1.39306724e-01,  2.40346678e-02,
           2.40564555e-01, -5.09837978e-02,  2.17804000e-01,
           1.35088935e-01,  8.79955664e-02, -5.64928725e-02,
           4.61013429e-02,  6.54249862e-02, -8.42749923e-02,
           2.62729824e-01, -3.99206020e-02, -1.17483221e-01,
          -1.40452668e-01, -1.06828704e-01, -1.74000204e-01,
          -4.49550189e-02,  2.60878950e-01,  2.07423091e-01,
          -9.15924609e-02,  1.91001654e-01, -1.47255644e-01,
           7.95471966e-02, -1.70050204e-01,  5.61165512e-02,
          -1.48466706e-01,  1.08682081e-01,  2.04737335e-02,
          -1.74528554e-01, -9.47896019e-02,  1.73530400e-01,
          -1.12356387e-01,  9.92965326e-02,  1.26004890e-01,
          -2.32813179e-01,  9.49711502e-02, -2.34253883e-01,
          -2.76989549e-01, -7.66268969e-02, -3.41671556e-02,
          -5.10511408e-03, -5.79159260e-02, -6.46380782e-02,
          -5.50055876e-02, -3.11404735e-01,  2.45275497e-01,
          -2.22187296e-01],
         [ 4.54801060e-02,  2.56455511e-01, -1.82633027e-01,
          -1.01602580e-02, -8.93032998e-02,  1.04237944e-01,
           5.84088564e-02,  1.54823989e-01, -1.07336426e-02,
           2.69688278e-01,  6.16033142e-03, -6.09616982e-03,
           8.98296311e-02,  1.78536490e-01, -1.43777172e-03,
          -9.94328558e-02, -4.55807038e-02,  9.91010815e-02,
          -4.42102812e-02,  3.77892517e-02,  1.33471981e-01,
           7.44501278e-02,  1.62690468e-02,  2.23104075e-01,
          -2.61054993e-01,  3.15811366e-01, -2.96082795e-01,
           1.78025752e-01, -2.63285220e-01, -5.37474826e-02,
          -9.58651751e-02, -2.15012103e-01, -4.33603339e-02,
          -2.60652751e-01, -5.41594252e-02,  2.35952377e-01,
          -3.74763012e-02,  1.91953376e-01,  1.17158510e-01,
           3.78518994e-03, -6.19563572e-02,  2.10780635e-01,
           1.62149847e-01, -1.30085796e-01,  1.28252106e-03,
           2.28483707e-01, -1.14689972e-02, -8.24389532e-02,
          -1.77851245e-01, -1.37649611e-01,  1.65123567e-01,
           1.03654794e-01,  8.36220309e-02,  1.99557766e-02,
           6.00132421e-02, -3.04210056e-02, -2.81973660e-01,
          -2.42123492e-02, -2.17434868e-01, -9.64278206e-02,
          -1.85030416e-01, -2.62960136e-01,  5.34782112e-02,
           1.58508420e-01,  1.65380761e-01, -3.85079943e-02,
           2.55265355e-01,  5.09922206e-02, -1.47566527e-01,
           7.40251169e-02],
         [ 1.29649222e-01, -3.14282179e-02, -6.06167972e-01,
          -2.50955880e-01, -3.46874207e-01,  7.49993503e-01,
           7.28010595e-01, -1.06399655e+00,  1.06234324e+00,
          -3.55233133e-01, -5.50140023e-01,  1.00409508e+00,
          -5.45210958e-01,  1.93181217e-01, -7.01776028e-01,
          -2.10634783e-01, -4.23527777e-01,  3.09440106e-01,
          -1.91907719e-01,  2.85458267e-01,  7.82932997e-01,
          -5.32808244e-01, -4.39185768e-01, -7.65542090e-01,
          -3.82927716e-01, -1.15567505e+00,  1.67764112e-01,
          -9.48192775e-01,  3.64812821e-01,  2.28667915e-01,
           6.75961256e-01,  8.27623010e-01, -6.38736844e-01,
          -2.00036347e-01, -3.25849533e-01,  9.03906941e-01,
          -2.68816352e-01, -6.27302647e-01, -3.23336124e-01,
           4.70992297e-01, -5.73931932e-01,  9.17997599e-01,
           7.42488205e-01, -2.06164107e-01,  2.04111740e-01,
          -7.19973087e-01, -3.76782537e-01,  8.55549395e-01,
          -8.38361323e-01, -9.57333803e-01, -5.20633638e-01,
          -3.67659301e-01,  1.50605768e-01, -7.64182091e-01,
          -3.19448918e-01, -5.01123592e-02,  1.64251193e-01,
          -7.17021644e-01,  6.97100699e-01, -1.25111267e-01,
          -4.96421248e-01, -5.05610764e-01,  1.01232016e+00,
          -1.09313202e+00,  3.20109189e-01, -1.06782168e-01,
          -9.03548539e-01, -2.81452984e-01, -2.17785537e-01,
           6.68265998e-01],
         [-2.42571607e-01,  2.04211175e-01, -4.92268875e-02,
          -1.63620815e-01,  4.04583551e-02,  3.45696330e-01,
           3.54173370e-02,  8.10830146e-02, -1.61551312e-03,
          -1.48698622e-02,  1.94258001e-02,  1.15005746e-01,
          -1.20848659e-02,  2.12298751e-01,  8.92769620e-02,
          -7.64900148e-02, -3.41445431e-02,  7.51630887e-02,
          -3.40494029e-02,  2.70350277e-01,  4.42853682e-02,
           5.13006859e-02,  2.81202555e-01,  1.35484681e-01,
           7.24686086e-02, -1.34075984e-01,  1.70696169e-01,
           8.00305977e-03,  2.56366223e-01,  1.33748680e-01,
           2.25041300e-01, -7.13687986e-02, -4.96987440e-02,
           3.10503058e-02, -2.25651234e-01,  3.94519985e-01,
           2.15304196e-01,  1.15869548e-02,  1.47072956e-01,
           3.18337977e-01, -6.86229253e-03, -5.09570874e-02,
           2.29824454e-01,  2.61031240e-02,  2.89728433e-01,
           3.48875783e-02, -5.50319031e-02,  1.21588549e-02,
          -4.12969440e-02,  1.07327215e-01,  1.35437414e-01,
          -2.93096341e-02,  3.36093381e-02, -1.90401971e-01,
          -2.66215026e-01,  6.08073771e-02,  6.91038072e-02,
          -1.98440487e-03,  7.31287152e-02, -2.77851731e-01,
          -1.08341835e-01, -1.85085818e-01,  2.29901448e-01,
          -2.96091676e-01,  2.23246500e-01, -1.44393981e-01,
          -1.93921745e-01, -1.92566663e-01,  1.32529914e-01,
          -1.94337085e-01],
         [-2.53917843e-01, -2.51892120e-01, -1.32432416e-01,
           1.47464365e-01, -3.17318618e-01,  1.97301418e-01,
           2.69987226e-01, -4.56497446e-02,  2.30195507e-01,
           1.32218450e-02, -4.06064779e-01,  2.51328260e-01,
           5.33021428e-02, -2.66608417e-01, -1.79995075e-01,
           1.29986405e-01,  2.86205828e-01,  1.35580912e-01,
          -2.18271971e-01,  1.57579169e-01,  2.17058808e-01,
          -1.03528440e-01, -4.87327874e-02, -6.85375035e-02,
          -1.29330382e-01, -1.23507090e-01,  4.80753556e-02,
          -3.16315889e-01,  2.06642285e-01, -1.25930071e-01,
           4.74674441e-03, -2.28398144e-01,  1.07306920e-01,
           2.11515024e-01, -1.60666719e-01,  1.23706006e-01,
          -2.24141285e-01,  5.61789013e-02,  1.76867880e-02,
          -9.17073190e-02,  8.19897652e-02, -1.55695155e-02,
           3.17650735e-01,  2.38761097e-01,  2.45510742e-01,
           8.75920355e-02,  3.21321398e-01, -1.12799473e-01,
           2.10606474e-02, -1.81161851e-01,  2.40284592e-01,
          -2.50088274e-01, -2.18976215e-01,  1.12234220e-01,
           4.77548651e-02, -4.73017395e-02,  1.37630356e-02,
           1.92280307e-01,  7.14965388e-02, -6.21563159e-02,
          -7.16416389e-02,  1.23388998e-01,  1.82368487e-01,
           2.31735930e-01, -2.02105597e-01,  1.42061830e-01,
           1.23616353e-01,  1.56008020e-01, -2.19544828e-01,
           2.12301493e-01],
         [-7.18890205e-02, -1.92233965e-01,  2.33305559e-01,
           6.87015578e-02,  8.51642191e-02, -2.19545767e-01,
           6.00749105e-02,  3.61590572e-02, -6.68269545e-02,
           1.48855716e-01,  2.30278343e-01,  2.16507941e-01,
           2.22660348e-01, -2.84734219e-01,  2.37847969e-01,
          -1.55460656e-01, -2.26989180e-01, -6.12188876e-02,
           1.77810416e-01,  1.45450696e-01,  2.52608925e-01,
          -1.36337921e-01, -1.94631949e-01,  1.51410148e-01,
           3.44162211e-02,  1.61046118e-01, -1.24759860e-01,
           1.83450043e-01,  1.75598450e-02, -2.05802217e-01,
          -8.66022483e-02,  6.08737469e-02, -2.22572535e-01,
          -1.20819479e-01, -1.68945014e-01,  2.10285246e-01,
          -2.40360171e-01, -1.79741889e-01, -1.93881094e-01,
           1.32005673e-03, -8.93675536e-02, -1.65670961e-01,
          -8.00130144e-02, -2.01122567e-01,  1.55159965e-01,
           4.84559573e-02, -1.92197278e-01,  1.46897465e-01,
          -1.71061575e-01,  5.79360016e-02, -1.45457163e-01,
           1.78534076e-01,  1.95346713e-01, -9.44947526e-02,
          -2.78981924e-01, -1.16451114e-01,  1.21675292e-02,
          -1.05452980e-03,  2.97299847e-02,  1.15553983e-01,
          -1.47618756e-01,  2.83984572e-01, -9.44054872e-02,
          -6.82652295e-02,  1.54531911e-01, -9.11844522e-02,
           2.69836523e-02, -3.09856743e-01,  6.67436346e-02,
           2.40427703e-01]], dtype=float32),
  array([[ 0.04073317, -0.16170031,  0.08170982, ...,  0.12143514,
          -0.03804543, -0.1848121 ],
         [-0.02006259,  0.04184515,  0.20358184, ...,  0.08938669,
           0.02554417, -0.0998741 ],
         [-0.05848309, -0.13393435,  0.28651938, ..., -0.19336581,
           0.28697622, -0.18376462],
         ...,
         [-0.13138615, -0.10152157,  0.05253223, ...,  0.16827357,
           0.09525165,  0.17411834],
         [-0.00976845, -0.10780089,  0.2228816 , ...,  0.1733975 ,
          -0.10156322,  0.03318954],
         [ 0.09590832, -0.01828083,  0.12743485, ...,  0.25016934,
           0.12800731, -0.10581163]], dtype=float32),
  array([[-0.00384881, -0.12021059,  0.01248708, ...,  0.01682259,
          -0.17754331,  0.02930963],
         [-0.03520177,  0.0117013 ,  0.03343487, ..., -0.16231427,
           0.1756002 ,  0.00351096],
         [-0.1752005 ,  0.004585  , -0.11959553, ..., -0.17236647,
           0.28346488,  0.26809448],
         ...,
         [ 0.01488994,  0.00250473, -0.25695267, ..., -0.11059541,
           0.17581026, -0.23348542],
         [ 0.21297403,  0.24602796,  0.06359419, ...,  0.205567  ,
           0.04510517,  0.11687386],
         [-0.17597616,  0.07059528,  0.10327347, ..., -0.02315794,
           0.00959007, -0.01356981]], dtype=float32),
  array([[-2.82930046e-01,  1.26908660e-01],
         [ 2.37486243e-01, -3.81716669e-01],
         [ 9.92978290e-02, -3.47963899e-01],
         [ 3.02352726e-01, -3.74164760e-01],
         [-2.05417976e-01,  2.52470911e-01],
         [-4.55201864e-02, -2.02432677e-01],
         [ 1.73006430e-01, -4.46816646e-02],
         [-2.84130216e-01, -2.26977065e-01],
         [-4.35910234e-03,  3.76744062e-01],
         [ 1.45330116e-01, -3.25348943e-01],
         [ 2.28147835e-01, -2.77784109e-01],
         [-1.19501755e-01,  4.07545753e-02],
         [ 1.01264335e-01, -2.43342578e-01],
         [-1.60477936e-01,  5.24386704e-01],
         [ 6.06849305e-02,  9.89513546e-02],
         [-2.89398909e-01,  1.83537304e-01],
         [ 1.01001307e-01,  2.95499355e-01],
         [-2.97017217e-01,  3.22097719e-01],
         [ 1.97861195e-01, -2.02269956e-01],
         [-7.52512068e-02,  9.88621786e-02],
         [-1.38137221e-01, -4.40452248e-01],
         [-2.33402535e-01,  1.64692864e-01],
         [ 1.27101064e-01,  1.98759794e-01],
         [-3.01784992e-01,  2.12811917e-01],
         [-1.96352318e-01,  1.54295802e-01],
         [ 2.49975443e-01, -2.01082289e-01],
         [-1.38984874e-01,  2.29037121e-01],
         [ 1.06105595e-04,  2.89339125e-01],
         [-3.00384670e-01, -4.83968072e-02],
         [-1.86271910e-02,  3.05029899e-01],
         [ 1.99009106e-03,  2.14236692e-01],
         [-3.34532440e-01, -6.11541159e-02],
         [-1.35282487e-01,  5.65957166e-02],
         [-2.68078923e-01,  1.56603098e-01],
         [-2.48180240e-01,  2.94318020e-01],
         [-1.30787000e-01, -4.86690477e-02],
         [ 2.72127450e-01,  1.19140044e-01],
         [ 2.90248722e-01, -1.86683103e-01],
         [ 9.85520706e-02, -2.18973175e-01],
         [-3.67538421e-03, -1.40206725e-03],
         [ 2.09687546e-01,  2.38504097e-01],
         [-1.00464404e-01,  2.42502570e-01],
         [ 1.55400500e-01,  1.01416796e-01],
         [-3.02865952e-01,  7.42565766e-02],
         [-1.75600335e-01,  3.34860444e-01],
         [-1.38489887e-01,  2.21242890e-01],
         [ 1.80740595e-01,  2.85507560e-01],
         [ 2.81139612e-01, -1.64098963e-01],
         [ 1.56524777e-01, -3.87664348e-01],
         [ 4.04402643e-01,  3.33227254e-02]], dtype=float32)],
 [array([-0.02854579,  0.05311754,  0.04959053, -0.03715162,  0.02330466,
         -0.04269688, -0.01013005,  0.06874033, -0.01974296,  0.06377108,
          0.01915232, -0.02997592,  0.04385599, -0.03190788,  0.0822746 ,
         -0.0659938 ,  0.05614723, -0.05916027,  0.00469705, -0.07926938,
         -0.04140961,  0.04683878,  0.05930374,  0.03659062,  0.01063251,
          0.08476436, -0.05785139,  0.06547377, -0.06716973, -0.03636409,
         -0.07686226, -0.06849793,  0.05797694,  0.02435303,  0.03509652,
         -0.05251152,  0.        ,  0.08996248,  0.03742762, -0.01523749,
          0.03550566, -0.03804714, -0.00484502, -0.00379997, -0.08697824,
          0.05449862,  0.0753291 , -0.04323955,  0.06112291,  0.07690092,
         -0.0228012 ,  0.0299318 , -0.00519692, -0.02068757, -0.00509608,
         -0.05934853,  0.01074862,  0.04968895, -0.06797037, -0.03621136,
          0.04376265,  0.05112915, -0.01860299,  0.1174278 , -0.07478895,
         -0.03162573,  0.03712742, -0.02407979,  0.01463216, -0.06149809],
        dtype=float32),
  array([-0.00573702,  0.02881279,  0.08348639,  0.05104946,  0.00677452,
          0.07581051, -0.00896564,  0.05683671, -0.04194446, -0.03258895,
          0.02584886, -0.01876919, -0.03003295,  0.06011844, -0.03459567,
         -0.02480574, -0.02663354, -0.01426572,  0.04364663,  0.08130559,
          0.06495807, -0.04167927,  0.05315804,  0.06356885, -0.0253933 ,
          0.04731547, -0.05747719,  0.05308496, -0.05659075,  0.05597835,
         -0.04872588, -0.04101903,  0.03206719,  0.05267171, -0.05795552,
          0.02027026, -0.05280019,  0.05442906, -0.01780812,  0.06429685,
         -0.04762284,  0.06730605, -0.05093784,  0.05504497,  0.05357518,
         -0.02114536,  0.07897805,  0.02339305,  0.08138859, -0.03255979],
        dtype=float32),
  array([-0.02713361,  0.09309322,  0.03805388,  0.03424142, -0.01113803,
          0.01667055,  0.07447569, -0.06668621, -0.03150655,  0.06387825,
          0.07785708, -0.02531729,  0.0723519 ,  0.00347189, -0.0177357 ,
          0.        , -0.03487411, -0.04851071,  0.06759044, -0.03533144,
          0.02401838, -0.03073125, -0.04450166, -0.02654641, -0.02827855,
          0.        , -0.01567897, -0.0040624 , -0.02975524, -0.03281569,
         -0.02176444, -0.03663797, -0.01836163, -0.01098274, -0.05059136,
         -0.02361586,  0.06188901,  0.07352342,  0.05501923, -0.04410602,
         -0.04001745, -0.05807052,  0.05206716, -0.01867674, -0.0089961 ,
         -0.01759916, -0.01311922, -0.01491712,  0.05937983,  0.06336498],
        dtype=float32),
  array([ 0.05128358, -0.05128355], dtype=float32)])
In [12]:
get_updates(model, X_train, y_train, 32, 2)
Out[12]:
([array([[ 1.49011612e-06, -1.16565228e-02,  7.42741823e-02,
           1.84727460e-02, -8.54203105e-03, -1.07147992e-02,
          -1.21345818e-02,  7.30311871e-02, -9.66983289e-03,
           6.97897077e-02,  3.34345400e-02,  8.74996185e-05,
           8.57096016e-02,  1.60385072e-02,  1.19651444e-02,
           2.49370933e-03,  8.43369067e-02,  1.64110810e-02,
           2.05034018e-02, -1.07030272e-02,  1.41942352e-02,
           8.08978826e-03,  3.21836621e-02,  5.65186143e-02,
          -4.39171195e-02,  4.30863351e-03, -1.15206484e-02,
           2.10180618e-02, -7.25415349e-03,  3.23727727e-06,
          -3.62303257e-02,  1.32649988e-02,  6.42685592e-03,
           1.92415118e-02,  3.42182815e-02,  1.04672834e-02,
           0.00000000e+00,  2.78882086e-02,  4.19589877e-02,
          -2.18840521e-02,  9.14855003e-02,  4.81577218e-03,
          -5.83276898e-03,  1.10103935e-02, -1.27351061e-02,
           1.04498863e-02, -4.18264568e-02,  2.71596760e-02,
           2.05704421e-02, -3.03956270e-02,  3.23033035e-02,
          -4.21546996e-02,  4.45771776e-02, -3.91504467e-02,
          -8.15168023e-02,  2.26370320e-02, -4.53699529e-02,
          -4.22071815e-02, -1.31393373e-02,  1.32917091e-02,
           3.45049202e-02, -2.07801014e-02, -3.54883261e-03,
           1.25047117e-02, -6.88902661e-03, -3.12848352e-02,
           1.91906095e-02,  4.19499911e-02, -6.82501495e-02,
          -1.53118074e-02],
         [ 1.26473606e-06,  3.19551006e-02, -9.95661318e-03,
           1.19727850e-02, -1.79602206e-03,  1.53238177e-02,
           3.48728746e-02, -6.26268983e-03,  1.16313845e-02,
          -1.24692582e-02,  7.92261958e-03,  1.70102119e-02,
          -6.51185215e-03,  1.54842287e-02, -1.05841383e-02,
           6.68349117e-03, -5.42246550e-03,  2.38585621e-02,
          -4.80242074e-03,  4.91370745e-02,  1.68562233e-02,
           1.78126954e-02,  1.90734789e-02,  1.65916681e-02,
           1.61990523e-03, -8.57291371e-03,  2.71531343e-02,
          -7.27659464e-03,  4.15879637e-02,  2.78651714e-06,
          -5.89160994e-03, -5.41770346e-02,  2.21652985e-02,
           2.85971761e-02,  2.64423043e-02,  2.02812403e-02,
           0.00000000e+00,  4.61198390e-03,  1.03888065e-02,
          -1.25160664e-02,  6.73976541e-03,  7.16526806e-03,
           2.04948001e-02,  3.80229056e-02,  5.23998961e-03,
           1.47194266e-02, -3.79353762e-04,  1.02847666e-02,
           2.61461437e-02, -2.40194798e-02,  6.51396215e-02,
          -3.63333039e-02, -5.51320203e-02,  6.67399764e-02,
           2.30744481e-04, -1.98918562e-02, -1.48451179e-02,
           2.16108412e-02,  2.36949362e-02,  1.34854689e-02,
           2.18952857e-02, -3.03417742e-02,  1.44592375e-02,
          -5.22305071e-03,  2.16116756e-02, -1.71370506e-02,
          -1.07840151e-02,  2.18297653e-02, -4.98308837e-02,
           6.17059022e-02],
         [ 7.59959221e-07,  6.49914891e-03,  6.02702424e-03,
           4.44378480e-02, -3.70854139e-03, -4.58493829e-03,
          -6.07848167e-03,  1.07879266e-02, -7.88053870e-03,
           5.81801683e-03, -5.74702024e-03, -1.39639974e-02,
          -6.83215261e-03,  4.36457917e-02,  3.14892232e-02,
           7.88080785e-03, -8.65940750e-03,  2.13787258e-02,
          -4.84359264e-03,  3.76113243e-02, -7.68777728e-03,
           3.44041884e-02,  2.17271540e-02,  3.49786282e-02,
          -9.57245380e-03,  1.17718354e-02,  1.76328421e-03,
           2.47428566e-02, -1.89929008e-02,  9.83476639e-07,
          -4.87717986e-03,  1.81090236e-02, -1.68472528e-03,
           1.75802708e-02,  1.03006195e-02,  2.29674578e-03,
           0.00000000e+00,  1.71143487e-02, -5.93899190e-03,
          -8.52180272e-03,  2.21078098e-03, -6.65974617e-03,
          -5.27688861e-03,  1.78252906e-03,  1.92168504e-02,
           3.20638865e-02,  3.12381685e-02, -9.20593739e-04,
           2.36202478e-02,  1.44775957e-02, -2.05530357e-02,
           1.06063485e-03,  3.32839042e-03, -1.01231039e-03,
          -2.06442177e-03,  9.97217000e-03,  2.65306234e-03,
           1.02972686e-02, -7.31389225e-03,  8.04215856e-03,
           9.98052955e-03, -1.05124712e-03, -6.54879957e-03,
           2.02297121e-02,  1.26888603e-02,  2.11532712e-02,
           8.01965594e-03, -1.44157931e-02, -1.14605129e-02,
           5.96144795e-03],
         [ 0.00000000e+00, -1.39323249e-02,  6.39429688e-03,
          -6.28607944e-02, -8.64857435e-03,  1.24468859e-02,
           2.86395848e-03,  8.28726962e-03,  1.22120306e-02,
           3.18306237e-02, -6.85551018e-03,  2.63845623e-02,
           2.07465943e-02,  0.00000000e+00,  2.27561444e-02,
           1.00681484e-02,  2.27013733e-02, -6.71585724e-02,
          -7.02583045e-03, -1.04030967e-02, -4.20713425e-03,
           1.82433575e-02, -1.60830319e-02,  1.62253212e-02,
          -1.33093297e-02, -3.11521888e-02,  3.77973914e-03,
           5.21981716e-03, -1.98365897e-02,  0.00000000e+00,
           2.18344629e-02, -2.21273601e-02,  3.62999588e-02,
           8.26938823e-03, -7.49810040e-03, -1.21858753e-02,
           0.00000000e+00, -1.79555565e-02,  6.84709847e-03,
          -5.80060855e-02,  5.37337959e-02, -3.77775021e-02,
          -7.55639747e-04, -2.18640100e-02, -4.38468456e-02,
           1.06569529e-02, -2.04435959e-02,  1.75144523e-02,
          -3.92878056e-03, -3.37126106e-03,  9.41401571e-02,
           2.34664902e-02,  3.57846729e-02, -1.89120620e-02,
          -3.53061408e-02, -1.61360130e-02, -7.82630816e-02,
           2.25571841e-02, -3.70850414e-03,  0.00000000e+00,
           1.72483921e-03, -3.43789682e-02, -6.79399073e-03,
           8.89310054e-03, -6.08799458e-02, -2.71661580e-03,
          -3.56471390e-02,  1.11350343e-02, -2.56258100e-02,
           2.05378458e-02],
         [ 2.23517418e-06, -7.50984997e-02,  1.20643079e-02,
           2.09403038e-02,  1.78195983e-02,  1.29758865e-02,
           9.08046961e-04, -7.36726820e-03,  8.75070691e-03,
          -3.93393636e-03,  4.93140984e-03, -7.30901957e-03,
          -2.97645107e-03,  4.18809950e-02,  1.70973763e-02,
           5.04964590e-03, -1.19856521e-02, -5.31809032e-03,
           9.80186462e-03,  6.25750273e-02,  7.91132450e-04,
           5.25092632e-02,  1.74276680e-02,  2.94222236e-02,
          -9.10501927e-03,  1.67332776e-03,  2.11818963e-02,
           1.82696581e-02, -2.05858052e-03,  1.21816993e-06,
          -2.51303613e-02, -7.00031966e-03, -1.90818682e-02,
           2.38134339e-02,  1.85852498e-02,  1.88816562e-02,
           0.00000000e+00,  2.37748027e-03,  1.05057806e-02,
          -5.73440939e-02, -5.06321341e-03,  1.27147976e-02,
          -9.73396003e-04,  9.90636647e-03,  2.64340937e-02,
           1.75562948e-02,  1.66450143e-02,  1.51741505e-03,
           2.09401064e-02,  2.11877432e-02, -1.54795647e-02,
           1.00296140e-02, -3.15607935e-02, -2.11823583e-02,
           5.22204116e-02,  6.60809129e-02, -2.77949665e-02,
           1.30063444e-02,  1.25313178e-03,  2.72946656e-02,
           4.89727184e-02, -1.69895999e-02,  1.30629092e-02,
           1.55981630e-03,  3.25212553e-02,  2.51948833e-04,
          -5.40875457e-03, -1.81352496e-02, -5.66543639e-03,
           2.94677094e-02],
         [ 0.00000000e+00, -4.97145504e-02, -3.21335346e-03,
           2.53131799e-02,  4.13774513e-03, -1.40788406e-03,
          -5.41490316e-03, -3.77743170e-02,  5.08022308e-03,
           1.75616145e-03,  1.51727945e-02,  1.46141499e-02,
           4.79898602e-02,  5.34517616e-02,  1.71546526e-02,
           3.00733373e-03,  2.73067039e-02,  4.28619981e-03,
           1.95392966e-03, -1.62802637e-03, -1.72183998e-02,
          -9.32161510e-03,  8.53136927e-03, -1.32551491e-02,
          -2.45124102e-04,  2.85785496e-02, -4.13239896e-02,
           9.71779227e-03, -4.31602448e-03,  0.00000000e+00,
           2.40202248e-03, -5.27902842e-02, -3.45261432e-02,
          -1.48351621e-02,  1.39084011e-02,  1.61831826e-02,
           0.00000000e+00, -5.43668903e-02,  7.43016601e-03,
           1.10875398e-01, -9.89424437e-03,  1.02454871e-02,
          -2.66520679e-03, -3.40092704e-02, -2.39770263e-02,
          -4.27880883e-02, -1.02666572e-01, -4.84505296e-03,
          -3.08043435e-02,  2.15946883e-02,  7.27008581e-02,
          -1.56747103e-02,  3.66818905e-03,  2.17209607e-02,
          -1.88962929e-02, -1.54630542e-02,  3.92973498e-02,
           3.84222344e-02,  2.19388902e-02,  1.04343221e-02,
           3.98382545e-03,  2.20829248e-03,  1.71164274e-02,
           6.02347218e-03, -4.65963036e-03,  5.45533001e-03,
           1.20884478e-02,  1.58099681e-02,  3.72115970e-02,
           4.13580202e-02],
         [ 5.59026375e-07,  2.19073892e-02,  1.40187293e-02,
           4.17544991e-02,  7.14531541e-03,  1.93578005e-03,
           3.68960202e-04,  9.27004218e-03, -7.46112317e-04,
           7.42256641e-03,  6.71111792e-03,  1.33904815e-02,
          -7.87302852e-03,  1.38419718e-02,  2.20113285e-02,
           2.13821232e-02, -3.75646353e-03,  2.36315280e-02,
           6.75320625e-05,  2.17337422e-02,  3.00095975e-03,
           2.45792195e-02,  1.44614577e-02,  2.65318509e-02,
          -1.13605559e-02,  8.07042047e-03,  3.87509167e-03,
           2.12808922e-02,  6.70775771e-04,  7.59959221e-07,
          -5.00490516e-03,  2.46562809e-03, -7.56940991e-03,
           1.34187639e-02,  4.02834602e-02,  1.11031756e-02,
           0.00000000e+00, -2.19111145e-03,  2.61621177e-03,
           1.26020499e-02, -5.91659546e-03,  4.12940979e-03,
           7.18377531e-04,  1.28328502e-02,  9.83218849e-03,
           1.39725059e-02,  7.12452829e-03,  9.86279920e-03,
           1.29076242e-02,  1.15619823e-02, -2.26510987e-02,
          -1.01644099e-02,  6.79536909e-03,  3.28451395e-03,
          -2.67382041e-02,  3.20950896e-03,  2.13438272e-03,
           2.84881890e-03, -1.82493776e-03,  2.74327397e-03,
           5.69045544e-04,  1.74677297e-02,  2.59828940e-03,
           9.38987918e-03,  2.89438292e-02, -1.68662779e-02,
           2.16373429e-03,  5.52284718e-02,  1.71898305e-03,
           1.84112042e-02],
         [ 1.92970037e-06, -7.33059645e-03,  6.08241558e-03,
           3.09040304e-02, -2.66182497e-02, -1.04110688e-03,
           6.55319542e-03,  1.30551010e-02, -6.43594936e-03,
           3.14751267e-03, -4.20112303e-03, -3.34899174e-03,
           1.19088590e-03,  2.75004655e-02,  1.39071923e-02,
          -3.07884067e-02, -6.91840798e-03,  3.75329219e-02,
          -9.75349918e-03,  1.89966261e-02,  1.39560252e-02,
           1.82078369e-02,  2.44169421e-02,  3.45724225e-02,
           3.83076072e-03,  2.96819210e-03, -2.79713571e-02,
           1.44217312e-02, -1.04773343e-02,  2.71573663e-06,
          -6.85906410e-03, -2.69274563e-02, -1.63219962e-02,
           2.06331909e-02,  1.40391737e-02,  9.01226699e-03,
           0.00000000e+00, -8.03788006e-03,  2.00668350e-02,
           1.17101632e-02,  8.09775665e-03,  1.32986456e-02,
           5.69504499e-03,  3.66747677e-02, -3.86754144e-03,
           2.89282054e-02,  3.69293764e-02, -2.47304142e-03,
           1.71321183e-02, -7.54578412e-03, -9.61725414e-03,
           1.61448121e-03, -1.06247291e-02,  3.18420455e-02,
           3.38531211e-02,  3.97837535e-02, -5.06854057e-03,
           4.11833264e-03, -2.18693763e-02,  0.00000000e+00,
           4.73356694e-02,  2.79366970e-02,  1.81355327e-03,
           1.18001401e-02,  1.80805475e-02, -4.14552540e-03,
          -4.44880128e-03, -3.96236032e-03, -2.36448124e-02,
           1.94880292e-02],
         [ 6.82026148e-05,  7.18571916e-02,  4.76970673e-02,
           9.57600772e-02,  1.31565988e-01, -1.64669454e-01,
          -7.16286302e-02,  1.19898200e-01, -2.11336374e-01,
           1.21684730e-01,  1.44014955e-01, -2.38232017e-01,
           1.46218359e-01,  4.44871187e-03,  7.22973943e-02,
           2.78884321e-02,  1.45830631e-01, -1.08749241e-01,
          -3.77449691e-02, -1.96678042e-02, -1.76536024e-01,
           5.42484522e-02,  6.87512457e-02,  2.47801125e-01,
           1.44790769e-01,  1.97795749e-01,  4.98535782e-02,
           2.33573496e-01,  1.99592710e-02,  2.87592411e-06,
          -1.66886806e-01, -1.45491302e-01,  2.27394998e-01,
           8.62296224e-02, -2.82013416e-03, -1.88887715e-01,
           0.00000000e+00,  6.12580180e-02,  1.33518666e-01,
          -1.46177143e-01,  1.59401894e-02, -2.11909533e-01,
          -1.65744841e-01,  1.35906681e-01, -2.10718960e-02,
           1.91993415e-01,  2.53243625e-01, -1.76544130e-01,
           4.48338389e-02,  1.66558981e-01,  4.43497002e-01,
           9.72419083e-02, -1.00000083e-01,  3.86097550e-01,
          -1.01794243e-01, -5.31688072e-02, -7.42531270e-02,
           2.60491490e-01, -1.89655364e-01, -1.26332268e-02,
           5.77271283e-02,  1.28030062e-01, -1.77092314e-01,
           2.60912895e-01, -1.63142890e-01, -2.03584060e-02,
           1.95915878e-01, -1.84553564e-02,  1.37363464e-01,
          -1.53825760e-01],
         [ 0.00000000e+00,  5.06502837e-02,  1.29767060e-02,
           6.28483295e-02,  6.34282753e-02, -1.72892511e-02,
          -9.57809016e-03, -1.54778808e-02,  9.92524205e-04,
          -7.11901113e-03, -1.05080698e-02,  1.87457874e-02,
           3.44372615e-02, -4.49787974e-02,  6.05929643e-03,
          -8.81178081e-02, -3.13648731e-02, -3.27146053e-03,
           1.48571506e-02,  9.23416018e-03, -1.23183951e-02,
           4.48894389e-02,  5.51292300e-02, -4.26419824e-02,
           2.51580812e-02,  5.67226112e-03, -1.00748777e-01,
          -2.07852945e-02, -2.21131444e-02,  0.00000000e+00,
          -2.12397873e-02,  1.61559135e-03,  1.88499428e-02,
           8.83308202e-02,  0.00000000e+00, -2.65170336e-02,
           0.00000000e+00, -3.47553790e-02,  3.28133181e-02,
          -1.98322237e-02, -5.88746481e-02,  3.99494730e-02,
          -2.77045369e-03, -2.44742259e-04, -3.06502879e-02,
          -2.40243226e-03,  4.12266403e-02, -4.80247736e-02,
           3.03203985e-02, -2.57963911e-02, -3.67290676e-02,
           6.71202391e-02, -8.20160732e-02,  2.72179991e-02,
           2.75232196e-02,  2.28866339e-02, -4.97791916e-03,
           3.20631042e-02,  2.17096210e-02,  0.00000000e+00,
          -4.09361497e-02,  2.31957436e-02, -2.09520608e-02,
          -1.32357776e-02,  5.95888197e-02,  0.00000000e+00,
          -8.28909650e-02,  4.27012593e-02,  2.20479891e-02,
           2.66924053e-02],
         [ 2.80141830e-06,  3.44573855e-02,  1.68616474e-02,
          -2.16292441e-02,  3.42520177e-02, -2.51318514e-03,
          -2.99223959e-02,  3.11351269e-02, -7.86145031e-03,
           3.79217640e-02,  5.88734746e-02, -2.47791409e-03,
          -1.02755167e-02,  1.48909986e-02,  6.38209134e-02,
          -2.27537304e-02, -2.68034339e-02,  8.43681395e-03,
           5.44354320e-04,  2.46977806e-03, -6.86615705e-03,
           5.15806824e-02,  3.51931527e-02,  4.50379923e-02,
           1.82978213e-02,  1.10631064e-02,  5.33567518e-02,
           5.02201617e-02, -2.68985629e-02,  2.14576721e-06,
           3.81427701e-03,  2.03454792e-02, -1.38287246e-02,
          -2.03619152e-02,  3.11780423e-02,  1.26634762e-02,
           0.00000000e+00,  2.77237929e-02,  9.62861814e-03,
          -1.84376612e-02,  9.24696773e-03, -5.45406435e-03,
          -1.83334649e-02,  7.35372305e-04, -1.68629438e-02,
           8.47994536e-03, -1.79683268e-02,  1.69715211e-02,
           1.91382784e-02,  4.71077561e-02, -4.38696444e-02,
           5.99594712e-02,  7.32099563e-02, -4.90160286e-02,
           2.35909577e-02,  1.54711753e-02,  1.36633702e-02,
          -4.24712151e-02, -2.49780267e-02,  7.48580322e-03,
           6.65621459e-03, -2.29308978e-02, -2.68799514e-02,
           1.97568089e-02,  6.91704005e-02,  2.21450403e-02,
          -2.95889378e-03, -7.42284954e-03,  3.10672224e-02,
           7.04184175e-04],
         [ 9.68575478e-08,  3.99114043e-02, -9.91953909e-03,
           3.25712711e-02, -1.46806389e-02,  5.75292110e-03,
           4.84618917e-03,  5.86069934e-03, -4.93802130e-04,
           1.27205700e-02, -1.29961967e-03,  6.42070174e-03,
          -1.86905414e-02,  3.09723020e-02,  1.03429109e-02,
           5.13049960e-03, -5.01155853e-04,  3.36712301e-02,
          -3.53804231e-03,  3.75385135e-02,  8.72188807e-03,
           1.75718218e-02,  3.87762487e-03,  3.15307751e-02,
          -1.69710331e-02,  7.96510279e-03,  1.68320760e-02,
           2.14911848e-02,  3.29448655e-03,  4.91738319e-07,
           9.02945548e-03,  1.97583884e-02, -3.59129906e-03,
           3.39994431e-02,  1.95574164e-02,  1.80864036e-02,
           0.00000000e+00,  1.42792165e-02,  4.98805940e-03,
           3.40463128e-03,  3.10290605e-03, -1.59697235e-03,
           3.81151587e-03, -1.06487721e-02,  2.39797831e-02,
           1.12190805e-02,  1.06314570e-02,  1.27038062e-02,
           5.77275455e-03,  4.94579226e-03, -7.09898770e-03,
           1.74231827e-03,  1.20699406e-06, -1.24374777e-03,
           6.46207333e-02,  2.14003474e-02,  1.49126686e-02,
           6.48828223e-03, -5.38637117e-03,  1.02209002e-02,
           7.54408538e-03,  7.72580504e-03,  3.43549997e-03,
           7.72102922e-03,  2.00577080e-02,  4.22081202e-02,
          -5.91963530e-04,  1.37665868e-02, -9.87207890e-03,
           2.15003788e-02]], dtype=float32),
  array([[-2.91615725e-05,  0.00000000e+00,  0.00000000e+00, ...,
           1.14738941e-05,  0.00000000e+00,  0.00000000e+00],
         [ 8.49024765e-03, -4.31102738e-02, -5.05789816e-02, ...,
           1.43139213e-02, -3.98727357e-02, -8.07711482e-03],
         [ 1.00968346e-01,  7.57245272e-02, -2.73677707e-03, ...,
          -3.34357619e-02, -2.99723148e-02,  3.60149145e-03],
         ...,
         [ 2.19667554e-02,  1.55268461e-02,  1.16532966e-02, ...,
          -6.69102073e-02, -4.36145514e-02,  8.49020481e-03],
         [ 4.64867577e-02,  2.68688127e-02, -7.21535087e-03, ...,
          -1.35802627e-02,  1.19180009e-02,  1.85405612e-02],
         [-4.64454293e-04,  2.76498124e-02, -5.06965816e-03, ...,
           2.36923993e-03,  1.80526227e-02,  2.38161311e-02]], dtype=float32),
  array([[-0.03158482,  0.03222603,  0.02458179, ...,  0.08041045,
           0.01049814, -0.04392172],
         [ 0.01048962, -0.01012598, -0.01415333, ..., -0.13836852,
           0.00485021, -0.01224236],
         [-0.00740135, -0.00507368,  0.00139559, ..., -0.13694564,
          -0.00293422,  0.00983405],
         ...,
         [ 0.00256901, -0.00454452,  0.00040635, ...,  0.00017413,
           0.0033875 ,  0.03571756],
         [-0.01601908,  0.00801374, -0.00454768, ..., -0.10811353,
           0.01376003, -0.00697925],
         [-0.0469159 ,  0.01431713,  0.0028784 , ...,  0.02210182,
          -0.01271321, -0.04000308]], dtype=float32),
  array([[-1.15624368e-02,  1.15615055e-02],
         [ 7.26526976e-03, -7.26565719e-03],
         [ 2.03531086e-02, -2.03537643e-02],
         [-6.14860654e-03,  6.14863634e-03],
         [-1.56270713e-02,  1.56272203e-02],
         [-1.02237239e-03,  1.02218986e-03],
         [ 1.03584975e-02, -1.03584863e-02],
         [ 2.06144452e-02, -2.06139088e-02],
         [-1.49596017e-04,  1.48773193e-04],
         [ 4.79128957e-03, -4.79125977e-03],
         [ 6.30669296e-03, -6.30635023e-03],
         [ 0.00000000e+00,  0.00000000e+00],
         [-3.16215307e-03,  3.16265225e-03],
         [ 5.85643947e-02, -5.85644245e-02],
         [-7.11206347e-04,  7.11120665e-04],
         [-1.32675469e-02,  1.32675469e-02],
         [-4.61159647e-03,  4.61125374e-03],
         [ 1.90031528e-03, -1.90034509e-03],
         [ 1.55715942e-02, -1.55715793e-02],
         [-1.19959861e-02,  1.19959936e-02],
         [-1.39654800e-02,  1.39658749e-02],
         [-5.74159622e-03,  5.74158132e-03],
         [-2.31015980e-02,  2.31016129e-02],
         [-3.14235687e-04,  3.14325094e-04],
         [ 2.00573653e-02, -2.00573802e-02],
         [ 0.00000000e+00,  0.00000000e+00],
         [-1.67117715e-02,  1.67115927e-02],
         [-9.77398362e-03,  9.77447629e-03],
         [ 0.00000000e+00,  0.00000000e+00],
         [-1.44827366e-02,  1.44827068e-02],
         [-4.94149094e-03,  4.94125485e-03],
         [-7.87574053e-03,  7.87599012e-03],
         [-1.95778906e-03,  1.95809081e-03],
         [ 5.55017591e-02, -5.55018783e-02],
         [ 1.07816160e-02, -1.07815862e-02],
         [-3.71644944e-02,  3.71643975e-02],
         [-2.17199326e-04,  2.16580927e-04],
         [ 5.87260723e-03, -5.87230921e-03],
         [-2.86758766e-02,  2.86761075e-02],
         [-2.14453358e-02,  2.14453079e-02],
         [-1.89587176e-02,  1.89587921e-02],
         [-2.95607746e-03,  2.95570493e-03],
         [ 1.51866078e-02, -1.51871741e-02],
         [-1.82390213e-05,  1.85519457e-05],
         [-3.47673893e-04,  3.47673893e-04],
         [ 2.12892145e-02, -2.12893784e-02],
         [-3.14503908e-04,  3.14474106e-04],
         [-2.15784311e-02,  2.15792507e-02],
         [-3.74348462e-03,  3.74361873e-03],
         [-4.07821834e-02,  4.07828614e-02]], dtype=float32)],
 [array([ 1.93715096e-06, -7.33081996e-03, -2.59644464e-02,  3.53913940e-02,
         -3.17766666e-02,  2.61158086e-02,  3.09046637e-02, -1.39724910e-02,
          1.64157562e-02, -1.33834183e-02, -2.69608200e-02,  1.18504055e-02,
         -3.39354798e-02,  2.74997652e-02,  8.21013749e-03,  3.14356387e-03,
         -1.39056407e-02,  3.89015339e-02, -4.97761788e-03,  4.34829444e-02,
          1.99493878e-02,  1.66720618e-02, -8.93015787e-03,  1.38194822e-02,
         -3.19721177e-02, -1.15122199e-02,  3.29295881e-02, -1.12637132e-03,
          2.69653797e-02,  2.71573663e-06,  2.13105604e-02,  8.05605203e-03,
         -1.68812610e-02,  6.44891895e-03,  1.83254965e-02,  3.12887095e-02,
          0.00000000e+00, -1.13615617e-02, -1.91472992e-02,  2.73335315e-02,
         -8.05212930e-03,  1.09619126e-02,  2.75111161e-02,  7.87547231e-03,
          5.21017611e-02,  8.26306641e-04,  4.48312610e-03,  1.90423653e-02,
         -9.32561979e-03, -1.31163374e-02, -7.42557272e-03, -8.75856541e-03,
         -3.20619484e-03,  1.95884481e-02, -3.87861580e-02,  7.03234226e-04,
         -4.51306719e-03,  1.04681477e-02,  9.19631869e-03,  4.49468940e-03,
          8.59957188e-04, -1.29299201e-02,  2.50575244e-02, -6.41628355e-03,
          1.74099207e-02,  6.38409331e-03, -1.43194981e-02, -5.27505018e-03,
         -2.04372257e-02,  3.64710018e-02], dtype=float32),
  array([ 0.00331618,  0.0539567 , -0.03811157, -0.01109656,  0.02418597,
         -0.02698364, -0.00161759, -0.02607356,  0.00603188,  0.0166943 ,
          0.00161197,  0.01395666, -0.02311181, -0.02586061,  0.01232279,
          0.        ,  0.00035155,  0.01375218,  0.01308896, -0.01829513,
         -0.01323594,  0.01266809, -0.01157204, -0.01712865,  0.00102041,
         -0.00915188,  0.00994506,  0.0003348 ,  0.01999653, -0.01448206,
          0.02454722, -0.00036955,  0.01865685, -0.00514613,  0.01885039,
         -0.02837751,  0.00375455, -0.04659929,  0.00275468,  0.00854283,
          0.02362405, -0.02947872,  0.01506701, -0.00749737, -0.02634757,
          0.02495424, -0.02921787,  0.01843503, -0.00016499,  0.01456888],
        dtype=float32),
  array([ 0.00049266, -0.02066331, -0.0207266 , -0.01927541,  0.02330254,
          0.00480025, -0.0293951 ,  0.01040582,  0.00090722, -0.0259526 ,
         -0.02241635,  0.        , -0.02288882,  0.02825862,  0.00242692,
          0.02985683,  0.00974142,  0.00933962, -0.02005748,  0.01471977,
          0.00262542,  0.00572324,  0.02363591,  0.00187776,  0.01032351,
          0.        , -0.01178178, -0.01856739,  0.        ,  0.01882949,
         -0.00219416,  0.01156303,  0.00352357,  0.01064696,  0.02187355,
          0.00582501, -0.02582753, -0.02192662, -0.00337636,  0.02204296,
          0.02771452,  0.0342248 ,  0.00142591,  0.0136618 ,  0.00158552,
          0.01320035,  0.0052134 , -0.03961967, -0.01420508,  0.00930639],
        dtype=float32),
  array([-0.01254899,  0.01254895], dtype=float32)])
In [13]:
W = get_parameters(model)[0]
B = get_parameters(model)[1]
In [14]:
# BASELINE SCENARIO
#buid the model as base line for the shards (sequential)
# Number of peers
#accordin to what we need
ss = int(len(X_train)/n_peers)
inputs_in = X_train[0*ss:0*ss+ss]
outputs_in = y_train[0*ss:0*ss+ss]
def build_model(X_t, y_t):
    model = Sequential()
    model.add(Dense(70, input_dim=Features_number, activation='relu'))
    model.add(Dense(50, activation='relu'))
    model.add(Dense(50, activation='relu'))
    model.add(Dense(2, activation='softmax'))
    model.compile(loss='categorical_crossentropy', optimizer='adam',metrics=['accuracy'])
    model.fit(X_t,
              y_t,        
              batch_size=32, 
              epochs=250, 
              verbose=1,
              validation_data=((X_test, y_test)))
    return model
In [15]:
display(model.summary())
Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_1 (Dense)              (None, 70)                910       
_________________________________________________________________
dense_2 (Dense)              (None, 50)                3550      
_________________________________________________________________
dense_3 (Dense)              (None, 50)                2550      
_________________________________________________________________
dense_4 (Dense)              (None, 2)                 102       
=================================================================
Total params: 7,112
Trainable params: 7,112
Non-trainable params: 0
_________________________________________________________________
None
In [16]:
# predict probabilities for test set
yhat_probs = model.predict(X_test, verbose=0)
# predict crisp classes for test set
yhat_classes = model.predict_classes(X_test, verbose=0)
In [17]:
# accuracy: (tp + tn) / (p + n)
accuracy = accuracy_score(np.argmax(y_test, axis=1), np.argmax(model.predict(X_test), axis=1))
print('Accuracy: %f' % accuracy)
# precision tp / (tp + fp)
precision = precision_score(np.argmax(y_test, axis=1), np.argmax(model.predict(X_test), axis=1))
print('Precision: %f' % precision)
# recall: tp / (tp + fn)
recall = recall_score(np.argmax(y_test, axis=1), np.argmax(model.predict(X_test), axis=1))
print('Recall: %f' % recall)
# f1: 2 tp / (2 tp + fp + fn)
f1 = f1_score(np.argmax(y_test, axis=1), np.argmax(model.predict(X_test), axis=1))
print('F1 score: %f' % f1)
Accuracy: 0.836071
Precision: 0.791075
Recall: 0.483871
F1 score: 0.600462
In [18]:
# confusion matrix
mat = confusion_matrix(np.argmax(y_test, axis=1), np.argmax(model.predict(X_test), axis=1))

display(mat)
plt.matshow(mat);
plt.colorbar()
plt.show()
array([[2257,  103],
       [ 416,  390]], dtype=int64)
No description has been provided for this image
In [19]:
# the dectinary
FI_dic1= {0:[],1:[],2:[],3:[],4:[],5:[],6:[],7:[],8:[],9:[]}
In [ ]:
# select aa random peer to be the scanner peer
peers_selected=random.sample(range(n_peers), 1)
scaner = peers_selected[0]

# Percentage and number of peers participating at each global training epoch
percentage_participants = 1.0
n_participants = int(n_peers * percentage_participants)

# Number of global training epochs
n_rounds = 10
start_attack_round = 4
end_attack_round = 7
# Number of local training epochs per global training epoch
n_local_rounds = 5

# Local batch size
local_batch_size = 32

# Local learning rate
local_lr = 0.001

# Global learning rate or 'gain'
model_substitution_rate = 1.0

# Attack detection / prevention mechanism = {None, 'distance', 'median', 'accuracy', 'krum'}
discard_outliers = None

# Used in 'dist' attack detection, defines how far the outliers are (1.5 is a typical value)
tau = 1.5

# Used in 'accuracy' attack detection, defines the error margin for the accuracy improvement
sensitivity = 0.05

# Used in 'krum' attack detection, defines how many byzantine attackers we want to defend against
tolerance=4

# Prevent suspicious peers from participating again, only valid for 'dist' and 'accuracy'
ban_malicious = False

# Clear nans and infinites in model updates
clear_nans = True

number_for_threshold1 = numpy.empty(20, dtype=float)
number_for_threshold2 = numpy.empty(20, dtype=float)
for r in range(len(number_for_threshold1)):
    number_for_threshold1[r] = 0
    number_for_threshold2[r] = 0

########################
# ATTACK CONFIGURATION #
########################

# Percentage of malicious peers
r_malicious_peers = 0.0

# Number of malicious peers (absolute or relative to total number of peers)
n_malicious_peers = int(n_peers * r_malicious_peers)
#n_malicious_peers = 1

# Malicious peers
malicious_peer = range(n_malicious_peers)

# Target for coalitions
common_attack_target = [4,7]

# Target class of the attack, per each malicious peer
malicious_targets = dict([(p, t) for p,t in zip(malicious_peer, [common_attack_target]*n_malicious_peers)])

# Boosting parameter per each malicious peer
common_malicious_boost = 12
malicious_boost = dict([(p, b) for p,b in zip(malicious_peer, [common_malicious_boost]*n_malicious_peers)])

###########
# METRICS #
###########
metrics = {'accuracy': [],
          'atk_effectivity': [],
          'update_distances': [],
          'outliers_detected': [],

          'acc_no_target': []}

####################################
# MODEL AND NETWORK INITIALIZATION #
####################################
inputs = X_train[0*ss:0*ss+ss]
outputs = y_train[0*ss:0*ss+ss]
global_model = build_model(inputs,outputs)
n_layers = len(trainable_layers(global_model))

print('Initializing network.')
sleep(1)
network = []
for i in tqdm(range(n_peers)):
    ss = int(len(X_train)/n_peers)
    inputs = X_train[i*ss:i*ss+ss]
    outputs = y_train[i*ss:i*ss+ss]
#     network.append(build_model(inputs, outputs))
    network.append(global_model)


banned_peers = set()

##################
# BEGIN TRAINING #
##################
for t in range(n_rounds):
    print(f'Round {t+1}.')
    sleep(1)

    ## SERVER SIDE #################################################################
    # Fetch global model parameters
    global_weights, global_biases = get_parameters(global_model)

    if clear_nans:
        global_weights, global_biases = nans_to_zero(global_weights, global_biases)

    # Initialize peer update lists
    network_weight_updates = []
    network_bias_updates = []

    # Selection of participant peers in this global training epoch
    if ban_malicious:
        good_peers = list([p for i,p in enumerate(network) if i not in banned_peers])
        n_participants = n_participants if n_participants <= len(good_peers) else int(len(good_peers) * percentage_participants)
        participants = random.sample(list(enumerate(good_peers)), n_participants)
    else:
        participants = random.sample(list(enumerate(network)),n_participants)
    ################################################################################


    ## CLIENT SIDE #################################################################
    for i, local_model in tqdm(participants):

        # Update local model with global parameters 
        set_parameters(local_model, global_weights, global_biases)

        # Initialization of user data
        ss = int(len(X_train)/n_peers)
        inputs = X_train[i*ss:i*ss+ss]
        outputs = y_train[i*ss:i*ss+ss]

# the scanner peer side
        if(i == scaner):
            X_train_local, X_test_local, y_train_local, y_test_local = train_test_split(inputs,outputs, test_size=0.7, random_state=rs)
            inputs = X_train_local
            outputs = y_train_local
            if(t == 0):
                forest = build_forest(X_train_local,y_train_local)
            forest_predictions = forest.predict(X_test_local)
            acc_forest = np.mean([t==p for t,p in zip(y_test_local, forest_predictions)])
            FL_predict1 = global_model.predict(X_test_local)
            imp = scan_wrong(forest_predictions, FL_predict1, forest , y_test_local, X_test_local)
            FI_dic1[t] = imp


 # Benign peer
                # Train local model 
        local_weight_updates, local_bias_updates = get_updates(local_model, 
                                                                       inputs, outputs, 
                                                                       local_batch_size, n_local_rounds)
        if clear_nans:
            local_weight_updates, local_bias_updates = nans_to_zero(local_weight_updates, local_bias_updates)
        network_weight_updates.append(local_weight_updates)
        network_bias_updates.append(local_bias_updates)

    ## END OF CLIENT SIDE ##########################################################

    ######################################
    # SERVER SIDE AGGREGATION MECHANISMS #
    ######################################


        # Aggregate client updates
    aggregated_weights, aggregated_biases = aggregate(n_layers, 
                                                      n_participants, 
                                                      np.mean, 
                                                      network_weight_updates, 
                                                      network_bias_updates)

    if clear_nans:
        aggregated_weights, aggregated_biases = nans_to_zero(aggregated_weights, aggregated_biases)

    # Apply updates to global model
    apply_updates(global_model, model_substitution_rate, aggregated_weights, aggregated_biases)

    # Proceed as in first case
    aggregated_weights, aggregated_biases = aggregate(n_layers, 
                                                      n_participants, 
                                                      np.mean, 
                                                      network_weight_updates, 
                                                      network_bias_updates)
    if clear_nans:
        aggregated_weights, aggregated_biases = nans_to_zero(aggregated_weights, aggregated_biases)

    apply_updates(global_model, model_substitution_rate, aggregated_weights, aggregated_biases)

    ###################
    # COMPUTE METRICS #
    ###################

    # Global model accuracy
    score = global_model.evaluate(X_test, y_test, verbose=0)
    print(f'Global model loss: {score[0]}; global model accuracy: {score[1]}')
    metrics['accuracy'].append(score[1])


    # Accuracy without the target
    score = global_model.evaluate(X_test, y_test, verbose=0)
    metrics['acc_no_target'].append(score[1])


    # Distance of individual updates to the final aggregation
    metrics['update_distances'].append([dist_weights(aggregated_weights, w_i) for w_i in network_weight_updates])
Train on 420 samples, validate on 3166 samples
Epoch 1/250
420/420 [==============================] - ETA: 1s - loss: 0.8020 - accuracy: 0.15 - 0s 475us/step - loss: 0.6159 - accuracy: 0.6571 - val_loss: 0.5505 - val_accuracy: 0.7454
Epoch 2/250
420/420 [==============================] - ETA: 0s - loss: 0.4634 - accuracy: 0.81 - 0s 221us/step - loss: 0.5195 - accuracy: 0.7667 - val_loss: 0.5165 - val_accuracy: 0.7454
Epoch 3/250
420/420 [==============================] - ETA: 0s - loss: 0.4718 - accuracy: 0.78 - 0s 197us/step - loss: 0.4819 - accuracy: 0.7667 - val_loss: 0.4850 - val_accuracy: 0.7454
Epoch 4/250
420/420 [==============================] - ETA: 0s - loss: 0.4996 - accuracy: 0.78 - 0s 202us/step - loss: 0.4604 - accuracy: 0.7667 - val_loss: 0.4648 - val_accuracy: 0.7454
Epoch 5/250
420/420 [==============================] - ETA: 0s - loss: 0.4819 - accuracy: 0.81 - 0s 216us/step - loss: 0.4435 - accuracy: 0.7690 - val_loss: 0.4511 - val_accuracy: 0.7527
Epoch 6/250
420/420 [==============================] - ETA: 0s - loss: 0.5019 - accuracy: 0.71 - 0s 180us/step - loss: 0.4364 - accuracy: 0.7667 - val_loss: 0.4513 - val_accuracy: 0.7489
Epoch 7/250
420/420 [==============================] - ETA: 0s - loss: 0.4597 - accuracy: 0.71 - 0s 164us/step - loss: 0.4270 - accuracy: 0.7857 - val_loss: 0.4398 - val_accuracy: 0.7748
Epoch 8/250
420/420 [==============================] - ETA: 0s - loss: 0.4168 - accuracy: 0.75 - 0s 178us/step - loss: 0.4183 - accuracy: 0.8048 - val_loss: 0.4361 - val_accuracy: 0.7761
Epoch 9/250
420/420 [==============================] - ETA: 0s - loss: 0.3737 - accuracy: 0.75 - 0s 199us/step - loss: 0.4082 - accuracy: 0.8167 - val_loss: 0.4300 - val_accuracy: 0.7738
Epoch 10/250
420/420 [==============================] - ETA: 0s - loss: 0.4122 - accuracy: 0.84 - 0s 180us/step - loss: 0.4041 - accuracy: 0.8024 - val_loss: 0.4348 - val_accuracy: 0.7798
Epoch 11/250
420/420 [==============================] - ETA: 0s - loss: 0.5593 - accuracy: 0.62 - 0s 185us/step - loss: 0.3933 - accuracy: 0.8167 - val_loss: 0.4230 - val_accuracy: 0.7817
Epoch 12/250
420/420 [==============================] - ETA: 0s - loss: 0.3269 - accuracy: 0.84 - 0s 190us/step - loss: 0.3849 - accuracy: 0.8167 - val_loss: 0.4186 - val_accuracy: 0.7890
Epoch 13/250
420/420 [==============================] - ETA: 0s - loss: 0.4152 - accuracy: 0.87 - 0s 197us/step - loss: 0.3739 - accuracy: 0.8190 - val_loss: 0.4109 - val_accuracy: 0.7922
Epoch 14/250
420/420 [==============================] - ETA: 0s - loss: 0.3175 - accuracy: 0.87 - 0s 212us/step - loss: 0.3681 - accuracy: 0.8357 - val_loss: 0.4049 - val_accuracy: 0.7953
Epoch 15/250
420/420 [==============================] - ETA: 0s - loss: 0.2642 - accuracy: 0.87 - 0s 254us/step - loss: 0.3643 - accuracy: 0.8190 - val_loss: 0.3993 - val_accuracy: 0.8061
Epoch 16/250
420/420 [==============================] - ETA: 0s - loss: 0.3769 - accuracy: 0.84 - 0s 230us/step - loss: 0.3681 - accuracy: 0.8286 - val_loss: 0.3994 - val_accuracy: 0.8016
Epoch 17/250
420/420 [==============================] - ETA: 0s - loss: 0.4769 - accuracy: 0.84 - 0s 214us/step - loss: 0.3530 - accuracy: 0.8381 - val_loss: 0.3918 - val_accuracy: 0.8092
Epoch 18/250
420/420 [==============================] - ETA: 0s - loss: 0.3678 - accuracy: 0.87 - 0s 202us/step - loss: 0.3377 - accuracy: 0.8310 - val_loss: 0.3871 - val_accuracy: 0.8130
Epoch 19/250
420/420 [==============================] - ETA: 0s - loss: 0.3190 - accuracy: 0.81 - 0s 230us/step - loss: 0.3469 - accuracy: 0.8310 - val_loss: 0.4820 - val_accuracy: 0.7738
Epoch 20/250
420/420 [==============================] - ETA: 0s - loss: 0.3204 - accuracy: 0.84 - 0s 199us/step - loss: 0.3986 - accuracy: 0.8167 - val_loss: 0.4357 - val_accuracy: 0.7713
Epoch 21/250
420/420 [==============================] - ETA: 0s - loss: 0.4518 - accuracy: 0.75 - 0s 211us/step - loss: 0.3504 - accuracy: 0.8286 - val_loss: 0.4117 - val_accuracy: 0.7979
Epoch 22/250
420/420 [==============================] - ETA: 0s - loss: 0.4504 - accuracy: 0.81 - 0s 197us/step - loss: 0.3356 - accuracy: 0.8357 - val_loss: 0.3779 - val_accuracy: 0.8222
Epoch 23/250
420/420 [==============================] - ETA: 0s - loss: 0.3286 - accuracy: 0.81 - 0s 181us/step - loss: 0.3188 - accuracy: 0.8357 - val_loss: 0.4099 - val_accuracy: 0.7982
Epoch 24/250
420/420 [==============================] - ETA: 0s - loss: 0.5356 - accuracy: 0.81 - 0s 178us/step - loss: 0.3277 - accuracy: 0.8310 - val_loss: 0.3724 - val_accuracy: 0.8247
Epoch 25/250
420/420 [==============================] - ETA: 0s - loss: 0.2666 - accuracy: 0.87 - 0s 192us/step - loss: 0.3120 - accuracy: 0.8476 - val_loss: 0.3793 - val_accuracy: 0.8181
Epoch 26/250
420/420 [==============================] - ETA: 0s - loss: 0.1937 - accuracy: 0.90 - 0s 191us/step - loss: 0.3137 - accuracy: 0.8595 - val_loss: 0.3741 - val_accuracy: 0.8272
Epoch 27/250
420/420 [==============================] - ETA: 0s - loss: 0.3485 - accuracy: 0.81 - 0s 228us/step - loss: 0.3016 - accuracy: 0.8571 - val_loss: 0.3725 - val_accuracy: 0.8241
Epoch 28/250
420/420 [==============================] - ETA: 0s - loss: 0.1122 - accuracy: 1.00 - 0s 261us/step - loss: 0.2970 - accuracy: 0.8548 - val_loss: 0.3731 - val_accuracy: 0.8222
Epoch 29/250
420/420 [==============================] - ETA: 0s - loss: 0.2215 - accuracy: 0.90 - 0s 197us/step - loss: 0.3014 - accuracy: 0.8476 - val_loss: 0.3781 - val_accuracy: 0.8174
Epoch 30/250
420/420 [==============================] - ETA: 0s - loss: 0.2079 - accuracy: 0.84 - 0s 204us/step - loss: 0.2942 - accuracy: 0.8548 - val_loss: 0.3813 - val_accuracy: 0.8181
Epoch 31/250
420/420 [==============================] - ETA: 0s - loss: 0.3559 - accuracy: 0.75 - 0s 225us/step - loss: 0.2938 - accuracy: 0.8429 - val_loss: 0.3857 - val_accuracy: 0.8105
Epoch 32/250
420/420 [==============================] - ETA: 0s - loss: 0.4537 - accuracy: 0.78 - 0s 239us/step - loss: 0.3120 - accuracy: 0.8500 - val_loss: 0.3785 - val_accuracy: 0.8184
Epoch 33/250
420/420 [==============================] - ETA: 0s - loss: 0.1957 - accuracy: 0.90 - 0s 229us/step - loss: 0.3323 - accuracy: 0.8500 - val_loss: 0.4259 - val_accuracy: 0.7997
Epoch 34/250
420/420 [==============================] - ETA: 0s - loss: 0.2729 - accuracy: 0.87 - 0s 212us/step - loss: 0.2830 - accuracy: 0.8738 - val_loss: 0.3710 - val_accuracy: 0.8282
Epoch 35/250
420/420 [==============================] - ETA: 0s - loss: 0.4034 - accuracy: 0.78 - 0s 245us/step - loss: 0.2836 - accuracy: 0.8643 - val_loss: 0.3741 - val_accuracy: 0.8272
Epoch 36/250
420/420 [==============================] - ETA: 0s - loss: 0.2914 - accuracy: 0.87 - 0s 219us/step - loss: 0.2785 - accuracy: 0.8667 - val_loss: 0.3696 - val_accuracy: 0.8256
Epoch 37/250
420/420 [==============================] - ETA: 0s - loss: 0.2081 - accuracy: 0.90 - 0s 252us/step - loss: 0.2737 - accuracy: 0.8690 - val_loss: 0.3721 - val_accuracy: 0.8244
Epoch 38/250
420/420 [==============================] - ETA: 0s - loss: 0.2847 - accuracy: 0.90 - 0s 185us/step - loss: 0.2932 - accuracy: 0.8762 - val_loss: 0.3984 - val_accuracy: 0.8149
Epoch 39/250
420/420 [==============================] - ETA: 0s - loss: 0.2068 - accuracy: 0.96 - 0s 275us/step - loss: 0.2746 - accuracy: 0.8786 - val_loss: 0.3747 - val_accuracy: 0.8260
Epoch 40/250
420/420 [==============================] - ETA: 0s - loss: 0.2561 - accuracy: 0.87 - 0s 235us/step - loss: 0.2991 - accuracy: 0.8524 - val_loss: 0.3819 - val_accuracy: 0.8253
Epoch 41/250
420/420 [==============================] - ETA: 0s - loss: 0.2118 - accuracy: 0.90 - 0s 188us/step - loss: 0.2920 - accuracy: 0.8429 - val_loss: 0.3901 - val_accuracy: 0.8143
Epoch 42/250
420/420 [==============================] - ETA: 0s - loss: 0.1978 - accuracy: 0.90 - 0s 230us/step - loss: 0.2979 - accuracy: 0.8667 - val_loss: 0.4308 - val_accuracy: 0.8080
Epoch 43/250
420/420 [==============================] - ETA: 0s - loss: 0.3652 - accuracy: 0.75 - 0s 233us/step - loss: 0.2874 - accuracy: 0.8571 - val_loss: 0.3759 - val_accuracy: 0.8266
Epoch 44/250
420/420 [==============================] - ETA: 0s - loss: 0.2414 - accuracy: 0.90 - 0s 211us/step - loss: 0.2671 - accuracy: 0.8643 - val_loss: 0.3861 - val_accuracy: 0.8152
Epoch 45/250
420/420 [==============================] - ETA: 0s - loss: 0.1696 - accuracy: 0.96 - 0s 211us/step - loss: 0.2889 - accuracy: 0.8738 - val_loss: 0.3877 - val_accuracy: 0.8234
Epoch 46/250
420/420 [==============================] - ETA: 0s - loss: 0.2738 - accuracy: 0.84 - 0s 220us/step - loss: 0.2635 - accuracy: 0.8738 - val_loss: 0.3914 - val_accuracy: 0.8222
Epoch 47/250
 32/420 [=>............................] - ETA: 0s - loss: 0.2366 - accuracy: 0.9062
In [ ]:
# sort the feature according to the last epoch and print it with importances

sort_index = np.argsort(FI_dic1[9])
for x in sort_index:
    print(names[x], ', ', FI_dic1[9][x])
In [ ]:
 
In [ ]:
 
In [ ]: