120 KiB
120 KiB
In [1]:
#IMPORTS import numpy as np import random import tensorflow as tf import tensorflow.keras as kr import tensorflow.keras.backend as K from tensorflow.keras.models import Model from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, Flatten, Dense from tensorflow.keras.datasets import mnist import os import csv from scipy.spatial.distance import euclidean from sklearn.metrics import confusion_matrix from time import sleep from tqdm import tqdm import copy import numpy from sklearn.datasets import make_classification from sklearn.ensemble import RandomForestClassifier import pandas as pd import matplotlib.pyplot as plt import math import seaborn as sns from numpy.random import RandomState import scipy as scp from sklearn.model_selection import train_test_split from sklearn.compose import ColumnTransformer from sklearn.preprocessing import OneHotEncoder, LabelEncoder from keras.models import Sequential from keras.layers import Dense from keras import optimizers from keras.callbacks import EarlyStopping,ModelCheckpoint from keras.utils import to_categorical from keras import backend as K from itertools import product from sklearn.metrics import accuracy_score from sklearn.metrics import precision_score from sklearn.metrics import recall_score from sklearn.metrics import f1_score from sklearn.metrics import roc_auc_score from sklearn.metrics import confusion_matrix from sklearn import mixture from mpl_toolkits.mplot3d import Axes3D import matplotlib.pyplot as plt %matplotlib inline
Using TensorFlow backend.
In [2]:
# Enter here the data set you want to explain (adult, activity, or synthatic) data_set = 'adult' # Enter here the numb er of peers you want in the experiments n_peers = 100
In [3]:
# the random state we will use in the experiments. It can be changed rs = RandomState(92)
In [4]:
# preprocessing adults data set if data_set == 'adult': #Load dataset into a pandas DataFrame adult_data = pd.read_csv('adult_data.csv', na_values='?') # Drop all records with missing values adult_data.dropna(inplace=True) adult_data.reset_index(drop=True, inplace=True) # Drop fnlwgt, not interesting for ML adult_data.drop('fnlwgt', axis=1, inplace=True) adult_data.drop('education', axis=1, inplace=True) # merging some similar features. adult_data['marital-status'].replace('Married-civ-spouse', 'Married', inplace=True) adult_data['marital-status'].replace('Divorced', 'Unmarried', inplace=True) adult_data['marital-status'].replace('Never-married', 'Unmarried', inplace=True) adult_data['marital-status'].replace('Separated', 'Unmarried', inplace=True) adult_data['marital-status'].replace('Widowed', 'Unmarried', inplace=True) adult_data['marital-status'].replace('Married-spouse-absent', 'Married', inplace=True) adult_data['marital-status'].replace('Married-AF-spouse', 'Married', inplace=True) adult_data = pd.concat([adult_data,pd.get_dummies(adult_data['income'], prefix='income')],axis=1) adult_data.drop('income', axis=1, inplace=True) obj_columns = adult_data.select_dtypes(['object']).columns adult_data[obj_columns] = adult_data[obj_columns].astype('category') # Convert numerics to floats and normalize num_columns = adult_data.select_dtypes(['int64']).columns adult_data[num_columns] = adult_data[num_columns].astype('float64') for c in num_columns: #adult[c] -= adult[c].mean() #adult[c] /= adult[c].std() adult_data[c] = (adult_data[c] - adult_data[c].min()) / (adult_data[c].max()-adult_data[c].min()) # 'workclass', 'marital-status', 'occupation', 'relationship' ,'race', 'gender', 'native-country' # adult_data['income'] = adult_data['income'].cat.codes adult_data['marital-status'] = adult_data['marital-status'].cat.codes adult_data['occupation'] = adult_data['occupation'].cat.codes adult_data['relationship'] = adult_data['relationship'].cat.codes adult_data['race'] = adult_data['race'].cat.codes adult_data['gender'] = adult_data['gender'].cat.codes adult_data['native-country'] = adult_data['native-country'].cat.codes adult_data['workclass'] = adult_data['workclass'].cat.codes num_columns = adult_data.select_dtypes(['int8']).columns adult_data[num_columns] = adult_data[num_columns].astype('float64') for c in num_columns: #adult[c] -= adult[c].mean() #adult[c] /= adult[c].std() adult_data[c] = (adult_data[c] - adult_data[c].min()) / (adult_data[c].max()-adult_data[c].min()) display(adult_data.info()) display(adult_data.head(10)) adult_data = adult_data.to_numpy() # splite the data to train and test datasets X_train, X_test, y_train, y_test = train_test_split(adult_data[:,:-2],adult_data[:,-2:], test_size=0.07, random_state=rs) # the names of the features names = ['age','workclass','educational-num','marital-status','occupation', 'relationship','race','gender','capital-gain','capital-loss','hours-per-week','native-country'] Features_number = len(X_train[0])
<class 'pandas.core.frame.DataFrame'> RangeIndex: 45222 entries, 0 to 45221 Data columns (total 14 columns): age 45222 non-null float64 workclass 45222 non-null float64 educational-num 45222 non-null float64 marital-status 45222 non-null float64 occupation 45222 non-null float64 relationship 45222 non-null float64 race 45222 non-null float64 gender 45222 non-null float64 capital-gain 45222 non-null float64 capital-loss 45222 non-null float64 hours-per-week 45222 non-null float64 native-country 45222 non-null float64 income_<=50K 45222 non-null uint8 income_>50K 45222 non-null uint8 dtypes: float64(12), uint8(2) memory usage: 4.2 MB
None
age | workclass | educational-num | marital-status | occupation | relationship | race | gender | capital-gain | capital-loss | hours-per-week | native-country | income_<=50K | income_>50K | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 0.109589 | 0.333333 | 0.400000 | 1.0 | 0.461538 | 0.6 | 0.5 | 1.0 | 0.000000 | 0.0 | 0.397959 | 0.95 | 1 | 0 |
1 | 0.287671 | 0.333333 | 0.533333 | 0.0 | 0.307692 | 0.0 | 1.0 | 1.0 | 0.000000 | 0.0 | 0.500000 | 0.95 | 1 | 0 |
2 | 0.150685 | 0.166667 | 0.733333 | 0.0 | 0.769231 | 0.0 | 1.0 | 1.0 | 0.000000 | 0.0 | 0.397959 | 0.95 | 0 | 1 |
3 | 0.369863 | 0.333333 | 0.600000 | 0.0 | 0.461538 | 0.0 | 0.5 | 1.0 | 0.076881 | 0.0 | 0.397959 | 0.95 | 0 | 1 |
4 | 0.232877 | 0.333333 | 0.333333 | 1.0 | 0.538462 | 0.2 | 1.0 | 1.0 | 0.000000 | 0.0 | 0.295918 | 0.95 | 1 | 0 |
5 | 0.630137 | 0.666667 | 0.933333 | 0.0 | 0.692308 | 0.0 | 1.0 | 1.0 | 0.031030 | 0.0 | 0.316327 | 0.95 | 0 | 1 |
6 | 0.095890 | 0.333333 | 0.600000 | 1.0 | 0.538462 | 0.8 | 1.0 | 0.0 | 0.000000 | 0.0 | 0.397959 | 0.95 | 1 | 0 |
7 | 0.520548 | 0.333333 | 0.200000 | 0.0 | 0.153846 | 0.0 | 1.0 | 1.0 | 0.000000 | 0.0 | 0.091837 | 0.95 | 1 | 0 |
8 | 0.657534 | 0.333333 | 0.533333 | 0.0 | 0.461538 | 0.0 | 1.0 | 1.0 | 0.064181 | 0.0 | 0.397959 | 0.95 | 0 | 1 |
9 | 0.260274 | 0.000000 | 0.800000 | 0.0 | 0.000000 | 0.0 | 1.0 | 1.0 | 0.000000 | 0.0 | 0.397959 | 0.95 | 1 | 0 |
In [5]:
if data_set == 'synthatic': #generate the data X, y = make_classification(n_samples=1000000, n_features=10, n_redundant=3, n_repeated=2, #n_classes=3, n_informative=5, n_clusters_per_class=4, random_state=42) y = pd.DataFrame(data=y, columns=["y"]) y = pd.get_dummies(y['y'], prefix='y') y = y.to_numpy() X_train, X_test, y_train, y_test = train_test_split(X,y, test_size=0.07, random_state=rs) # the names of the features names = ['X(0)','X(1)','X(2)','X(3)','X(4)','X(5)','X(6)','X(7)','X(8)','X(9)'] Features_number = len(X_train[0])
In [6]:
if data_set == 'activity': #Load dataset into a pandas DataFrame activity = pd.read_csv("activity_3_original.csv", sep=',') # drop some features that have non value in the majority of the samples to_drop = ['subject', 'timestamp', 'heart_rate','activityID'] activity.drop(axis=1, columns=to_drop, inplace=True) # prepare the truth activity = pd.concat([activity,pd.get_dummies(activity['motion'], prefix='motion')],axis=1) activity.drop('motion', axis=1, inplace=True) class_label = [ 'motion_n', 'motion_y'] predictors = [a for a in activity.columns.values if a not in class_label] for p in predictors: activity[p].fillna(activity[p].mean(), inplace=True) display(predictors) for p in predictors: activity[p] = (activity[p]-activity[p].min()) / (activity[p].max() - activity[p].min()) activity[p].astype('float32') activity = activity.to_numpy() X_train, X_test, y_train, y_test = train_test_split(activity[:,:-2],activity[:,-2:], test_size=0.07, random_state=rs) # the names of the features names = ['temp_hand','acceleration_16_x_hand', 'acceleration_16_y_hand','acceleration_16_z_hand','acceleration_6_x_hand', 'acceleration_6_y_hand','acceleration_6_z_hand','gyroscope_x_hand','gyroscope_y_hand', 'gyroscope_z_hand','magnetometer_x_hand','magnetometer_y_hand','magnetometer_z_hand', 'temp_chest','acceleration_16_x_chest','acceleration_16_y_chest','acceleration_16_z_chest','acceleration_6_x_chest', 'acceleration_6_y_chest','acceleration_6_z_chest','gyroscope_x_chest','gyroscope_y_chest','gyroscope_z_chest', 'magnetometer_x_chest','magnetometer_y_chest','magnetometer_z_chest','temp_ankle','acceleration_16_x_ankle', 'acceleration_16_y_ankle','acceleration_16_z_ankle','acceleration_6_x_ankle','acceleration_6_y_ankle', 'acceleration_6_z_ankle','gyroscope_x_ankle','gyroscope_y_ankle','gyroscope_z_ankle','magnetometer_x_ankle', 'magnetometer_y_ankle','magnetometer_z_ankle'] Features_number = len(X_train[0])
In [7]:
#begin federated earlystopping = EarlyStopping(monitor = 'val_loss', min_delta = 0.01, patience = 50, verbose = 1, baseline = 2, restore_best_weights = True) checkpoint = ModelCheckpoint('test.h8', monitor='val_loss', mode='min', save_best_only=True, verbose=1) model = Sequential() model.add(Dense(70, input_dim=Features_number, activation='relu')) model.add(Dense(50, activation='relu')) model.add(Dense(50, activation='relu')) model.add(Dense(2, activation='softmax')) model.compile(loss='categorical_crossentropy', optimizer='adam',metrics=['accuracy']) history = model.fit(X_train, y_train, epochs=2, validation_data=(X_test, y_test), callbacks = [checkpoint, earlystopping], shuffle=True)
Train on 42056 samples, validate on 3166 samples Epoch 1/2 42056/42056 [==============================] - ETA: 2:15 - loss: 0.7931 - accuracy: 0.28 - ETA: 5s - loss: 0.5651 - accuracy: 0.6988 - ETA: 3s - loss: 0.5165 - accuracy: 0.72 - ETA: 2s - loss: 0.4896 - accuracy: 0.74 - ETA: 2s - loss: 0.4646 - accuracy: 0.76 - ETA: 2s - loss: 0.4489 - accuracy: 0.77 - ETA: 1s - loss: 0.4412 - accuracy: 0.77 - ETA: 1s - loss: 0.4316 - accuracy: 0.78 - ETA: 1s - loss: 0.4258 - accuracy: 0.78 - ETA: 1s - loss: 0.4224 - accuracy: 0.79 - ETA: 1s - loss: 0.4158 - accuracy: 0.79 - ETA: 1s - loss: 0.4118 - accuracy: 0.79 - ETA: 1s - loss: 0.4073 - accuracy: 0.80 - ETA: 1s - loss: 0.4052 - accuracy: 0.80 - ETA: 1s - loss: 0.4025 - accuracy: 0.80 - ETA: 1s - loss: 0.3992 - accuracy: 0.80 - ETA: 1s - loss: 0.3983 - accuracy: 0.80 - ETA: 0s - loss: 0.3964 - accuracy: 0.80 - ETA: 0s - loss: 0.3930 - accuracy: 0.80 - ETA: 0s - loss: 0.3893 - accuracy: 0.81 - ETA: 0s - loss: 0.3881 - accuracy: 0.81 - ETA: 0s - loss: 0.3873 - accuracy: 0.81 - ETA: 0s - loss: 0.3857 - accuracy: 0.81 - ETA: 0s - loss: 0.3831 - accuracy: 0.81 - ETA: 0s - loss: 0.3811 - accuracy: 0.81 - ETA: 0s - loss: 0.3793 - accuracy: 0.81 - ETA: 0s - loss: 0.3784 - accuracy: 0.81 - ETA: 0s - loss: 0.3764 - accuracy: 0.81 - ETA: 0s - loss: 0.3756 - accuracy: 0.81 - ETA: 0s - loss: 0.3740 - accuracy: 0.82 - ETA: 0s - loss: 0.3726 - accuracy: 0.82 - 2s 41us/step - loss: 0.3720 - accuracy: 0.8217 - val_loss: 0.3452 - val_accuracy: 0.8370 Epoch 00001: val_loss improved from inf to 0.34516, saving model to test.h8 Epoch 2/2 42056/42056 [==============================] - ETA: 3s - loss: 0.5095 - accuracy: 0.84 - ETA: 1s - loss: 0.3444 - accuracy: 0.83 - ETA: 1s - loss: 0.3407 - accuracy: 0.84 - ETA: 1s - loss: 0.3383 - accuracy: 0.83 - ETA: 1s - loss: 0.3365 - accuracy: 0.84 - ETA: 1s - loss: 0.3374 - accuracy: 0.84 - ETA: 1s - loss: 0.3386 - accuracy: 0.84 - ETA: 1s - loss: 0.3357 - accuracy: 0.84 - ETA: 1s - loss: 0.3364 - accuracy: 0.84 - ETA: 1s - loss: 0.3353 - accuracy: 0.84 - ETA: 1s - loss: 0.3351 - accuracy: 0.84 - ETA: 1s - loss: 0.3368 - accuracy: 0.83 - ETA: 1s - loss: 0.3381 - accuracy: 0.83 - ETA: 1s - loss: 0.3396 - accuracy: 0.83 - ETA: 0s - loss: 0.3388 - accuracy: 0.83 - ETA: 0s - loss: 0.3384 - accuracy: 0.83 - ETA: 0s - loss: 0.3393 - accuracy: 0.83 - ETA: 0s - loss: 0.3390 - accuracy: 0.83 - ETA: 0s - loss: 0.3392 - accuracy: 0.83 - ETA: 0s - loss: 0.3400 - accuracy: 0.83 - ETA: 0s - loss: 0.3400 - accuracy: 0.83 - ETA: 0s - loss: 0.3410 - accuracy: 0.83 - ETA: 0s - loss: 0.3406 - accuracy: 0.83 - ETA: 0s - loss: 0.3415 - accuracy: 0.83 - ETA: 0s - loss: 0.3419 - accuracy: 0.83 - ETA: 0s - loss: 0.3421 - accuracy: 0.83 - ETA: 0s - loss: 0.3417 - accuracy: 0.83 - ETA: 0s - loss: 0.3410 - accuracy: 0.83 - ETA: 0s - loss: 0.3415 - accuracy: 0.83 - ETA: 0s - loss: 0.3419 - accuracy: 0.83 - ETA: 0s - loss: 0.3420 - accuracy: 0.83 - ETA: 0s - loss: 0.3411 - accuracy: 0.83 - ETA: 0s - loss: 0.3413 - accuracy: 0.83 - 2s 41us/step - loss: 0.3412 - accuracy: 0.8375 - val_loss: 0.3417 - val_accuracy: 0.8370 Epoch 00002: val_loss improved from 0.34516 to 0.34170, saving model to test.h8
In [8]:
#AUXILIARY METHODS FOR FEDERATED LEARNING # RETURN INDICES TO LAYERS WITH WEIGHTS AND BIASES def trainable_layers(model): return [i for i, layer in enumerate(model.layers) if len(layer.get_weights()) > 0] # RETURN WEIGHTS AND BIASES OF A MODEL def get_parameters(model): weights = [] biases = [] index = trainable_layers(model) for i in index: weights.append(copy.deepcopy(model.layers[i].get_weights()[0])) biases.append(copy.deepcopy(model.layers[i].get_weights()[1])) return weights, biases # SET WEIGHTS AND BIASES OF A MODEL def set_parameters(model, weights, biases): index = trainable_layers(model) for i, j in enumerate(index): model.layers[j].set_weights([weights[i], biases[i]]) # DEPRECATED: RETURN THE GRADIENTS OF THE MODEL AFTER AN UPDATE def get_gradients(model, inputs, outputs): """ Gets gradient of model for given inputs and outputs for all weights""" grads = model.optimizer.get_gradients(model.total_loss, model.trainable_weights) symb_inputs = (model._feed_inputs + model._feed_targets + model._feed_sample_weights) f = K.function(symb_inputs, grads) x, y, sample_weight = model._standardize_user_data(inputs, outputs) output_grad = f(x + y + sample_weight) w_grad = [w for i,w in enumerate(output_grad) if i%2==0] b_grad = [w for i,w in enumerate(output_grad) if i%2==1] return w_grad, b_grad # RETURN THE DIFFERENCE OF MODELS' WEIGHTS AND BIASES AFTER AN UPDATE # NOTE: LEARNING RATE IS APPLIED, SO THE UPDATE IS DIFFERENT FROM THE # GRADIENTS. IN CASE VANILLA SGD IS USED, THE GRADIENTS ARE OBTAINED # AS (UPDATES / LEARNING_RATE) def get_updates(model, inputs, outputs, batch_size, epochs): w, b = get_parameters(model) #model.train_on_batch(inputs, outputs) model.fit(inputs, outputs, batch_size=batch_size, epochs=epochs, verbose=0) w_new, b_new = get_parameters(model) weight_updates = [old - new for old,new in zip(w, w_new)] bias_updates = [old - new for old,new in zip(b, b_new)] return weight_updates, bias_updates # UPDATE THE MODEL'S WEIGHTS AND PARAMETERS WITH AN UPDATE def apply_updates(model, eta, w_new, b_new): w, b = get_parameters(model) new_weights = [theta - eta*delta for theta,delta in zip(w, w_new)] new_biases = [theta - eta*delta for theta,delta in zip(b, b_new)] set_parameters(model, new_weights, new_biases) # FEDERATED AGGREGATION FUNCTION def aggregate(n_layers, n_peers, f, w_updates, b_updates): agg_w = [f([w_updates[j][i] for j in range(n_peers)], axis=0) for i in range(n_layers)] agg_b = [f([b_updates[j][i] for j in range(n_peers)], axis=0) for i in range(n_layers)] return agg_w, agg_b # SOLVE NANS def nans_to_zero(W, B): W0 = [np.nan_to_num(w, nan=0.0, posinf=0.0, neginf=0.0) for w in W] B0 = [np.nan_to_num(b, nan=0.0, posinf=0.0, neginf=0.0) for b in B] return W0, B0 def build_forest(X,y): clf=RandomForestClassifier(n_estimators=1000, max_depth=7, random_state=0, verbose = 1) clf.fit(X,y) return clf # COMPUTE EUCLIDEAN DISTANCE OF WEIGHTS def dist_weights(w_a, w_b): wf_a = flatten_weights(w_a) wf_b = flatten_weights(w_b) return euclidean(wf_a, wf_b) # TRANSFORM ALL WEIGHT TENSORS TO 1D ARRAY def flatten_weights(w_in): h = w_in[0].reshape(-1) for w in w_in[1:]: h = np.append(h, w.reshape(-1)) return h
In [9]:
# scan the forest for trees maches the wrong predictions of the black-box def scan_wrong(forest_predictions, FL_predict1, forest , y_test_local, X_test_local): sum_feature_improtance= 0 overal_wrong_feature_importance = 0 counter = 0 second_counter = 0 never_seen = 0 avr_wrong_importance = 0 FL_predict1 = np.argmax(FL_predict1, axis=1) forest_predictions = np.argmax(forest_predictions, axis=1) y_test_local = np.argmax(y_test_local, axis=1) for i in range (len(FL_predict1)): i_tree = 0 # if the black-box got a wrong prediction if (FL_predict1[i] != y_test_local[i]): # getting the prediction of the trees one by one for tree_in_forest in forest.estimators_: sample = X_test_local[i].reshape(1, -1) temp = forest.estimators_[i_tree].predict(sample) temp = np.argmax(temp, axis=1) i_tree = i_tree + 1 # if the prediction of the tree maches the predictions of the black-box if(FL_predict1[i] == temp): # getting the features importances sum_feature_improtance = sum_feature_improtance + tree_in_forest.feature_importances_ counter = counter + 1 # if we have trees maches the black-box predictions if(counter>0): ave_feature_importence = sum_feature_improtance/counter overal_wrong_feature_importance = ave_feature_importence + overal_wrong_feature_importance second_counter = second_counter + 1 counter = 0 sum_feature_improtance = 0 # if there is no trees maches the black-box predictions else: if(FL_predict1[i] != y_test_local[i]): never_seen = never_seen +1 # getting the average features importances for all the samples that had wrong predictions. if(second_counter>0): avr_wrong_importance = overal_wrong_feature_importance / second_counter return forest.feature_importances_
In [10]:
trainable_layers(model)
Out[10]:
[0, 1, 2, 3]
In [11]:
get_parameters(model)
Out[11]:
([array([[ 1.39432400e-01, 8.84631574e-02, -4.47415888e-01, 1.23670131e-01, -2.65049934e-01, 2.56673127e-01, 2.82177985e-01, -3.88451487e-01, -8.48813355e-02, -4.55360711e-01, -2.55180508e-01, -1.34169891e-01, -4.19932574e-01, 9.50885192e-02, -4.10533138e-02, 1.23161055e-01, -3.34913731e-01, -3.29331495e-02, -2.09537312e-01, 2.89370805e-01, -2.42182449e-01, 9.41318497e-02, -8.54814351e-02, -2.53278345e-01, 7.38841221e-02, 9.76254940e-02, 9.64644551e-03, 4.62163612e-02, 1.47847623e-01, 3.28071006e-02, -2.16738522e-01, -5.52587435e-02, -1.01704948e-01, 2.31297538e-01, -3.01694840e-01, 2.23755836e-02, -2.37541839e-01, -8.33741352e-02, -3.33046556e-01, -3.82800475e-02, -2.60576427e-01, 1.35413051e-01, 5.84374070e-02, -1.67372033e-01, 7.50956163e-02, -2.44477212e-01, 4.34608996e-01, 1.95100769e-01, -1.71157598e-01, 2.94538945e-01, -2.78368771e-01, 3.23733628e-01, 7.93107301e-02, 2.28328109e-01, 6.06352724e-02, -7.03767091e-02, 1.33410409e-01, 1.21751621e-01, 1.97286800e-01, -9.21699479e-02, -3.15490931e-01, 2.30563477e-01, -3.28507647e-02, 8.77456143e-02, 5.48780151e-02, -4.60406430e-02, -1.89183086e-01, 3.93763036e-02, 2.96199113e-01, -2.79987492e-02], [ 3.03673856e-02, 1.97539851e-02, -1.50838614e-01, 1.14162855e-01, 1.80196881e-01, -1.22831225e-01, 7.67074972e-02, -1.13835640e-01, -1.38265222e-01, -6.62374571e-02, 1.81205988e-01, -2.81262010e-01, 1.72400191e-01, 2.07341984e-01, -1.34065270e-01, 6.87680393e-02, -6.93561733e-02, -2.21116617e-01, 1.04925461e-01, 1.02081522e-02, 1.51008025e-01, -2.92544812e-02, -1.05958931e-01, 1.61262244e-01, 1.58383980e-01, -1.24027103e-01, -1.80273309e-01, 2.02690706e-01, 1.30619720e-01, -1.44045368e-01, 5.87314926e-02, -6.84582517e-02, 5.60571887e-02, -1.27603471e-01, 2.20635161e-01, 1.71862170e-01, 1.77298188e-02, 1.31710157e-01, -2.06363559e-01, 1.41939849e-01, 4.67592143e-02, -2.25164890e-01, 2.84170844e-02, -1.87025517e-01, 2.21437346e-02, 2.89680868e-01, 2.44593516e-01, 5.39705567e-02, 1.68798208e-01, -9.17015448e-02, -9.46003050e-02, -8.50451589e-02, -9.65483636e-02, 2.15933964e-01, 3.86347598e-03, -2.29437221e-02, 8.44280720e-02, 1.96231887e-01, 3.78342345e-02, 1.12372516e-02, 7.45132491e-02, -1.45243943e-01, 1.38520822e-01, 1.27623096e-01, 9.93933976e-02, 7.73796961e-02, 1.07909396e-01, 5.35671674e-02, -2.25077912e-01, 1.48774251e-01], [-2.48966157e-01, -1.18819617e-01, 3.78526822e-02, -4.11971584e-02, 5.32225370e-02, 2.79902488e-01, 3.43969136e-01, -5.78653142e-02, 1.67140678e-01, -2.94734612e-02, 1.13698818e-01, -2.92426739e-02, -1.79812416e-01, 2.88941506e-02, -1.41450733e-01, -7.92392809e-03, -1.35528877e-01, -2.56182522e-01, -2.33598545e-01, -5.47329225e-02, 2.58110791e-01, -2.45282829e-01, 4.75647040e-02, 4.78960238e-02, -6.56322390e-02, -6.67297915e-02, 1.69852525e-01, -1.50414899e-01, 2.58721203e-01, 1.14194579e-01, 2.65164256e-01, 8.89386758e-02, 2.67333359e-01, -3.09747636e-01, -1.52420253e-02, 2.57288337e-01, 1.46575630e-01, 8.43582675e-02, 1.89198285e-01, -5.13301976e-02, -1.45431489e-01, 1.83323875e-01, 2.22104147e-01, -7.55850300e-02, 1.44288674e-01, -1.75847083e-01, -1.43846169e-01, 1.33877620e-01, 1.63822114e-01, -1.28378317e-01, -2.10838597e-02, -2.69852519e-01, 1.04066990e-01, 2.06833377e-01, -1.28662705e-01, 1.49911791e-01, -2.75938064e-01, -3.31552997e-02, 2.19017982e-01, 6.46202068e-04, 1.66913256e-01, -1.72089741e-01, 9.96593982e-02, -2.43812397e-01, -8.03031027e-02, -1.92508698e-01, -3.14832121e-01, -9.16534588e-02, -3.15453112e-01, 1.48415402e-01], [-2.35771656e-01, 3.27018127e-02, 1.60873935e-01, -1.28616795e-01, 3.11803758e-01, -2.35472228e-02, -1.39719948e-01, 1.74694061e-02, 7.51914829e-02, 2.35624880e-01, 7.33765140e-02, 2.16503426e-01, 4.06566672e-02, -2.05656707e-01, 1.96258724e-01, 5.99774197e-02, -1.27538797e-02, -6.30170330e-02, -1.16274104e-01, -1.43104732e-01, -1.37973130e-01, -1.91767380e-01, 3.22461128e-01, 2.99887396e-02, 2.64688015e-01, -2.45580390e-01, -2.41390377e-01, -1.29994661e-01, -1.80605844e-01, -2.61187732e-01, 1.44567609e-01, 1.88110307e-01, 1.73101038e-01, 2.86840070e-02, -1.33754045e-01, -5.33887371e-02, -1.13288000e-01, -8.15718770e-02, 2.53453523e-01, -1.54690027e-01, -1.32443011e-02, -6.94180205e-02, -1.20536266e-02, -2.19712891e-02, -2.30549023e-01, 2.46970072e-01, -1.82330459e-02, -1.24268174e-01, 2.66243219e-01, 1.11885495e-01, 8.33856687e-02, -1.06503241e-01, -2.80220248e-02, -1.17930442e-01, 2.08708122e-01, 7.04001710e-02, -1.37973502e-02, 1.89776018e-01, -7.30874389e-02, -2.11521506e-01, 1.42071024e-01, 2.42409576e-02, 8.69186819e-02, 3.34844626e-02, -2.07044452e-01, -1.04645088e-01, 1.51515082e-01, -1.95780490e-02, 2.13911623e-01, 9.59823653e-02], [-2.26251304e-01, -4.98282760e-02, 8.57945010e-02, 1.85095415e-01, 1.94030240e-01, 1.70300901e-01, -1.48310944e-01, -1.68697998e-01, 1.38381734e-01, -8.20567235e-02, 1.35808028e-02, -1.75055087e-01, 2.08388101e-02, -2.22936451e-01, -7.68952891e-02, -4.24526669e-02, 4.03720774e-02, 2.34893888e-01, -1.57926619e-01, -2.40865514e-01, 1.67401552e-01, 2.16235057e-01, -1.50564939e-01, 1.77459866e-01, -1.02011845e-01, 9.56041086e-03, -1.36439502e-01, 1.67499810e-01, 1.46594793e-01, -2.37665162e-03, 2.35330492e-01, -4.87338640e-02, -8.25209543e-02, -7.34776333e-02, 2.11637601e-01, -8.63815099e-02, -2.52601802e-01, -1.03249528e-01, 1.14807218e-01, 1.93410560e-01, -7.48374164e-02, 4.09806073e-02, -1.25015989e-01, 1.75860271e-01, 1.65006757e-01, 1.63865000e-01, 1.56919926e-01, -2.22888529e-01, -3.29164751e-02, 4.06037048e-02, 2.24684268e-01, 1.01046182e-01, -1.53632820e-01, -1.65310353e-01, 4.86176573e-02, -2.46649399e-01, -2.84075760e-03, 1.55264661e-01, 4.27330621e-02, -2.05510065e-01, 1.62713528e-01, -3.14808562e-02, 1.86110288e-01, 6.84845075e-02, 4.47224490e-02, -3.40451181e-01, 1.40326787e-02, 2.19547436e-01, 7.52496868e-02, 1.09770238e-01], [-2.14519277e-01, -1.97733060e-01, -1.04191333e-01, -1.52826672e-02, 1.04496861e-03, -6.56969398e-02, -7.04714730e-02, -1.19291015e-01, 1.01761602e-01, 7.52121955e-02, -2.15532720e-01, -1.47176266e-01, 1.51603609e-01, -1.83050726e-02, -3.25457342e-02, -5.11338934e-02, 1.16198196e-03, -2.66087204e-01, 7.53995031e-02, 7.98415840e-02, 4.19246480e-02, -7.96627849e-02, 1.22839414e-01, 1.80793643e-01, -2.73334742e-01, 5.54925241e-02, 1.19968027e-01, 1.63323641e-01, 1.11940101e-01, -1.46585837e-01, 1.94005132e-01, 1.88561931e-01, -5.62924668e-02, -4.18225750e-02, -1.56423241e-01, -2.25715101e-01, -4.82656956e-02, 2.14031748e-02, 2.10182130e-01, -3.18871409e-01, -7.38589093e-02, -2.32924759e-01, 8.74556080e-02, -1.10086516e-01, 1.84157446e-01, -1.46957889e-01, -1.06122330e-01, 2.88575172e-01, 7.43130967e-02, 1.63028061e-01, 2.40940854e-01, 8.84263813e-02, 1.86871052e-01, -1.03018314e-01, -2.51245052e-02, -2.32590944e-01, 2.58567259e-02, 1.24988005e-01, 4.27892543e-02, 6.42778203e-02, 2.41022035e-01, -5.46587259e-02, -1.77857980e-01, 3.70368622e-02, 2.42744144e-02, 1.84613451e-01, 2.30415717e-01, -1.80632919e-01, -9.84579027e-02, -4.87778150e-02], [-2.97077070e-03, -9.92525965e-02, 9.59780440e-02, -1.05714351e-01, -2.09908143e-01, 2.08500147e-01, -9.31153223e-02, 2.99151987e-01, 4.34016176e-02, -2.24611446e-01, 3.31769064e-02, 2.14490488e-01, -2.24754527e-01, -1.74998924e-01, -4.15243544e-02, -1.69698030e-01, 2.80564696e-01, 1.17882535e-01, -9.80678648e-02, 3.15327570e-03, -2.08990425e-01, 1.49431065e-01, -1.39306724e-01, 2.40346678e-02, 2.40564555e-01, -5.09837978e-02, 2.17804000e-01, 1.35088935e-01, 8.79955664e-02, -5.64928725e-02, 4.61013429e-02, 6.54249862e-02, -8.42749923e-02, 2.62729824e-01, -3.99206020e-02, -1.17483221e-01, -1.40452668e-01, -1.06828704e-01, -1.74000204e-01, -4.49550189e-02, 2.60878950e-01, 2.07423091e-01, -9.15924609e-02, 1.91001654e-01, -1.47255644e-01, 7.95471966e-02, -1.70050204e-01, 5.61165512e-02, -1.48466706e-01, 1.08682081e-01, 2.04737335e-02, -1.74528554e-01, -9.47896019e-02, 1.73530400e-01, -1.12356387e-01, 9.92965326e-02, 1.26004890e-01, -2.32813179e-01, 9.49711502e-02, -2.34253883e-01, -2.76989549e-01, -7.66268969e-02, -3.41671556e-02, -5.10511408e-03, -5.79159260e-02, -6.46380782e-02, -5.50055876e-02, -3.11404735e-01, 2.45275497e-01, -2.22187296e-01], [ 4.54801060e-02, 2.56455511e-01, -1.82633027e-01, -1.01602580e-02, -8.93032998e-02, 1.04237944e-01, 5.84088564e-02, 1.54823989e-01, -1.07336426e-02, 2.69688278e-01, 6.16033142e-03, -6.09616982e-03, 8.98296311e-02, 1.78536490e-01, -1.43777172e-03, -9.94328558e-02, -4.55807038e-02, 9.91010815e-02, -4.42102812e-02, 3.77892517e-02, 1.33471981e-01, 7.44501278e-02, 1.62690468e-02, 2.23104075e-01, -2.61054993e-01, 3.15811366e-01, -2.96082795e-01, 1.78025752e-01, -2.63285220e-01, -5.37474826e-02, -9.58651751e-02, -2.15012103e-01, -4.33603339e-02, -2.60652751e-01, -5.41594252e-02, 2.35952377e-01, -3.74763012e-02, 1.91953376e-01, 1.17158510e-01, 3.78518994e-03, -6.19563572e-02, 2.10780635e-01, 1.62149847e-01, -1.30085796e-01, 1.28252106e-03, 2.28483707e-01, -1.14689972e-02, -8.24389532e-02, -1.77851245e-01, -1.37649611e-01, 1.65123567e-01, 1.03654794e-01, 8.36220309e-02, 1.99557766e-02, 6.00132421e-02, -3.04210056e-02, -2.81973660e-01, -2.42123492e-02, -2.17434868e-01, -9.64278206e-02, -1.85030416e-01, -2.62960136e-01, 5.34782112e-02, 1.58508420e-01, 1.65380761e-01, -3.85079943e-02, 2.55265355e-01, 5.09922206e-02, -1.47566527e-01, 7.40251169e-02], [ 1.29649222e-01, -3.14282179e-02, -6.06167972e-01, -2.50955880e-01, -3.46874207e-01, 7.49993503e-01, 7.28010595e-01, -1.06399655e+00, 1.06234324e+00, -3.55233133e-01, -5.50140023e-01, 1.00409508e+00, -5.45210958e-01, 1.93181217e-01, -7.01776028e-01, -2.10634783e-01, -4.23527777e-01, 3.09440106e-01, -1.91907719e-01, 2.85458267e-01, 7.82932997e-01, -5.32808244e-01, -4.39185768e-01, -7.65542090e-01, -3.82927716e-01, -1.15567505e+00, 1.67764112e-01, -9.48192775e-01, 3.64812821e-01, 2.28667915e-01, 6.75961256e-01, 8.27623010e-01, -6.38736844e-01, -2.00036347e-01, -3.25849533e-01, 9.03906941e-01, -2.68816352e-01, -6.27302647e-01, -3.23336124e-01, 4.70992297e-01, -5.73931932e-01, 9.17997599e-01, 7.42488205e-01, -2.06164107e-01, 2.04111740e-01, -7.19973087e-01, -3.76782537e-01, 8.55549395e-01, -8.38361323e-01, -9.57333803e-01, -5.20633638e-01, -3.67659301e-01, 1.50605768e-01, -7.64182091e-01, -3.19448918e-01, -5.01123592e-02, 1.64251193e-01, -7.17021644e-01, 6.97100699e-01, -1.25111267e-01, -4.96421248e-01, -5.05610764e-01, 1.01232016e+00, -1.09313202e+00, 3.20109189e-01, -1.06782168e-01, -9.03548539e-01, -2.81452984e-01, -2.17785537e-01, 6.68265998e-01], [-2.42571607e-01, 2.04211175e-01, -4.92268875e-02, -1.63620815e-01, 4.04583551e-02, 3.45696330e-01, 3.54173370e-02, 8.10830146e-02, -1.61551312e-03, -1.48698622e-02, 1.94258001e-02, 1.15005746e-01, -1.20848659e-02, 2.12298751e-01, 8.92769620e-02, -7.64900148e-02, -3.41445431e-02, 7.51630887e-02, -3.40494029e-02, 2.70350277e-01, 4.42853682e-02, 5.13006859e-02, 2.81202555e-01, 1.35484681e-01, 7.24686086e-02, -1.34075984e-01, 1.70696169e-01, 8.00305977e-03, 2.56366223e-01, 1.33748680e-01, 2.25041300e-01, -7.13687986e-02, -4.96987440e-02, 3.10503058e-02, -2.25651234e-01, 3.94519985e-01, 2.15304196e-01, 1.15869548e-02, 1.47072956e-01, 3.18337977e-01, -6.86229253e-03, -5.09570874e-02, 2.29824454e-01, 2.61031240e-02, 2.89728433e-01, 3.48875783e-02, -5.50319031e-02, 1.21588549e-02, -4.12969440e-02, 1.07327215e-01, 1.35437414e-01, -2.93096341e-02, 3.36093381e-02, -1.90401971e-01, -2.66215026e-01, 6.08073771e-02, 6.91038072e-02, -1.98440487e-03, 7.31287152e-02, -2.77851731e-01, -1.08341835e-01, -1.85085818e-01, 2.29901448e-01, -2.96091676e-01, 2.23246500e-01, -1.44393981e-01, -1.93921745e-01, -1.92566663e-01, 1.32529914e-01, -1.94337085e-01], [-2.53917843e-01, -2.51892120e-01, -1.32432416e-01, 1.47464365e-01, -3.17318618e-01, 1.97301418e-01, 2.69987226e-01, -4.56497446e-02, 2.30195507e-01, 1.32218450e-02, -4.06064779e-01, 2.51328260e-01, 5.33021428e-02, -2.66608417e-01, -1.79995075e-01, 1.29986405e-01, 2.86205828e-01, 1.35580912e-01, -2.18271971e-01, 1.57579169e-01, 2.17058808e-01, -1.03528440e-01, -4.87327874e-02, -6.85375035e-02, -1.29330382e-01, -1.23507090e-01, 4.80753556e-02, -3.16315889e-01, 2.06642285e-01, -1.25930071e-01, 4.74674441e-03, -2.28398144e-01, 1.07306920e-01, 2.11515024e-01, -1.60666719e-01, 1.23706006e-01, -2.24141285e-01, 5.61789013e-02, 1.76867880e-02, -9.17073190e-02, 8.19897652e-02, -1.55695155e-02, 3.17650735e-01, 2.38761097e-01, 2.45510742e-01, 8.75920355e-02, 3.21321398e-01, -1.12799473e-01, 2.10606474e-02, -1.81161851e-01, 2.40284592e-01, -2.50088274e-01, -2.18976215e-01, 1.12234220e-01, 4.77548651e-02, -4.73017395e-02, 1.37630356e-02, 1.92280307e-01, 7.14965388e-02, -6.21563159e-02, -7.16416389e-02, 1.23388998e-01, 1.82368487e-01, 2.31735930e-01, -2.02105597e-01, 1.42061830e-01, 1.23616353e-01, 1.56008020e-01, -2.19544828e-01, 2.12301493e-01], [-7.18890205e-02, -1.92233965e-01, 2.33305559e-01, 6.87015578e-02, 8.51642191e-02, -2.19545767e-01, 6.00749105e-02, 3.61590572e-02, -6.68269545e-02, 1.48855716e-01, 2.30278343e-01, 2.16507941e-01, 2.22660348e-01, -2.84734219e-01, 2.37847969e-01, -1.55460656e-01, -2.26989180e-01, -6.12188876e-02, 1.77810416e-01, 1.45450696e-01, 2.52608925e-01, -1.36337921e-01, -1.94631949e-01, 1.51410148e-01, 3.44162211e-02, 1.61046118e-01, -1.24759860e-01, 1.83450043e-01, 1.75598450e-02, -2.05802217e-01, -8.66022483e-02, 6.08737469e-02, -2.22572535e-01, -1.20819479e-01, -1.68945014e-01, 2.10285246e-01, -2.40360171e-01, -1.79741889e-01, -1.93881094e-01, 1.32005673e-03, -8.93675536e-02, -1.65670961e-01, -8.00130144e-02, -2.01122567e-01, 1.55159965e-01, 4.84559573e-02, -1.92197278e-01, 1.46897465e-01, -1.71061575e-01, 5.79360016e-02, -1.45457163e-01, 1.78534076e-01, 1.95346713e-01, -9.44947526e-02, -2.78981924e-01, -1.16451114e-01, 1.21675292e-02, -1.05452980e-03, 2.97299847e-02, 1.15553983e-01, -1.47618756e-01, 2.83984572e-01, -9.44054872e-02, -6.82652295e-02, 1.54531911e-01, -9.11844522e-02, 2.69836523e-02, -3.09856743e-01, 6.67436346e-02, 2.40427703e-01]], dtype=float32), array([[ 0.04073317, -0.16170031, 0.08170982, ..., 0.12143514, -0.03804543, -0.1848121 ], [-0.02006259, 0.04184515, 0.20358184, ..., 0.08938669, 0.02554417, -0.0998741 ], [-0.05848309, -0.13393435, 0.28651938, ..., -0.19336581, 0.28697622, -0.18376462], ..., [-0.13138615, -0.10152157, 0.05253223, ..., 0.16827357, 0.09525165, 0.17411834], [-0.00976845, -0.10780089, 0.2228816 , ..., 0.1733975 , -0.10156322, 0.03318954], [ 0.09590832, -0.01828083, 0.12743485, ..., 0.25016934, 0.12800731, -0.10581163]], dtype=float32), array([[-0.00384881, -0.12021059, 0.01248708, ..., 0.01682259, -0.17754331, 0.02930963], [-0.03520177, 0.0117013 , 0.03343487, ..., -0.16231427, 0.1756002 , 0.00351096], [-0.1752005 , 0.004585 , -0.11959553, ..., -0.17236647, 0.28346488, 0.26809448], ..., [ 0.01488994, 0.00250473, -0.25695267, ..., -0.11059541, 0.17581026, -0.23348542], [ 0.21297403, 0.24602796, 0.06359419, ..., 0.205567 , 0.04510517, 0.11687386], [-0.17597616, 0.07059528, 0.10327347, ..., -0.02315794, 0.00959007, -0.01356981]], dtype=float32), array([[-2.82930046e-01, 1.26908660e-01], [ 2.37486243e-01, -3.81716669e-01], [ 9.92978290e-02, -3.47963899e-01], [ 3.02352726e-01, -3.74164760e-01], [-2.05417976e-01, 2.52470911e-01], [-4.55201864e-02, -2.02432677e-01], [ 1.73006430e-01, -4.46816646e-02], [-2.84130216e-01, -2.26977065e-01], [-4.35910234e-03, 3.76744062e-01], [ 1.45330116e-01, -3.25348943e-01], [ 2.28147835e-01, -2.77784109e-01], [-1.19501755e-01, 4.07545753e-02], [ 1.01264335e-01, -2.43342578e-01], [-1.60477936e-01, 5.24386704e-01], [ 6.06849305e-02, 9.89513546e-02], [-2.89398909e-01, 1.83537304e-01], [ 1.01001307e-01, 2.95499355e-01], [-2.97017217e-01, 3.22097719e-01], [ 1.97861195e-01, -2.02269956e-01], [-7.52512068e-02, 9.88621786e-02], [-1.38137221e-01, -4.40452248e-01], [-2.33402535e-01, 1.64692864e-01], [ 1.27101064e-01, 1.98759794e-01], [-3.01784992e-01, 2.12811917e-01], [-1.96352318e-01, 1.54295802e-01], [ 2.49975443e-01, -2.01082289e-01], [-1.38984874e-01, 2.29037121e-01], [ 1.06105595e-04, 2.89339125e-01], [-3.00384670e-01, -4.83968072e-02], [-1.86271910e-02, 3.05029899e-01], [ 1.99009106e-03, 2.14236692e-01], [-3.34532440e-01, -6.11541159e-02], [-1.35282487e-01, 5.65957166e-02], [-2.68078923e-01, 1.56603098e-01], [-2.48180240e-01, 2.94318020e-01], [-1.30787000e-01, -4.86690477e-02], [ 2.72127450e-01, 1.19140044e-01], [ 2.90248722e-01, -1.86683103e-01], [ 9.85520706e-02, -2.18973175e-01], [-3.67538421e-03, -1.40206725e-03], [ 2.09687546e-01, 2.38504097e-01], [-1.00464404e-01, 2.42502570e-01], [ 1.55400500e-01, 1.01416796e-01], [-3.02865952e-01, 7.42565766e-02], [-1.75600335e-01, 3.34860444e-01], [-1.38489887e-01, 2.21242890e-01], [ 1.80740595e-01, 2.85507560e-01], [ 2.81139612e-01, -1.64098963e-01], [ 1.56524777e-01, -3.87664348e-01], [ 4.04402643e-01, 3.33227254e-02]], dtype=float32)], [array([-0.02854579, 0.05311754, 0.04959053, -0.03715162, 0.02330466, -0.04269688, -0.01013005, 0.06874033, -0.01974296, 0.06377108, 0.01915232, -0.02997592, 0.04385599, -0.03190788, 0.0822746 , -0.0659938 , 0.05614723, -0.05916027, 0.00469705, -0.07926938, -0.04140961, 0.04683878, 0.05930374, 0.03659062, 0.01063251, 0.08476436, -0.05785139, 0.06547377, -0.06716973, -0.03636409, -0.07686226, -0.06849793, 0.05797694, 0.02435303, 0.03509652, -0.05251152, 0. , 0.08996248, 0.03742762, -0.01523749, 0.03550566, -0.03804714, -0.00484502, -0.00379997, -0.08697824, 0.05449862, 0.0753291 , -0.04323955, 0.06112291, 0.07690092, -0.0228012 , 0.0299318 , -0.00519692, -0.02068757, -0.00509608, -0.05934853, 0.01074862, 0.04968895, -0.06797037, -0.03621136, 0.04376265, 0.05112915, -0.01860299, 0.1174278 , -0.07478895, -0.03162573, 0.03712742, -0.02407979, 0.01463216, -0.06149809], dtype=float32), array([-0.00573702, 0.02881279, 0.08348639, 0.05104946, 0.00677452, 0.07581051, -0.00896564, 0.05683671, -0.04194446, -0.03258895, 0.02584886, -0.01876919, -0.03003295, 0.06011844, -0.03459567, -0.02480574, -0.02663354, -0.01426572, 0.04364663, 0.08130559, 0.06495807, -0.04167927, 0.05315804, 0.06356885, -0.0253933 , 0.04731547, -0.05747719, 0.05308496, -0.05659075, 0.05597835, -0.04872588, -0.04101903, 0.03206719, 0.05267171, -0.05795552, 0.02027026, -0.05280019, 0.05442906, -0.01780812, 0.06429685, -0.04762284, 0.06730605, -0.05093784, 0.05504497, 0.05357518, -0.02114536, 0.07897805, 0.02339305, 0.08138859, -0.03255979], dtype=float32), array([-0.02713361, 0.09309322, 0.03805388, 0.03424142, -0.01113803, 0.01667055, 0.07447569, -0.06668621, -0.03150655, 0.06387825, 0.07785708, -0.02531729, 0.0723519 , 0.00347189, -0.0177357 , 0. , -0.03487411, -0.04851071, 0.06759044, -0.03533144, 0.02401838, -0.03073125, -0.04450166, -0.02654641, -0.02827855, 0. , -0.01567897, -0.0040624 , -0.02975524, -0.03281569, -0.02176444, -0.03663797, -0.01836163, -0.01098274, -0.05059136, -0.02361586, 0.06188901, 0.07352342, 0.05501923, -0.04410602, -0.04001745, -0.05807052, 0.05206716, -0.01867674, -0.0089961 , -0.01759916, -0.01311922, -0.01491712, 0.05937983, 0.06336498], dtype=float32), array([ 0.05128358, -0.05128355], dtype=float32)])
In [12]:
get_updates(model, X_train, y_train, 32, 2)
Out[12]:
([array([[ 1.49011612e-06, -1.16565228e-02, 7.42741823e-02, 1.84727460e-02, -8.54203105e-03, -1.07147992e-02, -1.21345818e-02, 7.30311871e-02, -9.66983289e-03, 6.97897077e-02, 3.34345400e-02, 8.74996185e-05, 8.57096016e-02, 1.60385072e-02, 1.19651444e-02, 2.49370933e-03, 8.43369067e-02, 1.64110810e-02, 2.05034018e-02, -1.07030272e-02, 1.41942352e-02, 8.08978826e-03, 3.21836621e-02, 5.65186143e-02, -4.39171195e-02, 4.30863351e-03, -1.15206484e-02, 2.10180618e-02, -7.25415349e-03, 3.23727727e-06, -3.62303257e-02, 1.32649988e-02, 6.42685592e-03, 1.92415118e-02, 3.42182815e-02, 1.04672834e-02, 0.00000000e+00, 2.78882086e-02, 4.19589877e-02, -2.18840521e-02, 9.14855003e-02, 4.81577218e-03, -5.83276898e-03, 1.10103935e-02, -1.27351061e-02, 1.04498863e-02, -4.18264568e-02, 2.71596760e-02, 2.05704421e-02, -3.03956270e-02, 3.23033035e-02, -4.21546996e-02, 4.45771776e-02, -3.91504467e-02, -8.15168023e-02, 2.26370320e-02, -4.53699529e-02, -4.22071815e-02, -1.31393373e-02, 1.32917091e-02, 3.45049202e-02, -2.07801014e-02, -3.54883261e-03, 1.25047117e-02, -6.88902661e-03, -3.12848352e-02, 1.91906095e-02, 4.19499911e-02, -6.82501495e-02, -1.53118074e-02], [ 1.26473606e-06, 3.19551006e-02, -9.95661318e-03, 1.19727850e-02, -1.79602206e-03, 1.53238177e-02, 3.48728746e-02, -6.26268983e-03, 1.16313845e-02, -1.24692582e-02, 7.92261958e-03, 1.70102119e-02, -6.51185215e-03, 1.54842287e-02, -1.05841383e-02, 6.68349117e-03, -5.42246550e-03, 2.38585621e-02, -4.80242074e-03, 4.91370745e-02, 1.68562233e-02, 1.78126954e-02, 1.90734789e-02, 1.65916681e-02, 1.61990523e-03, -8.57291371e-03, 2.71531343e-02, -7.27659464e-03, 4.15879637e-02, 2.78651714e-06, -5.89160994e-03, -5.41770346e-02, 2.21652985e-02, 2.85971761e-02, 2.64423043e-02, 2.02812403e-02, 0.00000000e+00, 4.61198390e-03, 1.03888065e-02, -1.25160664e-02, 6.73976541e-03, 7.16526806e-03, 2.04948001e-02, 3.80229056e-02, 5.23998961e-03, 1.47194266e-02, -3.79353762e-04, 1.02847666e-02, 2.61461437e-02, -2.40194798e-02, 6.51396215e-02, -3.63333039e-02, -5.51320203e-02, 6.67399764e-02, 2.30744481e-04, -1.98918562e-02, -1.48451179e-02, 2.16108412e-02, 2.36949362e-02, 1.34854689e-02, 2.18952857e-02, -3.03417742e-02, 1.44592375e-02, -5.22305071e-03, 2.16116756e-02, -1.71370506e-02, -1.07840151e-02, 2.18297653e-02, -4.98308837e-02, 6.17059022e-02], [ 7.59959221e-07, 6.49914891e-03, 6.02702424e-03, 4.44378480e-02, -3.70854139e-03, -4.58493829e-03, -6.07848167e-03, 1.07879266e-02, -7.88053870e-03, 5.81801683e-03, -5.74702024e-03, -1.39639974e-02, -6.83215261e-03, 4.36457917e-02, 3.14892232e-02, 7.88080785e-03, -8.65940750e-03, 2.13787258e-02, -4.84359264e-03, 3.76113243e-02, -7.68777728e-03, 3.44041884e-02, 2.17271540e-02, 3.49786282e-02, -9.57245380e-03, 1.17718354e-02, 1.76328421e-03, 2.47428566e-02, -1.89929008e-02, 9.83476639e-07, -4.87717986e-03, 1.81090236e-02, -1.68472528e-03, 1.75802708e-02, 1.03006195e-02, 2.29674578e-03, 0.00000000e+00, 1.71143487e-02, -5.93899190e-03, -8.52180272e-03, 2.21078098e-03, -6.65974617e-03, -5.27688861e-03, 1.78252906e-03, 1.92168504e-02, 3.20638865e-02, 3.12381685e-02, -9.20593739e-04, 2.36202478e-02, 1.44775957e-02, -2.05530357e-02, 1.06063485e-03, 3.32839042e-03, -1.01231039e-03, -2.06442177e-03, 9.97217000e-03, 2.65306234e-03, 1.02972686e-02, -7.31389225e-03, 8.04215856e-03, 9.98052955e-03, -1.05124712e-03, -6.54879957e-03, 2.02297121e-02, 1.26888603e-02, 2.11532712e-02, 8.01965594e-03, -1.44157931e-02, -1.14605129e-02, 5.96144795e-03], [ 0.00000000e+00, -1.39323249e-02, 6.39429688e-03, -6.28607944e-02, -8.64857435e-03, 1.24468859e-02, 2.86395848e-03, 8.28726962e-03, 1.22120306e-02, 3.18306237e-02, -6.85551018e-03, 2.63845623e-02, 2.07465943e-02, 0.00000000e+00, 2.27561444e-02, 1.00681484e-02, 2.27013733e-02, -6.71585724e-02, -7.02583045e-03, -1.04030967e-02, -4.20713425e-03, 1.82433575e-02, -1.60830319e-02, 1.62253212e-02, -1.33093297e-02, -3.11521888e-02, 3.77973914e-03, 5.21981716e-03, -1.98365897e-02, 0.00000000e+00, 2.18344629e-02, -2.21273601e-02, 3.62999588e-02, 8.26938823e-03, -7.49810040e-03, -1.21858753e-02, 0.00000000e+00, -1.79555565e-02, 6.84709847e-03, -5.80060855e-02, 5.37337959e-02, -3.77775021e-02, -7.55639747e-04, -2.18640100e-02, -4.38468456e-02, 1.06569529e-02, -2.04435959e-02, 1.75144523e-02, -3.92878056e-03, -3.37126106e-03, 9.41401571e-02, 2.34664902e-02, 3.57846729e-02, -1.89120620e-02, -3.53061408e-02, -1.61360130e-02, -7.82630816e-02, 2.25571841e-02, -3.70850414e-03, 0.00000000e+00, 1.72483921e-03, -3.43789682e-02, -6.79399073e-03, 8.89310054e-03, -6.08799458e-02, -2.71661580e-03, -3.56471390e-02, 1.11350343e-02, -2.56258100e-02, 2.05378458e-02], [ 2.23517418e-06, -7.50984997e-02, 1.20643079e-02, 2.09403038e-02, 1.78195983e-02, 1.29758865e-02, 9.08046961e-04, -7.36726820e-03, 8.75070691e-03, -3.93393636e-03, 4.93140984e-03, -7.30901957e-03, -2.97645107e-03, 4.18809950e-02, 1.70973763e-02, 5.04964590e-03, -1.19856521e-02, -5.31809032e-03, 9.80186462e-03, 6.25750273e-02, 7.91132450e-04, 5.25092632e-02, 1.74276680e-02, 2.94222236e-02, -9.10501927e-03, 1.67332776e-03, 2.11818963e-02, 1.82696581e-02, -2.05858052e-03, 1.21816993e-06, -2.51303613e-02, -7.00031966e-03, -1.90818682e-02, 2.38134339e-02, 1.85852498e-02, 1.88816562e-02, 0.00000000e+00, 2.37748027e-03, 1.05057806e-02, -5.73440939e-02, -5.06321341e-03, 1.27147976e-02, -9.73396003e-04, 9.90636647e-03, 2.64340937e-02, 1.75562948e-02, 1.66450143e-02, 1.51741505e-03, 2.09401064e-02, 2.11877432e-02, -1.54795647e-02, 1.00296140e-02, -3.15607935e-02, -2.11823583e-02, 5.22204116e-02, 6.60809129e-02, -2.77949665e-02, 1.30063444e-02, 1.25313178e-03, 2.72946656e-02, 4.89727184e-02, -1.69895999e-02, 1.30629092e-02, 1.55981630e-03, 3.25212553e-02, 2.51948833e-04, -5.40875457e-03, -1.81352496e-02, -5.66543639e-03, 2.94677094e-02], [ 0.00000000e+00, -4.97145504e-02, -3.21335346e-03, 2.53131799e-02, 4.13774513e-03, -1.40788406e-03, -5.41490316e-03, -3.77743170e-02, 5.08022308e-03, 1.75616145e-03, 1.51727945e-02, 1.46141499e-02, 4.79898602e-02, 5.34517616e-02, 1.71546526e-02, 3.00733373e-03, 2.73067039e-02, 4.28619981e-03, 1.95392966e-03, -1.62802637e-03, -1.72183998e-02, -9.32161510e-03, 8.53136927e-03, -1.32551491e-02, -2.45124102e-04, 2.85785496e-02, -4.13239896e-02, 9.71779227e-03, -4.31602448e-03, 0.00000000e+00, 2.40202248e-03, -5.27902842e-02, -3.45261432e-02, -1.48351621e-02, 1.39084011e-02, 1.61831826e-02, 0.00000000e+00, -5.43668903e-02, 7.43016601e-03, 1.10875398e-01, -9.89424437e-03, 1.02454871e-02, -2.66520679e-03, -3.40092704e-02, -2.39770263e-02, -4.27880883e-02, -1.02666572e-01, -4.84505296e-03, -3.08043435e-02, 2.15946883e-02, 7.27008581e-02, -1.56747103e-02, 3.66818905e-03, 2.17209607e-02, -1.88962929e-02, -1.54630542e-02, 3.92973498e-02, 3.84222344e-02, 2.19388902e-02, 1.04343221e-02, 3.98382545e-03, 2.20829248e-03, 1.71164274e-02, 6.02347218e-03, -4.65963036e-03, 5.45533001e-03, 1.20884478e-02, 1.58099681e-02, 3.72115970e-02, 4.13580202e-02], [ 5.59026375e-07, 2.19073892e-02, 1.40187293e-02, 4.17544991e-02, 7.14531541e-03, 1.93578005e-03, 3.68960202e-04, 9.27004218e-03, -7.46112317e-04, 7.42256641e-03, 6.71111792e-03, 1.33904815e-02, -7.87302852e-03, 1.38419718e-02, 2.20113285e-02, 2.13821232e-02, -3.75646353e-03, 2.36315280e-02, 6.75320625e-05, 2.17337422e-02, 3.00095975e-03, 2.45792195e-02, 1.44614577e-02, 2.65318509e-02, -1.13605559e-02, 8.07042047e-03, 3.87509167e-03, 2.12808922e-02, 6.70775771e-04, 7.59959221e-07, -5.00490516e-03, 2.46562809e-03, -7.56940991e-03, 1.34187639e-02, 4.02834602e-02, 1.11031756e-02, 0.00000000e+00, -2.19111145e-03, 2.61621177e-03, 1.26020499e-02, -5.91659546e-03, 4.12940979e-03, 7.18377531e-04, 1.28328502e-02, 9.83218849e-03, 1.39725059e-02, 7.12452829e-03, 9.86279920e-03, 1.29076242e-02, 1.15619823e-02, -2.26510987e-02, -1.01644099e-02, 6.79536909e-03, 3.28451395e-03, -2.67382041e-02, 3.20950896e-03, 2.13438272e-03, 2.84881890e-03, -1.82493776e-03, 2.74327397e-03, 5.69045544e-04, 1.74677297e-02, 2.59828940e-03, 9.38987918e-03, 2.89438292e-02, -1.68662779e-02, 2.16373429e-03, 5.52284718e-02, 1.71898305e-03, 1.84112042e-02], [ 1.92970037e-06, -7.33059645e-03, 6.08241558e-03, 3.09040304e-02, -2.66182497e-02, -1.04110688e-03, 6.55319542e-03, 1.30551010e-02, -6.43594936e-03, 3.14751267e-03, -4.20112303e-03, -3.34899174e-03, 1.19088590e-03, 2.75004655e-02, 1.39071923e-02, -3.07884067e-02, -6.91840798e-03, 3.75329219e-02, -9.75349918e-03, 1.89966261e-02, 1.39560252e-02, 1.82078369e-02, 2.44169421e-02, 3.45724225e-02, 3.83076072e-03, 2.96819210e-03, -2.79713571e-02, 1.44217312e-02, -1.04773343e-02, 2.71573663e-06, -6.85906410e-03, -2.69274563e-02, -1.63219962e-02, 2.06331909e-02, 1.40391737e-02, 9.01226699e-03, 0.00000000e+00, -8.03788006e-03, 2.00668350e-02, 1.17101632e-02, 8.09775665e-03, 1.32986456e-02, 5.69504499e-03, 3.66747677e-02, -3.86754144e-03, 2.89282054e-02, 3.69293764e-02, -2.47304142e-03, 1.71321183e-02, -7.54578412e-03, -9.61725414e-03, 1.61448121e-03, -1.06247291e-02, 3.18420455e-02, 3.38531211e-02, 3.97837535e-02, -5.06854057e-03, 4.11833264e-03, -2.18693763e-02, 0.00000000e+00, 4.73356694e-02, 2.79366970e-02, 1.81355327e-03, 1.18001401e-02, 1.80805475e-02, -4.14552540e-03, -4.44880128e-03, -3.96236032e-03, -2.36448124e-02, 1.94880292e-02], [ 6.82026148e-05, 7.18571916e-02, 4.76970673e-02, 9.57600772e-02, 1.31565988e-01, -1.64669454e-01, -7.16286302e-02, 1.19898200e-01, -2.11336374e-01, 1.21684730e-01, 1.44014955e-01, -2.38232017e-01, 1.46218359e-01, 4.44871187e-03, 7.22973943e-02, 2.78884321e-02, 1.45830631e-01, -1.08749241e-01, -3.77449691e-02, -1.96678042e-02, -1.76536024e-01, 5.42484522e-02, 6.87512457e-02, 2.47801125e-01, 1.44790769e-01, 1.97795749e-01, 4.98535782e-02, 2.33573496e-01, 1.99592710e-02, 2.87592411e-06, -1.66886806e-01, -1.45491302e-01, 2.27394998e-01, 8.62296224e-02, -2.82013416e-03, -1.88887715e-01, 0.00000000e+00, 6.12580180e-02, 1.33518666e-01, -1.46177143e-01, 1.59401894e-02, -2.11909533e-01, -1.65744841e-01, 1.35906681e-01, -2.10718960e-02, 1.91993415e-01, 2.53243625e-01, -1.76544130e-01, 4.48338389e-02, 1.66558981e-01, 4.43497002e-01, 9.72419083e-02, -1.00000083e-01, 3.86097550e-01, -1.01794243e-01, -5.31688072e-02, -7.42531270e-02, 2.60491490e-01, -1.89655364e-01, -1.26332268e-02, 5.77271283e-02, 1.28030062e-01, -1.77092314e-01, 2.60912895e-01, -1.63142890e-01, -2.03584060e-02, 1.95915878e-01, -1.84553564e-02, 1.37363464e-01, -1.53825760e-01], [ 0.00000000e+00, 5.06502837e-02, 1.29767060e-02, 6.28483295e-02, 6.34282753e-02, -1.72892511e-02, -9.57809016e-03, -1.54778808e-02, 9.92524205e-04, -7.11901113e-03, -1.05080698e-02, 1.87457874e-02, 3.44372615e-02, -4.49787974e-02, 6.05929643e-03, -8.81178081e-02, -3.13648731e-02, -3.27146053e-03, 1.48571506e-02, 9.23416018e-03, -1.23183951e-02, 4.48894389e-02, 5.51292300e-02, -4.26419824e-02, 2.51580812e-02, 5.67226112e-03, -1.00748777e-01, -2.07852945e-02, -2.21131444e-02, 0.00000000e+00, -2.12397873e-02, 1.61559135e-03, 1.88499428e-02, 8.83308202e-02, 0.00000000e+00, -2.65170336e-02, 0.00000000e+00, -3.47553790e-02, 3.28133181e-02, -1.98322237e-02, -5.88746481e-02, 3.99494730e-02, -2.77045369e-03, -2.44742259e-04, -3.06502879e-02, -2.40243226e-03, 4.12266403e-02, -4.80247736e-02, 3.03203985e-02, -2.57963911e-02, -3.67290676e-02, 6.71202391e-02, -8.20160732e-02, 2.72179991e-02, 2.75232196e-02, 2.28866339e-02, -4.97791916e-03, 3.20631042e-02, 2.17096210e-02, 0.00000000e+00, -4.09361497e-02, 2.31957436e-02, -2.09520608e-02, -1.32357776e-02, 5.95888197e-02, 0.00000000e+00, -8.28909650e-02, 4.27012593e-02, 2.20479891e-02, 2.66924053e-02], [ 2.80141830e-06, 3.44573855e-02, 1.68616474e-02, -2.16292441e-02, 3.42520177e-02, -2.51318514e-03, -2.99223959e-02, 3.11351269e-02, -7.86145031e-03, 3.79217640e-02, 5.88734746e-02, -2.47791409e-03, -1.02755167e-02, 1.48909986e-02, 6.38209134e-02, -2.27537304e-02, -2.68034339e-02, 8.43681395e-03, 5.44354320e-04, 2.46977806e-03, -6.86615705e-03, 5.15806824e-02, 3.51931527e-02, 4.50379923e-02, 1.82978213e-02, 1.10631064e-02, 5.33567518e-02, 5.02201617e-02, -2.68985629e-02, 2.14576721e-06, 3.81427701e-03, 2.03454792e-02, -1.38287246e-02, -2.03619152e-02, 3.11780423e-02, 1.26634762e-02, 0.00000000e+00, 2.77237929e-02, 9.62861814e-03, -1.84376612e-02, 9.24696773e-03, -5.45406435e-03, -1.83334649e-02, 7.35372305e-04, -1.68629438e-02, 8.47994536e-03, -1.79683268e-02, 1.69715211e-02, 1.91382784e-02, 4.71077561e-02, -4.38696444e-02, 5.99594712e-02, 7.32099563e-02, -4.90160286e-02, 2.35909577e-02, 1.54711753e-02, 1.36633702e-02, -4.24712151e-02, -2.49780267e-02, 7.48580322e-03, 6.65621459e-03, -2.29308978e-02, -2.68799514e-02, 1.97568089e-02, 6.91704005e-02, 2.21450403e-02, -2.95889378e-03, -7.42284954e-03, 3.10672224e-02, 7.04184175e-04], [ 9.68575478e-08, 3.99114043e-02, -9.91953909e-03, 3.25712711e-02, -1.46806389e-02, 5.75292110e-03, 4.84618917e-03, 5.86069934e-03, -4.93802130e-04, 1.27205700e-02, -1.29961967e-03, 6.42070174e-03, -1.86905414e-02, 3.09723020e-02, 1.03429109e-02, 5.13049960e-03, -5.01155853e-04, 3.36712301e-02, -3.53804231e-03, 3.75385135e-02, 8.72188807e-03, 1.75718218e-02, 3.87762487e-03, 3.15307751e-02, -1.69710331e-02, 7.96510279e-03, 1.68320760e-02, 2.14911848e-02, 3.29448655e-03, 4.91738319e-07, 9.02945548e-03, 1.97583884e-02, -3.59129906e-03, 3.39994431e-02, 1.95574164e-02, 1.80864036e-02, 0.00000000e+00, 1.42792165e-02, 4.98805940e-03, 3.40463128e-03, 3.10290605e-03, -1.59697235e-03, 3.81151587e-03, -1.06487721e-02, 2.39797831e-02, 1.12190805e-02, 1.06314570e-02, 1.27038062e-02, 5.77275455e-03, 4.94579226e-03, -7.09898770e-03, 1.74231827e-03, 1.20699406e-06, -1.24374777e-03, 6.46207333e-02, 2.14003474e-02, 1.49126686e-02, 6.48828223e-03, -5.38637117e-03, 1.02209002e-02, 7.54408538e-03, 7.72580504e-03, 3.43549997e-03, 7.72102922e-03, 2.00577080e-02, 4.22081202e-02, -5.91963530e-04, 1.37665868e-02, -9.87207890e-03, 2.15003788e-02]], dtype=float32), array([[-2.91615725e-05, 0.00000000e+00, 0.00000000e+00, ..., 1.14738941e-05, 0.00000000e+00, 0.00000000e+00], [ 8.49024765e-03, -4.31102738e-02, -5.05789816e-02, ..., 1.43139213e-02, -3.98727357e-02, -8.07711482e-03], [ 1.00968346e-01, 7.57245272e-02, -2.73677707e-03, ..., -3.34357619e-02, -2.99723148e-02, 3.60149145e-03], ..., [ 2.19667554e-02, 1.55268461e-02, 1.16532966e-02, ..., -6.69102073e-02, -4.36145514e-02, 8.49020481e-03], [ 4.64867577e-02, 2.68688127e-02, -7.21535087e-03, ..., -1.35802627e-02, 1.19180009e-02, 1.85405612e-02], [-4.64454293e-04, 2.76498124e-02, -5.06965816e-03, ..., 2.36923993e-03, 1.80526227e-02, 2.38161311e-02]], dtype=float32), array([[-0.03158482, 0.03222603, 0.02458179, ..., 0.08041045, 0.01049814, -0.04392172], [ 0.01048962, -0.01012598, -0.01415333, ..., -0.13836852, 0.00485021, -0.01224236], [-0.00740135, -0.00507368, 0.00139559, ..., -0.13694564, -0.00293422, 0.00983405], ..., [ 0.00256901, -0.00454452, 0.00040635, ..., 0.00017413, 0.0033875 , 0.03571756], [-0.01601908, 0.00801374, -0.00454768, ..., -0.10811353, 0.01376003, -0.00697925], [-0.0469159 , 0.01431713, 0.0028784 , ..., 0.02210182, -0.01271321, -0.04000308]], dtype=float32), array([[-1.15624368e-02, 1.15615055e-02], [ 7.26526976e-03, -7.26565719e-03], [ 2.03531086e-02, -2.03537643e-02], [-6.14860654e-03, 6.14863634e-03], [-1.56270713e-02, 1.56272203e-02], [-1.02237239e-03, 1.02218986e-03], [ 1.03584975e-02, -1.03584863e-02], [ 2.06144452e-02, -2.06139088e-02], [-1.49596017e-04, 1.48773193e-04], [ 4.79128957e-03, -4.79125977e-03], [ 6.30669296e-03, -6.30635023e-03], [ 0.00000000e+00, 0.00000000e+00], [-3.16215307e-03, 3.16265225e-03], [ 5.85643947e-02, -5.85644245e-02], [-7.11206347e-04, 7.11120665e-04], [-1.32675469e-02, 1.32675469e-02], [-4.61159647e-03, 4.61125374e-03], [ 1.90031528e-03, -1.90034509e-03], [ 1.55715942e-02, -1.55715793e-02], [-1.19959861e-02, 1.19959936e-02], [-1.39654800e-02, 1.39658749e-02], [-5.74159622e-03, 5.74158132e-03], [-2.31015980e-02, 2.31016129e-02], [-3.14235687e-04, 3.14325094e-04], [ 2.00573653e-02, -2.00573802e-02], [ 0.00000000e+00, 0.00000000e+00], [-1.67117715e-02, 1.67115927e-02], [-9.77398362e-03, 9.77447629e-03], [ 0.00000000e+00, 0.00000000e+00], [-1.44827366e-02, 1.44827068e-02], [-4.94149094e-03, 4.94125485e-03], [-7.87574053e-03, 7.87599012e-03], [-1.95778906e-03, 1.95809081e-03], [ 5.55017591e-02, -5.55018783e-02], [ 1.07816160e-02, -1.07815862e-02], [-3.71644944e-02, 3.71643975e-02], [-2.17199326e-04, 2.16580927e-04], [ 5.87260723e-03, -5.87230921e-03], [-2.86758766e-02, 2.86761075e-02], [-2.14453358e-02, 2.14453079e-02], [-1.89587176e-02, 1.89587921e-02], [-2.95607746e-03, 2.95570493e-03], [ 1.51866078e-02, -1.51871741e-02], [-1.82390213e-05, 1.85519457e-05], [-3.47673893e-04, 3.47673893e-04], [ 2.12892145e-02, -2.12893784e-02], [-3.14503908e-04, 3.14474106e-04], [-2.15784311e-02, 2.15792507e-02], [-3.74348462e-03, 3.74361873e-03], [-4.07821834e-02, 4.07828614e-02]], dtype=float32)], [array([ 1.93715096e-06, -7.33081996e-03, -2.59644464e-02, 3.53913940e-02, -3.17766666e-02, 2.61158086e-02, 3.09046637e-02, -1.39724910e-02, 1.64157562e-02, -1.33834183e-02, -2.69608200e-02, 1.18504055e-02, -3.39354798e-02, 2.74997652e-02, 8.21013749e-03, 3.14356387e-03, -1.39056407e-02, 3.89015339e-02, -4.97761788e-03, 4.34829444e-02, 1.99493878e-02, 1.66720618e-02, -8.93015787e-03, 1.38194822e-02, -3.19721177e-02, -1.15122199e-02, 3.29295881e-02, -1.12637132e-03, 2.69653797e-02, 2.71573663e-06, 2.13105604e-02, 8.05605203e-03, -1.68812610e-02, 6.44891895e-03, 1.83254965e-02, 3.12887095e-02, 0.00000000e+00, -1.13615617e-02, -1.91472992e-02, 2.73335315e-02, -8.05212930e-03, 1.09619126e-02, 2.75111161e-02, 7.87547231e-03, 5.21017611e-02, 8.26306641e-04, 4.48312610e-03, 1.90423653e-02, -9.32561979e-03, -1.31163374e-02, -7.42557272e-03, -8.75856541e-03, -3.20619484e-03, 1.95884481e-02, -3.87861580e-02, 7.03234226e-04, -4.51306719e-03, 1.04681477e-02, 9.19631869e-03, 4.49468940e-03, 8.59957188e-04, -1.29299201e-02, 2.50575244e-02, -6.41628355e-03, 1.74099207e-02, 6.38409331e-03, -1.43194981e-02, -5.27505018e-03, -2.04372257e-02, 3.64710018e-02], dtype=float32), array([ 0.00331618, 0.0539567 , -0.03811157, -0.01109656, 0.02418597, -0.02698364, -0.00161759, -0.02607356, 0.00603188, 0.0166943 , 0.00161197, 0.01395666, -0.02311181, -0.02586061, 0.01232279, 0. , 0.00035155, 0.01375218, 0.01308896, -0.01829513, -0.01323594, 0.01266809, -0.01157204, -0.01712865, 0.00102041, -0.00915188, 0.00994506, 0.0003348 , 0.01999653, -0.01448206, 0.02454722, -0.00036955, 0.01865685, -0.00514613, 0.01885039, -0.02837751, 0.00375455, -0.04659929, 0.00275468, 0.00854283, 0.02362405, -0.02947872, 0.01506701, -0.00749737, -0.02634757, 0.02495424, -0.02921787, 0.01843503, -0.00016499, 0.01456888], dtype=float32), array([ 0.00049266, -0.02066331, -0.0207266 , -0.01927541, 0.02330254, 0.00480025, -0.0293951 , 0.01040582, 0.00090722, -0.0259526 , -0.02241635, 0. , -0.02288882, 0.02825862, 0.00242692, 0.02985683, 0.00974142, 0.00933962, -0.02005748, 0.01471977, 0.00262542, 0.00572324, 0.02363591, 0.00187776, 0.01032351, 0. , -0.01178178, -0.01856739, 0. , 0.01882949, -0.00219416, 0.01156303, 0.00352357, 0.01064696, 0.02187355, 0.00582501, -0.02582753, -0.02192662, -0.00337636, 0.02204296, 0.02771452, 0.0342248 , 0.00142591, 0.0136618 , 0.00158552, 0.01320035, 0.0052134 , -0.03961967, -0.01420508, 0.00930639], dtype=float32), array([-0.01254899, 0.01254895], dtype=float32)])
In [13]:
W = get_parameters(model)[0] B = get_parameters(model)[1]
In [14]:
# BASELINE SCENARIO #buid the model as base line for the shards (sequential) # Number of peers #accordin to what we need ss = int(len(X_train)/n_peers) inputs_in = X_train[0*ss:0*ss+ss] outputs_in = y_train[0*ss:0*ss+ss] def build_model(X_t, y_t): model = Sequential() model.add(Dense(70, input_dim=Features_number, activation='relu')) model.add(Dense(50, activation='relu')) model.add(Dense(50, activation='relu')) model.add(Dense(2, activation='softmax')) model.compile(loss='categorical_crossentropy', optimizer='adam',metrics=['accuracy']) model.fit(X_t, y_t, batch_size=32, epochs=250, verbose=1, validation_data=((X_test, y_test))) return model
In [15]:
display(model.summary())
Model: "sequential_1" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= dense_1 (Dense) (None, 70) 910 _________________________________________________________________ dense_2 (Dense) (None, 50) 3550 _________________________________________________________________ dense_3 (Dense) (None, 50) 2550 _________________________________________________________________ dense_4 (Dense) (None, 2) 102 ================================================================= Total params: 7,112 Trainable params: 7,112 Non-trainable params: 0 _________________________________________________________________
None
In [16]:
# predict probabilities for test set yhat_probs = model.predict(X_test, verbose=0) # predict crisp classes for test set yhat_classes = model.predict_classes(X_test, verbose=0)
In [17]:
# accuracy: (tp + tn) / (p + n) accuracy = accuracy_score(np.argmax(y_test, axis=1), np.argmax(model.predict(X_test), axis=1)) print('Accuracy: %f' % accuracy) # precision tp / (tp + fp) precision = precision_score(np.argmax(y_test, axis=1), np.argmax(model.predict(X_test), axis=1)) print('Precision: %f' % precision) # recall: tp / (tp + fn) recall = recall_score(np.argmax(y_test, axis=1), np.argmax(model.predict(X_test), axis=1)) print('Recall: %f' % recall) # f1: 2 tp / (2 tp + fp + fn) f1 = f1_score(np.argmax(y_test, axis=1), np.argmax(model.predict(X_test), axis=1)) print('F1 score: %f' % f1)
Accuracy: 0.836071 Precision: 0.791075 Recall: 0.483871 F1 score: 0.600462
In [18]:
# confusion matrix mat = confusion_matrix(np.argmax(y_test, axis=1), np.argmax(model.predict(X_test), axis=1)) display(mat) plt.matshow(mat); plt.colorbar() plt.show()
array([[2257, 103], [ 416, 390]], dtype=int64)
In [19]:
# the dectinary FI_dic1= {0:[],1:[],2:[],3:[],4:[],5:[],6:[],7:[],8:[],9:[]}
In [ ]:
# select aa random peer to be the scanner peer peers_selected=random.sample(range(n_peers), 1) scaner = peers_selected[0] # Percentage and number of peers participating at each global training epoch percentage_participants = 1.0 n_participants = int(n_peers * percentage_participants) # Number of global training epochs n_rounds = 10 start_attack_round = 4 end_attack_round = 7 # Number of local training epochs per global training epoch n_local_rounds = 5 # Local batch size local_batch_size = 32 # Local learning rate local_lr = 0.001 # Global learning rate or 'gain' model_substitution_rate = 1.0 # Attack detection / prevention mechanism = {None, 'distance', 'median', 'accuracy', 'krum'} discard_outliers = None # Used in 'dist' attack detection, defines how far the outliers are (1.5 is a typical value) tau = 1.5 # Used in 'accuracy' attack detection, defines the error margin for the accuracy improvement sensitivity = 0.05 # Used in 'krum' attack detection, defines how many byzantine attackers we want to defend against tolerance=4 # Prevent suspicious peers from participating again, only valid for 'dist' and 'accuracy' ban_malicious = False # Clear nans and infinites in model updates clear_nans = True number_for_threshold1 = numpy.empty(20, dtype=float) number_for_threshold2 = numpy.empty(20, dtype=float) for r in range(len(number_for_threshold1)): number_for_threshold1[r] = 0 number_for_threshold2[r] = 0 ######################## # ATTACK CONFIGURATION # ######################## # Percentage of malicious peers r_malicious_peers = 0.0 # Number of malicious peers (absolute or relative to total number of peers) n_malicious_peers = int(n_peers * r_malicious_peers) #n_malicious_peers = 1 # Malicious peers malicious_peer = range(n_malicious_peers) # Target for coalitions common_attack_target = [4,7] # Target class of the attack, per each malicious peer malicious_targets = dict([(p, t) for p,t in zip(malicious_peer, [common_attack_target]*n_malicious_peers)]) # Boosting parameter per each malicious peer common_malicious_boost = 12 malicious_boost = dict([(p, b) for p,b in zip(malicious_peer, [common_malicious_boost]*n_malicious_peers)]) ########### # METRICS # ########### metrics = {'accuracy': [], 'atk_effectivity': [], 'update_distances': [], 'outliers_detected': [], 'acc_no_target': []} #################################### # MODEL AND NETWORK INITIALIZATION # #################################### inputs = X_train[0*ss:0*ss+ss] outputs = y_train[0*ss:0*ss+ss] global_model = build_model(inputs,outputs) n_layers = len(trainable_layers(global_model)) print('Initializing network.') sleep(1) network = [] for i in tqdm(range(n_peers)): ss = int(len(X_train)/n_peers) inputs = X_train[i*ss:i*ss+ss] outputs = y_train[i*ss:i*ss+ss] # network.append(build_model(inputs, outputs)) network.append(global_model) banned_peers = set() ################## # BEGIN TRAINING # ################## for t in range(n_rounds): print(f'Round {t+1}.') sleep(1) ## SERVER SIDE ################################################################# # Fetch global model parameters global_weights, global_biases = get_parameters(global_model) if clear_nans: global_weights, global_biases = nans_to_zero(global_weights, global_biases) # Initialize peer update lists network_weight_updates = [] network_bias_updates = [] # Selection of participant peers in this global training epoch if ban_malicious: good_peers = list([p for i,p in enumerate(network) if i not in banned_peers]) n_participants = n_participants if n_participants <= len(good_peers) else int(len(good_peers) * percentage_participants) participants = random.sample(list(enumerate(good_peers)), n_participants) else: participants = random.sample(list(enumerate(network)),n_participants) ################################################################################ ## CLIENT SIDE ################################################################# for i, local_model in tqdm(participants): # Update local model with global parameters set_parameters(local_model, global_weights, global_biases) # Initialization of user data ss = int(len(X_train)/n_peers) inputs = X_train[i*ss:i*ss+ss] outputs = y_train[i*ss:i*ss+ss] # the scanner peer side if(i == scaner): X_train_local, X_test_local, y_train_local, y_test_local = train_test_split(inputs,outputs, test_size=0.7, random_state=rs) inputs = X_train_local outputs = y_train_local if(t == 0): forest = build_forest(X_train_local,y_train_local) forest_predictions = forest.predict(X_test_local) acc_forest = np.mean([t==p for t,p in zip(y_test_local, forest_predictions)]) FL_predict1 = global_model.predict(X_test_local) imp = scan_wrong(forest_predictions, FL_predict1, forest , y_test_local, X_test_local) FI_dic1[t] = imp # Benign peer # Train local model local_weight_updates, local_bias_updates = get_updates(local_model, inputs, outputs, local_batch_size, n_local_rounds) if clear_nans: local_weight_updates, local_bias_updates = nans_to_zero(local_weight_updates, local_bias_updates) network_weight_updates.append(local_weight_updates) network_bias_updates.append(local_bias_updates) ## END OF CLIENT SIDE ########################################################## ###################################### # SERVER SIDE AGGREGATION MECHANISMS # ###################################### # Aggregate client updates aggregated_weights, aggregated_biases = aggregate(n_layers, n_participants, np.mean, network_weight_updates, network_bias_updates) if clear_nans: aggregated_weights, aggregated_biases = nans_to_zero(aggregated_weights, aggregated_biases) # Apply updates to global model apply_updates(global_model, model_substitution_rate, aggregated_weights, aggregated_biases) # Proceed as in first case aggregated_weights, aggregated_biases = aggregate(n_layers, n_participants, np.mean, network_weight_updates, network_bias_updates) if clear_nans: aggregated_weights, aggregated_biases = nans_to_zero(aggregated_weights, aggregated_biases) apply_updates(global_model, model_substitution_rate, aggregated_weights, aggregated_biases) ################### # COMPUTE METRICS # ################### # Global model accuracy score = global_model.evaluate(X_test, y_test, verbose=0) print(f'Global model loss: {score[0]}; global model accuracy: {score[1]}') metrics['accuracy'].append(score[1]) # Accuracy without the target score = global_model.evaluate(X_test, y_test, verbose=0) metrics['acc_no_target'].append(score[1]) # Distance of individual updates to the final aggregation metrics['update_distances'].append([dist_weights(aggregated_weights, w_i) for w_i in network_weight_updates])
Train on 420 samples, validate on 3166 samples Epoch 1/250 420/420 [==============================] - ETA: 1s - loss: 0.8020 - accuracy: 0.15 - 0s 475us/step - loss: 0.6159 - accuracy: 0.6571 - val_loss: 0.5505 - val_accuracy: 0.7454 Epoch 2/250 420/420 [==============================] - ETA: 0s - loss: 0.4634 - accuracy: 0.81 - 0s 221us/step - loss: 0.5195 - accuracy: 0.7667 - val_loss: 0.5165 - val_accuracy: 0.7454 Epoch 3/250 420/420 [==============================] - ETA: 0s - loss: 0.4718 - accuracy: 0.78 - 0s 197us/step - loss: 0.4819 - accuracy: 0.7667 - val_loss: 0.4850 - val_accuracy: 0.7454 Epoch 4/250 420/420 [==============================] - ETA: 0s - loss: 0.4996 - accuracy: 0.78 - 0s 202us/step - loss: 0.4604 - accuracy: 0.7667 - val_loss: 0.4648 - val_accuracy: 0.7454 Epoch 5/250 420/420 [==============================] - ETA: 0s - loss: 0.4819 - accuracy: 0.81 - 0s 216us/step - loss: 0.4435 - accuracy: 0.7690 - val_loss: 0.4511 - val_accuracy: 0.7527 Epoch 6/250 420/420 [==============================] - ETA: 0s - loss: 0.5019 - accuracy: 0.71 - 0s 180us/step - loss: 0.4364 - accuracy: 0.7667 - val_loss: 0.4513 - val_accuracy: 0.7489 Epoch 7/250 420/420 [==============================] - ETA: 0s - loss: 0.4597 - accuracy: 0.71 - 0s 164us/step - loss: 0.4270 - accuracy: 0.7857 - val_loss: 0.4398 - val_accuracy: 0.7748 Epoch 8/250 420/420 [==============================] - ETA: 0s - loss: 0.4168 - accuracy: 0.75 - 0s 178us/step - loss: 0.4183 - accuracy: 0.8048 - val_loss: 0.4361 - val_accuracy: 0.7761 Epoch 9/250 420/420 [==============================] - ETA: 0s - loss: 0.3737 - accuracy: 0.75 - 0s 199us/step - loss: 0.4082 - accuracy: 0.8167 - val_loss: 0.4300 - val_accuracy: 0.7738 Epoch 10/250 420/420 [==============================] - ETA: 0s - loss: 0.4122 - accuracy: 0.84 - 0s 180us/step - loss: 0.4041 - accuracy: 0.8024 - val_loss: 0.4348 - val_accuracy: 0.7798 Epoch 11/250 420/420 [==============================] - ETA: 0s - loss: 0.5593 - accuracy: 0.62 - 0s 185us/step - loss: 0.3933 - accuracy: 0.8167 - val_loss: 0.4230 - val_accuracy: 0.7817 Epoch 12/250 420/420 [==============================] - ETA: 0s - loss: 0.3269 - accuracy: 0.84 - 0s 190us/step - loss: 0.3849 - accuracy: 0.8167 - val_loss: 0.4186 - val_accuracy: 0.7890 Epoch 13/250 420/420 [==============================] - ETA: 0s - loss: 0.4152 - accuracy: 0.87 - 0s 197us/step - loss: 0.3739 - accuracy: 0.8190 - val_loss: 0.4109 - val_accuracy: 0.7922 Epoch 14/250 420/420 [==============================] - ETA: 0s - loss: 0.3175 - accuracy: 0.87 - 0s 212us/step - loss: 0.3681 - accuracy: 0.8357 - val_loss: 0.4049 - val_accuracy: 0.7953 Epoch 15/250 420/420 [==============================] - ETA: 0s - loss: 0.2642 - accuracy: 0.87 - 0s 254us/step - loss: 0.3643 - accuracy: 0.8190 - val_loss: 0.3993 - val_accuracy: 0.8061 Epoch 16/250 420/420 [==============================] - ETA: 0s - loss: 0.3769 - accuracy: 0.84 - 0s 230us/step - loss: 0.3681 - accuracy: 0.8286 - val_loss: 0.3994 - val_accuracy: 0.8016 Epoch 17/250 420/420 [==============================] - ETA: 0s - loss: 0.4769 - accuracy: 0.84 - 0s 214us/step - loss: 0.3530 - accuracy: 0.8381 - val_loss: 0.3918 - val_accuracy: 0.8092 Epoch 18/250 420/420 [==============================] - ETA: 0s - loss: 0.3678 - accuracy: 0.87 - 0s 202us/step - loss: 0.3377 - accuracy: 0.8310 - val_loss: 0.3871 - val_accuracy: 0.8130 Epoch 19/250 420/420 [==============================] - ETA: 0s - loss: 0.3190 - accuracy: 0.81 - 0s 230us/step - loss: 0.3469 - accuracy: 0.8310 - val_loss: 0.4820 - val_accuracy: 0.7738 Epoch 20/250 420/420 [==============================] - ETA: 0s - loss: 0.3204 - accuracy: 0.84 - 0s 199us/step - loss: 0.3986 - accuracy: 0.8167 - val_loss: 0.4357 - val_accuracy: 0.7713 Epoch 21/250 420/420 [==============================] - ETA: 0s - loss: 0.4518 - accuracy: 0.75 - 0s 211us/step - loss: 0.3504 - accuracy: 0.8286 - val_loss: 0.4117 - val_accuracy: 0.7979 Epoch 22/250 420/420 [==============================] - ETA: 0s - loss: 0.4504 - accuracy: 0.81 - 0s 197us/step - loss: 0.3356 - accuracy: 0.8357 - val_loss: 0.3779 - val_accuracy: 0.8222 Epoch 23/250 420/420 [==============================] - ETA: 0s - loss: 0.3286 - accuracy: 0.81 - 0s 181us/step - loss: 0.3188 - accuracy: 0.8357 - val_loss: 0.4099 - val_accuracy: 0.7982 Epoch 24/250 420/420 [==============================] - ETA: 0s - loss: 0.5356 - accuracy: 0.81 - 0s 178us/step - loss: 0.3277 - accuracy: 0.8310 - val_loss: 0.3724 - val_accuracy: 0.8247 Epoch 25/250 420/420 [==============================] - ETA: 0s - loss: 0.2666 - accuracy: 0.87 - 0s 192us/step - loss: 0.3120 - accuracy: 0.8476 - val_loss: 0.3793 - val_accuracy: 0.8181 Epoch 26/250 420/420 [==============================] - ETA: 0s - loss: 0.1937 - accuracy: 0.90 - 0s 191us/step - loss: 0.3137 - accuracy: 0.8595 - val_loss: 0.3741 - val_accuracy: 0.8272 Epoch 27/250 420/420 [==============================] - ETA: 0s - loss: 0.3485 - accuracy: 0.81 - 0s 228us/step - loss: 0.3016 - accuracy: 0.8571 - val_loss: 0.3725 - val_accuracy: 0.8241 Epoch 28/250 420/420 [==============================] - ETA: 0s - loss: 0.1122 - accuracy: 1.00 - 0s 261us/step - loss: 0.2970 - accuracy: 0.8548 - val_loss: 0.3731 - val_accuracy: 0.8222 Epoch 29/250 420/420 [==============================] - ETA: 0s - loss: 0.2215 - accuracy: 0.90 - 0s 197us/step - loss: 0.3014 - accuracy: 0.8476 - val_loss: 0.3781 - val_accuracy: 0.8174 Epoch 30/250 420/420 [==============================] - ETA: 0s - loss: 0.2079 - accuracy: 0.84 - 0s 204us/step - loss: 0.2942 - accuracy: 0.8548 - val_loss: 0.3813 - val_accuracy: 0.8181 Epoch 31/250 420/420 [==============================] - ETA: 0s - loss: 0.3559 - accuracy: 0.75 - 0s 225us/step - loss: 0.2938 - accuracy: 0.8429 - val_loss: 0.3857 - val_accuracy: 0.8105 Epoch 32/250 420/420 [==============================] - ETA: 0s - loss: 0.4537 - accuracy: 0.78 - 0s 239us/step - loss: 0.3120 - accuracy: 0.8500 - val_loss: 0.3785 - val_accuracy: 0.8184 Epoch 33/250 420/420 [==============================] - ETA: 0s - loss: 0.1957 - accuracy: 0.90 - 0s 229us/step - loss: 0.3323 - accuracy: 0.8500 - val_loss: 0.4259 - val_accuracy: 0.7997 Epoch 34/250 420/420 [==============================] - ETA: 0s - loss: 0.2729 - accuracy: 0.87 - 0s 212us/step - loss: 0.2830 - accuracy: 0.8738 - val_loss: 0.3710 - val_accuracy: 0.8282 Epoch 35/250 420/420 [==============================] - ETA: 0s - loss: 0.4034 - accuracy: 0.78 - 0s 245us/step - loss: 0.2836 - accuracy: 0.8643 - val_loss: 0.3741 - val_accuracy: 0.8272 Epoch 36/250 420/420 [==============================] - ETA: 0s - loss: 0.2914 - accuracy: 0.87 - 0s 219us/step - loss: 0.2785 - accuracy: 0.8667 - val_loss: 0.3696 - val_accuracy: 0.8256 Epoch 37/250 420/420 [==============================] - ETA: 0s - loss: 0.2081 - accuracy: 0.90 - 0s 252us/step - loss: 0.2737 - accuracy: 0.8690 - val_loss: 0.3721 - val_accuracy: 0.8244 Epoch 38/250 420/420 [==============================] - ETA: 0s - loss: 0.2847 - accuracy: 0.90 - 0s 185us/step - loss: 0.2932 - accuracy: 0.8762 - val_loss: 0.3984 - val_accuracy: 0.8149 Epoch 39/250 420/420 [==============================] - ETA: 0s - loss: 0.2068 - accuracy: 0.96 - 0s 275us/step - loss: 0.2746 - accuracy: 0.8786 - val_loss: 0.3747 - val_accuracy: 0.8260 Epoch 40/250 420/420 [==============================] - ETA: 0s - loss: 0.2561 - accuracy: 0.87 - 0s 235us/step - loss: 0.2991 - accuracy: 0.8524 - val_loss: 0.3819 - val_accuracy: 0.8253 Epoch 41/250 420/420 [==============================] - ETA: 0s - loss: 0.2118 - accuracy: 0.90 - 0s 188us/step - loss: 0.2920 - accuracy: 0.8429 - val_loss: 0.3901 - val_accuracy: 0.8143 Epoch 42/250 420/420 [==============================] - ETA: 0s - loss: 0.1978 - accuracy: 0.90 - 0s 230us/step - loss: 0.2979 - accuracy: 0.8667 - val_loss: 0.4308 - val_accuracy: 0.8080 Epoch 43/250 420/420 [==============================] - ETA: 0s - loss: 0.3652 - accuracy: 0.75 - 0s 233us/step - loss: 0.2874 - accuracy: 0.8571 - val_loss: 0.3759 - val_accuracy: 0.8266 Epoch 44/250
420/420 [==============================] - ETA: 0s - loss: 0.2414 - accuracy: 0.90 - 0s 211us/step - loss: 0.2671 - accuracy: 0.8643 - val_loss: 0.3861 - val_accuracy: 0.8152 Epoch 45/250 420/420 [==============================] - ETA: 0s - loss: 0.1696 - accuracy: 0.96 - 0s 211us/step - loss: 0.2889 - accuracy: 0.8738 - val_loss: 0.3877 - val_accuracy: 0.8234 Epoch 46/250 420/420 [==============================] - ETA: 0s - loss: 0.2738 - accuracy: 0.84 - 0s 220us/step - loss: 0.2635 - accuracy: 0.8738 - val_loss: 0.3914 - val_accuracy: 0.8222 Epoch 47/250 32/420 [=>............................] - ETA: 0s - loss: 0.2366 - accuracy: 0.9062
In [ ]:
# sort the feature according to the last epoch and print it with importances sort_index = np.argsort(FI_dic1[9]) for x in sort_index: print(names[x], ', ', FI_dic1[9][x])
In [ ]:
In [ ]:
In [ ]: