From 650c628d2abf0d0fd0a6b8a08a7e33744502814c Mon Sep 17 00:00:00 2001 From: Rami <77787226+RamiHaf@users.noreply.github.com> Date: Tue, 26 Jan 2021 18:38:30 +0100 Subject: [PATCH] Add files via upload --- ..._prediction_via_Feature_importances .ipynb | 2133 +++++++++++++++++ 1 file changed, 2133 insertions(+) create mode 100644 Explaining_the_prediction_via_Feature_importances .ipynb diff --git a/Explaining_the_prediction_via_Feature_importances .ipynb b/Explaining_the_prediction_via_Feature_importances .ipynb new file mode 100644 index 0000000..78ef53d --- /dev/null +++ b/Explaining_the_prediction_via_Feature_importances .ipynb @@ -0,0 +1,2133 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Using TensorFlow backend.\n" + ] + } + ], + "source": [ + "#IMPORTS\n", + "\n", + "import numpy as np\n", + "import random\n", + "import tensorflow as tf\n", + "import tensorflow.keras as kr\n", + "import tensorflow.keras.backend as K\n", + "from tensorflow.keras.models import Model\n", + "from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, Flatten, Dense\n", + "from tensorflow.keras.datasets import mnist\n", + "import os\n", + "import csv\n", + "\n", + "from scipy.spatial.distance import euclidean\n", + "from sklearn.metrics import confusion_matrix\n", + "\n", + "from time import sleep\n", + "from tqdm import tqdm\n", + "\n", + "import copy\n", + "import numpy\n", + "from sklearn.datasets import make_classification\n", + "from sklearn.ensemble import RandomForestClassifier\n", + "import pandas as pd\n", + "import matplotlib.pyplot as plt\n", + "import math\n", + "import seaborn as sns\n", + "from numpy.random import RandomState\n", + "import scipy as scp\n", + "from sklearn.model_selection import train_test_split\n", + "from sklearn.compose import ColumnTransformer\n", + "from sklearn.preprocessing import OneHotEncoder, LabelEncoder\n", + "from keras.models import Sequential\n", + "from keras.layers import Dense\n", + "from keras import optimizers\n", + "from keras.callbacks import EarlyStopping,ModelCheckpoint\n", + "from keras.utils import to_categorical\n", + "from keras import backend as K\n", + "from itertools import product\n", + "from sklearn.metrics import accuracy_score\n", + "from sklearn.metrics import precision_score\n", + "from sklearn.metrics import recall_score\n", + "from sklearn.metrics import f1_score\n", + "from sklearn.metrics import roc_auc_score\n", + "from sklearn.metrics import confusion_matrix\n", + "\n", + "from sklearn import mixture\n", + "\n", + "from mpl_toolkits.mplot3d import Axes3D\n", + "import matplotlib.pyplot as plt\n", + "%matplotlib inline" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "# Enter here the data set you want to explain (adult, activity, or synthatic)\n", + "\n", + "data_set = 'adult'\n", + "\n", + "# Enter here the numb er of peers you want in the experiments\n", + "\n", + "n_peers = 100" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "# the random state we will use in the experiments. It can be changed \n", + "\n", + "rs = RandomState(92)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "RangeIndex: 45222 entries, 0 to 45221\n", + "Data columns (total 14 columns):\n", + "age 45222 non-null float64\n", + "workclass 45222 non-null float64\n", + "educational-num 45222 non-null float64\n", + "marital-status 45222 non-null float64\n", + "occupation 45222 non-null float64\n", + "relationship 45222 non-null float64\n", + "race 45222 non-null float64\n", + "gender 45222 non-null float64\n", + "capital-gain 45222 non-null float64\n", + "capital-loss 45222 non-null float64\n", + "hours-per-week 45222 non-null float64\n", + "native-country 45222 non-null float64\n", + "income_<=50K 45222 non-null uint8\n", + "income_>50K 45222 non-null uint8\n", + "dtypes: float64(12), uint8(2)\n", + "memory usage: 4.2 MB\n" + ] + }, + { + "data": { + "text/plain": [ + "None" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
ageworkclasseducational-nummarital-statusoccupationrelationshipracegendercapital-gaincapital-losshours-per-weeknative-countryincome_<=50Kincome_>50K
00.1095890.3333330.4000001.00.4615380.60.51.00.0000000.00.3979590.9510
10.2876710.3333330.5333330.00.3076920.01.01.00.0000000.00.5000000.9510
20.1506850.1666670.7333330.00.7692310.01.01.00.0000000.00.3979590.9501
30.3698630.3333330.6000000.00.4615380.00.51.00.0768810.00.3979590.9501
40.2328770.3333330.3333331.00.5384620.21.01.00.0000000.00.2959180.9510
50.6301370.6666670.9333330.00.6923080.01.01.00.0310300.00.3163270.9501
60.0958900.3333330.6000001.00.5384620.81.00.00.0000000.00.3979590.9510
70.5205480.3333330.2000000.00.1538460.01.01.00.0000000.00.0918370.9510
80.6575340.3333330.5333330.00.4615380.01.01.00.0641810.00.3979590.9501
90.2602740.0000000.8000000.00.0000000.01.01.00.0000000.00.3979590.9510
\n", + "
" + ], + "text/plain": [ + " age workclass educational-num marital-status occupation \\\n", + "0 0.109589 0.333333 0.400000 1.0 0.461538 \n", + "1 0.287671 0.333333 0.533333 0.0 0.307692 \n", + "2 0.150685 0.166667 0.733333 0.0 0.769231 \n", + "3 0.369863 0.333333 0.600000 0.0 0.461538 \n", + "4 0.232877 0.333333 0.333333 1.0 0.538462 \n", + "5 0.630137 0.666667 0.933333 0.0 0.692308 \n", + "6 0.095890 0.333333 0.600000 1.0 0.538462 \n", + "7 0.520548 0.333333 0.200000 0.0 0.153846 \n", + "8 0.657534 0.333333 0.533333 0.0 0.461538 \n", + "9 0.260274 0.000000 0.800000 0.0 0.000000 \n", + "\n", + " relationship race gender capital-gain capital-loss hours-per-week \\\n", + "0 0.6 0.5 1.0 0.000000 0.0 0.397959 \n", + "1 0.0 1.0 1.0 0.000000 0.0 0.500000 \n", + "2 0.0 1.0 1.0 0.000000 0.0 0.397959 \n", + "3 0.0 0.5 1.0 0.076881 0.0 0.397959 \n", + "4 0.2 1.0 1.0 0.000000 0.0 0.295918 \n", + "5 0.0 1.0 1.0 0.031030 0.0 0.316327 \n", + "6 0.8 1.0 0.0 0.000000 0.0 0.397959 \n", + "7 0.0 1.0 1.0 0.000000 0.0 0.091837 \n", + "8 0.0 1.0 1.0 0.064181 0.0 0.397959 \n", + "9 0.0 1.0 1.0 0.000000 0.0 0.397959 \n", + "\n", + " native-country income_<=50K income_>50K \n", + "0 0.95 1 0 \n", + "1 0.95 1 0 \n", + "2 0.95 0 1 \n", + "3 0.95 0 1 \n", + "4 0.95 1 0 \n", + "5 0.95 0 1 \n", + "6 0.95 1 0 \n", + "7 0.95 1 0 \n", + "8 0.95 0 1 \n", + "9 0.95 1 0 " + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# preprocessing adults data set\n", + "\n", + "if data_set == 'adult':\n", + " #Load dataset into a pandas DataFrame\n", + " adult_data = pd.read_csv('adult_data.csv', na_values='?')\n", + " # Drop all records with missing values\n", + " adult_data.dropna(inplace=True)\n", + " adult_data.reset_index(drop=True, inplace=True)\n", + "\n", + " # Drop fnlwgt, not interesting for ML\n", + " adult_data.drop('fnlwgt', axis=1, inplace=True)\n", + " adult_data.drop('education', axis=1, inplace=True)\n", + "\n", + "# merging some similar features.\n", + " adult_data['marital-status'].replace('Married-civ-spouse', 'Married', inplace=True)\n", + " adult_data['marital-status'].replace('Divorced', 'Unmarried', inplace=True)\n", + " adult_data['marital-status'].replace('Never-married', 'Unmarried', inplace=True)\n", + " adult_data['marital-status'].replace('Separated', 'Unmarried', inplace=True)\n", + " adult_data['marital-status'].replace('Widowed', 'Unmarried', inplace=True)\n", + " adult_data['marital-status'].replace('Married-spouse-absent', 'Married', inplace=True)\n", + " adult_data['marital-status'].replace('Married-AF-spouse', 'Married', inplace=True)\n", + " \n", + " adult_data = pd.concat([adult_data,pd.get_dummies(adult_data['income'], prefix='income')],axis=1)\n", + " adult_data.drop('income', axis=1, inplace=True)\n", + " obj_columns = adult_data.select_dtypes(['object']).columns\n", + " adult_data[obj_columns] = adult_data[obj_columns].astype('category')\n", + " # Convert numerics to floats and normalize\n", + " num_columns = adult_data.select_dtypes(['int64']).columns\n", + " adult_data[num_columns] = adult_data[num_columns].astype('float64')\n", + " for c in num_columns:\n", + " #adult[c] -= adult[c].mean()\n", + " #adult[c] /= adult[c].std()\n", + " adult_data[c] = (adult_data[c] - adult_data[c].min()) / (adult_data[c].max()-adult_data[c].min())\n", + " # 'workclass', 'marital-status', 'occupation', 'relationship' ,'race', 'gender', 'native-country'\n", + " # adult_data['income'] = adult_data['income'].cat.codes\n", + " adult_data['marital-status'] = adult_data['marital-status'].cat.codes\n", + " adult_data['occupation'] = adult_data['occupation'].cat.codes\n", + " adult_data['relationship'] = adult_data['relationship'].cat.codes\n", + " adult_data['race'] = adult_data['race'].cat.codes\n", + " adult_data['gender'] = adult_data['gender'].cat.codes\n", + " adult_data['native-country'] = adult_data['native-country'].cat.codes\n", + " adult_data['workclass'] = adult_data['workclass'].cat.codes\n", + "\n", + " num_columns = adult_data.select_dtypes(['int8']).columns\n", + " adult_data[num_columns] = adult_data[num_columns].astype('float64')\n", + " for c in num_columns:\n", + " #adult[c] -= adult[c].mean()\n", + " #adult[c] /= adult[c].std()\n", + " adult_data[c] = (adult_data[c] - adult_data[c].min()) / (adult_data[c].max()-adult_data[c].min())\n", + " display(adult_data.info())\n", + " display(adult_data.head(10))\n", + " \n", + " adult_data = adult_data.to_numpy()\n", + " \n", + "# splite the data to train and test datasets\n", + " X_train, X_test, y_train, y_test = train_test_split(adult_data[:,:-2],adult_data[:,-2:], test_size=0.07, random_state=rs)\n", + "# the names of the features\n", + " names = ['age','workclass','educational-num','marital-status','occupation',\n", + " 'relationship','race','gender','capital-gain','capital-loss','hours-per-week','native-country']\n", + " Features_number = len(X_train[0])" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "if data_set == 'synthatic':\n", + " #generate the data\n", + " X, y = make_classification(n_samples=1000000, n_features=10, n_redundant=3, n_repeated=2, #n_classes=3, \n", + " n_informative=5, n_clusters_per_class=4, \n", + " random_state=42)\n", + " y = pd.DataFrame(data=y, columns=[\"y\"])\n", + " y = pd.get_dummies(y['y'], prefix='y')\n", + " y = y.to_numpy()\n", + " X_train, X_test, y_train, y_test = train_test_split(X,y, test_size=0.07, random_state=rs)\n", + " # the names of the features\n", + " names = ['X(0)','X(1)','X(2)','X(3)','X(4)','X(5)','X(6)','X(7)','X(8)','X(9)']\n", + " Features_number = len(X_train[0])" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "if data_set == 'activity':\n", + " #Load dataset into a pandas DataFrame\n", + " activity = pd.read_csv(\"activity_3_original.csv\", sep=',')\n", + "# drop some features that have non value in the majority of the samples\n", + " to_drop = ['subject', 'timestamp', 'heart_rate','activityID']\n", + " activity.drop(axis=1, columns=to_drop, inplace=True)\n", + "# prepare the truth\n", + " activity = pd.concat([activity,pd.get_dummies(activity['motion'], prefix='motion')],axis=1)\n", + " activity.drop('motion', axis=1, inplace=True)\n", + " class_label = [ 'motion_n', 'motion_y']\n", + " predictors = [a for a in activity.columns.values if a not in class_label]\n", + "\n", + " for p in predictors:\n", + " activity[p].fillna(activity[p].mean(), inplace=True)\n", + "\n", + " display(predictors)\n", + " for p in predictors:\n", + " activity[p] = (activity[p]-activity[p].min()) / (activity[p].max() - activity[p].min())\n", + " activity[p].astype('float32')\n", + " activity = activity.to_numpy()\n", + " X_train, X_test, y_train, y_test = train_test_split(activity[:,:-2],activity[:,-2:], test_size=0.07, random_state=rs)\n", + " # the names of the features\n", + " names = ['temp_hand','acceleration_16_x_hand',\n", + " 'acceleration_16_y_hand','acceleration_16_z_hand','acceleration_6_x_hand',\n", + " 'acceleration_6_y_hand','acceleration_6_z_hand','gyroscope_x_hand','gyroscope_y_hand',\n", + " 'gyroscope_z_hand','magnetometer_x_hand','magnetometer_y_hand','magnetometer_z_hand',\n", + " 'temp_chest','acceleration_16_x_chest','acceleration_16_y_chest','acceleration_16_z_chest','acceleration_6_x_chest',\n", + " 'acceleration_6_y_chest','acceleration_6_z_chest','gyroscope_x_chest','gyroscope_y_chest','gyroscope_z_chest',\n", + " 'magnetometer_x_chest','magnetometer_y_chest','magnetometer_z_chest','temp_ankle','acceleration_16_x_ankle',\n", + " 'acceleration_16_y_ankle','acceleration_16_z_ankle','acceleration_6_x_ankle','acceleration_6_y_ankle',\n", + " 'acceleration_6_z_ankle','gyroscope_x_ankle','gyroscope_y_ankle','gyroscope_z_ankle','magnetometer_x_ankle',\n", + " 'magnetometer_y_ankle','magnetometer_z_ankle']\n", + " Features_number = len(X_train[0])\n" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train on 42056 samples, validate on 3166 samples\n", + "Epoch 1/2\n", + "42056/42056 [==============================] - ETA: 2:15 - loss: 0.7931 - accuracy: 0.28 - ETA: 5s - loss: 0.5651 - accuracy: 0.6988 - ETA: 3s - loss: 0.5165 - accuracy: 0.72 - ETA: 2s - loss: 0.4896 - accuracy: 0.74 - ETA: 2s - loss: 0.4646 - accuracy: 0.76 - ETA: 2s - loss: 0.4489 - accuracy: 0.77 - ETA: 1s - loss: 0.4412 - accuracy: 0.77 - ETA: 1s - loss: 0.4316 - accuracy: 0.78 - ETA: 1s - loss: 0.4258 - accuracy: 0.78 - ETA: 1s - loss: 0.4224 - accuracy: 0.79 - ETA: 1s - loss: 0.4158 - accuracy: 0.79 - ETA: 1s - loss: 0.4118 - accuracy: 0.79 - ETA: 1s - loss: 0.4073 - accuracy: 0.80 - ETA: 1s - loss: 0.4052 - accuracy: 0.80 - ETA: 1s - loss: 0.4025 - accuracy: 0.80 - ETA: 1s - loss: 0.3992 - accuracy: 0.80 - ETA: 1s - loss: 0.3983 - accuracy: 0.80 - ETA: 0s - loss: 0.3964 - accuracy: 0.80 - ETA: 0s - loss: 0.3930 - accuracy: 0.80 - ETA: 0s - loss: 0.3893 - accuracy: 0.81 - ETA: 0s - loss: 0.3881 - accuracy: 0.81 - ETA: 0s - loss: 0.3873 - accuracy: 0.81 - ETA: 0s - loss: 0.3857 - accuracy: 0.81 - ETA: 0s - loss: 0.3831 - accuracy: 0.81 - ETA: 0s - loss: 0.3811 - accuracy: 0.81 - ETA: 0s - loss: 0.3793 - accuracy: 0.81 - ETA: 0s - loss: 0.3784 - accuracy: 0.81 - ETA: 0s - loss: 0.3764 - accuracy: 0.81 - ETA: 0s - loss: 0.3756 - accuracy: 0.81 - ETA: 0s - loss: 0.3740 - accuracy: 0.82 - ETA: 0s - loss: 0.3726 - accuracy: 0.82 - 2s 41us/step - loss: 0.3720 - accuracy: 0.8217 - val_loss: 0.3452 - val_accuracy: 0.8370\n", + "\n", + "Epoch 00001: val_loss improved from inf to 0.34516, saving model to test.h8\n", + "Epoch 2/2\n", + "42056/42056 [==============================] - ETA: 3s - loss: 0.5095 - accuracy: 0.84 - ETA: 1s - loss: 0.3444 - accuracy: 0.83 - ETA: 1s - loss: 0.3407 - accuracy: 0.84 - ETA: 1s - loss: 0.3383 - accuracy: 0.83 - ETA: 1s - loss: 0.3365 - accuracy: 0.84 - ETA: 1s - loss: 0.3374 - accuracy: 0.84 - ETA: 1s - loss: 0.3386 - accuracy: 0.84 - ETA: 1s - loss: 0.3357 - accuracy: 0.84 - ETA: 1s - loss: 0.3364 - accuracy: 0.84 - ETA: 1s - loss: 0.3353 - accuracy: 0.84 - ETA: 1s - loss: 0.3351 - accuracy: 0.84 - ETA: 1s - loss: 0.3368 - accuracy: 0.83 - ETA: 1s - loss: 0.3381 - accuracy: 0.83 - ETA: 1s - loss: 0.3396 - accuracy: 0.83 - ETA: 0s - loss: 0.3388 - accuracy: 0.83 - ETA: 0s - loss: 0.3384 - accuracy: 0.83 - ETA: 0s - loss: 0.3393 - accuracy: 0.83 - ETA: 0s - loss: 0.3390 - accuracy: 0.83 - ETA: 0s - loss: 0.3392 - accuracy: 0.83 - ETA: 0s - loss: 0.3400 - accuracy: 0.83 - ETA: 0s - loss: 0.3400 - accuracy: 0.83 - ETA: 0s - loss: 0.3410 - accuracy: 0.83 - ETA: 0s - loss: 0.3406 - accuracy: 0.83 - ETA: 0s - loss: 0.3415 - accuracy: 0.83 - ETA: 0s - loss: 0.3419 - accuracy: 0.83 - ETA: 0s - loss: 0.3421 - accuracy: 0.83 - ETA: 0s - loss: 0.3417 - accuracy: 0.83 - ETA: 0s - loss: 0.3410 - accuracy: 0.83 - ETA: 0s - loss: 0.3415 - accuracy: 0.83 - ETA: 0s - loss: 0.3419 - accuracy: 0.83 - ETA: 0s - loss: 0.3420 - accuracy: 0.83 - ETA: 0s - loss: 0.3411 - accuracy: 0.83 - ETA: 0s - loss: 0.3413 - accuracy: 0.83 - 2s 41us/step - loss: 0.3412 - accuracy: 0.8375 - val_loss: 0.3417 - val_accuracy: 0.8370\n", + "\n", + "Epoch 00002: val_loss improved from 0.34516 to 0.34170, saving model to test.h8\n" + ] + } + ], + "source": [ + "#begin federated\n", + "\n", + "earlystopping = EarlyStopping(monitor = 'val_loss',\n", + " min_delta = 0.01,\n", + " patience = 50,\n", + " verbose = 1,\n", + " baseline = 2,\n", + " restore_best_weights = True)\n", + "\n", + "checkpoint = ModelCheckpoint('test.h8',\n", + " monitor='val_loss',\n", + " mode='min',\n", + " save_best_only=True,\n", + " verbose=1)\n", + " \n", + "model = Sequential()\n", + "model.add(Dense(70, input_dim=Features_number, activation='relu'))\n", + "model.add(Dense(50, activation='relu'))\n", + "model.add(Dense(50, activation='relu'))\n", + "model.add(Dense(2, activation='softmax'))\n", + "model.compile(loss='categorical_crossentropy', optimizer='adam',metrics=['accuracy'])\n", + "history = model.fit(X_train, y_train,\n", + "epochs=2,\n", + "validation_data=(X_test, y_test),\n", + "callbacks = [checkpoint, earlystopping],\n", + "shuffle=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "#AUXILIARY METHODS FOR FEDERATED LEARNING\n", + "\n", + "# RETURN INDICES TO LAYERS WITH WEIGHTS AND BIASES\n", + "def trainable_layers(model):\n", + " return [i for i, layer in enumerate(model.layers) if len(layer.get_weights()) > 0]\n", + "\n", + "# RETURN WEIGHTS AND BIASES OF A MODEL\n", + "def get_parameters(model):\n", + " weights = []\n", + " biases = []\n", + " index = trainable_layers(model)\n", + " for i in index:\n", + " weights.append(copy.deepcopy(model.layers[i].get_weights()[0]))\n", + " biases.append(copy.deepcopy(model.layers[i].get_weights()[1])) \n", + " \n", + " return weights, biases\n", + " \n", + "# SET WEIGHTS AND BIASES OF A MODEL\n", + "def set_parameters(model, weights, biases):\n", + " index = trainable_layers(model)\n", + " for i, j in enumerate(index):\n", + " model.layers[j].set_weights([weights[i], biases[i]])\n", + " \n", + "# DEPRECATED: RETURN THE GRADIENTS OF THE MODEL AFTER AN UPDATE \n", + "def get_gradients(model, inputs, outputs):\n", + " \"\"\" Gets gradient of model for given inputs and outputs for all weights\"\"\"\n", + " grads = model.optimizer.get_gradients(model.total_loss, model.trainable_weights)\n", + " symb_inputs = (model._feed_inputs + model._feed_targets + model._feed_sample_weights)\n", + " f = K.function(symb_inputs, grads)\n", + " x, y, sample_weight = model._standardize_user_data(inputs, outputs)\n", + " output_grad = f(x + y + sample_weight)\n", + " \n", + " w_grad = [w for i,w in enumerate(output_grad) if i%2==0]\n", + " b_grad = [w for i,w in enumerate(output_grad) if i%2==1]\n", + " \n", + " return w_grad, b_grad\n", + "\n", + "# RETURN THE DIFFERENCE OF MODELS' WEIGHTS AND BIASES AFTER AN UPDATE \n", + "# NOTE: LEARNING RATE IS APPLIED, SO THE UPDATE IS DIFFERENT FROM THE\n", + "# GRADIENTS. IN CASE VANILLA SGD IS USED, THE GRADIENTS ARE OBTAINED\n", + "# AS (UPDATES / LEARNING_RATE)\n", + "def get_updates(model, inputs, outputs, batch_size, epochs):\n", + " w, b = get_parameters(model)\n", + " #model.train_on_batch(inputs, outputs)\n", + " model.fit(inputs, outputs, batch_size=batch_size, epochs=epochs, verbose=0)\n", + " w_new, b_new = get_parameters(model)\n", + " \n", + " weight_updates = [old - new for old,new in zip(w, w_new)]\n", + " bias_updates = [old - new for old,new in zip(b, b_new)]\n", + " \n", + " return weight_updates, bias_updates\n", + "\n", + "# UPDATE THE MODEL'S WEIGHTS AND PARAMETERS WITH AN UPDATE\n", + "def apply_updates(model, eta, w_new, b_new):\n", + " w, b = get_parameters(model)\n", + " new_weights = [theta - eta*delta for theta,delta in zip(w, w_new)]\n", + " new_biases = [theta - eta*delta for theta,delta in zip(b, b_new)]\n", + " set_parameters(model, new_weights, new_biases)\n", + " \n", + "# FEDERATED AGGREGATION FUNCTION\n", + "def aggregate(n_layers, n_peers, f, w_updates, b_updates):\n", + " agg_w = [f([w_updates[j][i] for j in range(n_peers)], axis=0) for i in range(n_layers)]\n", + " agg_b = [f([b_updates[j][i] for j in range(n_peers)], axis=0) for i in range(n_layers)]\n", + " return agg_w, agg_b\n", + "\n", + "# SOLVE NANS\n", + "def nans_to_zero(W, B):\n", + " W0 = [np.nan_to_num(w, nan=0.0, posinf=0.0, neginf=0.0) for w in W]\n", + " B0 = [np.nan_to_num(b, nan=0.0, posinf=0.0, neginf=0.0) for b in B]\n", + " return W0, B0\n", + "\n", + "def build_forest(X,y):\n", + " clf=RandomForestClassifier(n_estimators=1000, max_depth=7, random_state=0, verbose = 1)\n", + " clf.fit(X,y)\n", + " return clf\n", + "\n", + "# COMPUTE EUCLIDEAN DISTANCE OF WEIGHTS\n", + "def dist_weights(w_a, w_b):\n", + " wf_a = flatten_weights(w_a)\n", + " wf_b = flatten_weights(w_b)\n", + " return euclidean(wf_a, wf_b)\n", + "\n", + "# TRANSFORM ALL WEIGHT TENSORS TO 1D ARRAY\n", + "def flatten_weights(w_in):\n", + " h = w_in[0].reshape(-1)\n", + " for w in w_in[1:]:\n", + " h = np.append(h, w.reshape(-1))\n", + " return h\n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "# scan the forest for trees maches the wrong predictions of the black-box\n", + "def scan_wrong(forest_predictions, FL_predict1, forest , y_test_local, X_test_local):\n", + " sum_feature_improtance= 0\n", + " overal_wrong_feature_importance = 0\n", + " counter = 0\n", + " second_counter = 0\n", + " never_seen = 0\n", + " avr_wrong_importance = 0\n", + " FL_predict1 = np.argmax(FL_predict1, axis=1)\n", + " forest_predictions = np.argmax(forest_predictions, axis=1)\n", + " y_test_local = np.argmax(y_test_local, axis=1)\n", + " for i in range (len(FL_predict1)):\n", + " i_tree = 0\n", + "# if the black-box got a wrong prediction\n", + " if (FL_predict1[i] != y_test_local[i]):\n", + "# getting the prediction of the trees one by one\n", + " for tree_in_forest in forest.estimators_:\n", + " sample = X_test_local[i].reshape(1, -1)\n", + " temp = forest.estimators_[i_tree].predict(sample)\n", + " temp = np.argmax(temp, axis=1)\n", + " i_tree = i_tree + 1\n", + "# if the prediction of the tree maches the predictions of the black-box\n", + " if(FL_predict1[i] == temp):\n", + "# getting the features importances\n", + " sum_feature_improtance = sum_feature_improtance + tree_in_forest.feature_importances_\n", + " counter = counter + 1\n", + "# if we have trees maches the black-box predictions\n", + " if(counter>0):\n", + " ave_feature_importence = sum_feature_improtance/counter\n", + " overal_wrong_feature_importance = ave_feature_importence + overal_wrong_feature_importance\n", + " second_counter = second_counter + 1\n", + " counter = 0\n", + " sum_feature_improtance = 0\n", + "# if there is no trees maches the black-box predictions\n", + " else:\n", + " if(FL_predict1[i] != y_test_local[i]):\n", + " never_seen = never_seen +1\n", + "# getting the average features importances for all the samples that had wrong predictions.\n", + " if(second_counter>0):\n", + " avr_wrong_importance = overal_wrong_feature_importance / second_counter\n", + " return forest.feature_importances_" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[0, 1, 2, 3]" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "trainable_layers(model)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "([array([[ 1.39432400e-01, 8.84631574e-02, -4.47415888e-01,\n", + " 1.23670131e-01, -2.65049934e-01, 2.56673127e-01,\n", + " 2.82177985e-01, -3.88451487e-01, -8.48813355e-02,\n", + " -4.55360711e-01, -2.55180508e-01, -1.34169891e-01,\n", + " -4.19932574e-01, 9.50885192e-02, -4.10533138e-02,\n", + " 1.23161055e-01, -3.34913731e-01, -3.29331495e-02,\n", + " -2.09537312e-01, 2.89370805e-01, -2.42182449e-01,\n", + " 9.41318497e-02, -8.54814351e-02, -2.53278345e-01,\n", + " 7.38841221e-02, 9.76254940e-02, 9.64644551e-03,\n", + " 4.62163612e-02, 1.47847623e-01, 3.28071006e-02,\n", + " -2.16738522e-01, -5.52587435e-02, -1.01704948e-01,\n", + " 2.31297538e-01, -3.01694840e-01, 2.23755836e-02,\n", + " -2.37541839e-01, -8.33741352e-02, -3.33046556e-01,\n", + " -3.82800475e-02, -2.60576427e-01, 1.35413051e-01,\n", + " 5.84374070e-02, -1.67372033e-01, 7.50956163e-02,\n", + " -2.44477212e-01, 4.34608996e-01, 1.95100769e-01,\n", + " -1.71157598e-01, 2.94538945e-01, -2.78368771e-01,\n", + " 3.23733628e-01, 7.93107301e-02, 2.28328109e-01,\n", + " 6.06352724e-02, -7.03767091e-02, 1.33410409e-01,\n", + " 1.21751621e-01, 1.97286800e-01, -9.21699479e-02,\n", + " -3.15490931e-01, 2.30563477e-01, -3.28507647e-02,\n", + " 8.77456143e-02, 5.48780151e-02, -4.60406430e-02,\n", + " -1.89183086e-01, 3.93763036e-02, 2.96199113e-01,\n", + " -2.79987492e-02],\n", + " [ 3.03673856e-02, 1.97539851e-02, -1.50838614e-01,\n", + " 1.14162855e-01, 1.80196881e-01, -1.22831225e-01,\n", + " 7.67074972e-02, -1.13835640e-01, -1.38265222e-01,\n", + " -6.62374571e-02, 1.81205988e-01, -2.81262010e-01,\n", + " 1.72400191e-01, 2.07341984e-01, -1.34065270e-01,\n", + " 6.87680393e-02, -6.93561733e-02, -2.21116617e-01,\n", + " 1.04925461e-01, 1.02081522e-02, 1.51008025e-01,\n", + " -2.92544812e-02, -1.05958931e-01, 1.61262244e-01,\n", + " 1.58383980e-01, -1.24027103e-01, -1.80273309e-01,\n", + " 2.02690706e-01, 1.30619720e-01, -1.44045368e-01,\n", + " 5.87314926e-02, -6.84582517e-02, 5.60571887e-02,\n", + " -1.27603471e-01, 2.20635161e-01, 1.71862170e-01,\n", + " 1.77298188e-02, 1.31710157e-01, -2.06363559e-01,\n", + " 1.41939849e-01, 4.67592143e-02, -2.25164890e-01,\n", + " 2.84170844e-02, -1.87025517e-01, 2.21437346e-02,\n", + " 2.89680868e-01, 2.44593516e-01, 5.39705567e-02,\n", + " 1.68798208e-01, -9.17015448e-02, -9.46003050e-02,\n", + " -8.50451589e-02, -9.65483636e-02, 2.15933964e-01,\n", + " 3.86347598e-03, -2.29437221e-02, 8.44280720e-02,\n", + " 1.96231887e-01, 3.78342345e-02, 1.12372516e-02,\n", + " 7.45132491e-02, -1.45243943e-01, 1.38520822e-01,\n", + " 1.27623096e-01, 9.93933976e-02, 7.73796961e-02,\n", + " 1.07909396e-01, 5.35671674e-02, -2.25077912e-01,\n", + " 1.48774251e-01],\n", + " [-2.48966157e-01, -1.18819617e-01, 3.78526822e-02,\n", + " -4.11971584e-02, 5.32225370e-02, 2.79902488e-01,\n", + " 3.43969136e-01, -5.78653142e-02, 1.67140678e-01,\n", + " -2.94734612e-02, 1.13698818e-01, -2.92426739e-02,\n", + " -1.79812416e-01, 2.88941506e-02, -1.41450733e-01,\n", + " -7.92392809e-03, -1.35528877e-01, -2.56182522e-01,\n", + " -2.33598545e-01, -5.47329225e-02, 2.58110791e-01,\n", + " -2.45282829e-01, 4.75647040e-02, 4.78960238e-02,\n", + " -6.56322390e-02, -6.67297915e-02, 1.69852525e-01,\n", + " -1.50414899e-01, 2.58721203e-01, 1.14194579e-01,\n", + " 2.65164256e-01, 8.89386758e-02, 2.67333359e-01,\n", + " -3.09747636e-01, -1.52420253e-02, 2.57288337e-01,\n", + " 1.46575630e-01, 8.43582675e-02, 1.89198285e-01,\n", + " -5.13301976e-02, -1.45431489e-01, 1.83323875e-01,\n", + " 2.22104147e-01, -7.55850300e-02, 1.44288674e-01,\n", + " -1.75847083e-01, -1.43846169e-01, 1.33877620e-01,\n", + " 1.63822114e-01, -1.28378317e-01, -2.10838597e-02,\n", + " -2.69852519e-01, 1.04066990e-01, 2.06833377e-01,\n", + " -1.28662705e-01, 1.49911791e-01, -2.75938064e-01,\n", + " -3.31552997e-02, 2.19017982e-01, 6.46202068e-04,\n", + " 1.66913256e-01, -1.72089741e-01, 9.96593982e-02,\n", + " -2.43812397e-01, -8.03031027e-02, -1.92508698e-01,\n", + " -3.14832121e-01, -9.16534588e-02, -3.15453112e-01,\n", + " 1.48415402e-01],\n", + " [-2.35771656e-01, 3.27018127e-02, 1.60873935e-01,\n", + " -1.28616795e-01, 3.11803758e-01, -2.35472228e-02,\n", + " -1.39719948e-01, 1.74694061e-02, 7.51914829e-02,\n", + " 2.35624880e-01, 7.33765140e-02, 2.16503426e-01,\n", + " 4.06566672e-02, -2.05656707e-01, 1.96258724e-01,\n", + " 5.99774197e-02, -1.27538797e-02, -6.30170330e-02,\n", + " -1.16274104e-01, -1.43104732e-01, -1.37973130e-01,\n", + " -1.91767380e-01, 3.22461128e-01, 2.99887396e-02,\n", + " 2.64688015e-01, -2.45580390e-01, -2.41390377e-01,\n", + " -1.29994661e-01, -1.80605844e-01, -2.61187732e-01,\n", + " 1.44567609e-01, 1.88110307e-01, 1.73101038e-01,\n", + " 2.86840070e-02, -1.33754045e-01, -5.33887371e-02,\n", + " -1.13288000e-01, -8.15718770e-02, 2.53453523e-01,\n", + " -1.54690027e-01, -1.32443011e-02, -6.94180205e-02,\n", + " -1.20536266e-02, -2.19712891e-02, -2.30549023e-01,\n", + " 2.46970072e-01, -1.82330459e-02, -1.24268174e-01,\n", + " 2.66243219e-01, 1.11885495e-01, 8.33856687e-02,\n", + " -1.06503241e-01, -2.80220248e-02, -1.17930442e-01,\n", + " 2.08708122e-01, 7.04001710e-02, -1.37973502e-02,\n", + " 1.89776018e-01, -7.30874389e-02, -2.11521506e-01,\n", + " 1.42071024e-01, 2.42409576e-02, 8.69186819e-02,\n", + " 3.34844626e-02, -2.07044452e-01, -1.04645088e-01,\n", + " 1.51515082e-01, -1.95780490e-02, 2.13911623e-01,\n", + " 9.59823653e-02],\n", + " [-2.26251304e-01, -4.98282760e-02, 8.57945010e-02,\n", + " 1.85095415e-01, 1.94030240e-01, 1.70300901e-01,\n", + " -1.48310944e-01, -1.68697998e-01, 1.38381734e-01,\n", + " -8.20567235e-02, 1.35808028e-02, -1.75055087e-01,\n", + " 2.08388101e-02, -2.22936451e-01, -7.68952891e-02,\n", + " -4.24526669e-02, 4.03720774e-02, 2.34893888e-01,\n", + " -1.57926619e-01, -2.40865514e-01, 1.67401552e-01,\n", + " 2.16235057e-01, -1.50564939e-01, 1.77459866e-01,\n", + " -1.02011845e-01, 9.56041086e-03, -1.36439502e-01,\n", + " 1.67499810e-01, 1.46594793e-01, -2.37665162e-03,\n", + " 2.35330492e-01, -4.87338640e-02, -8.25209543e-02,\n", + " -7.34776333e-02, 2.11637601e-01, -8.63815099e-02,\n", + " -2.52601802e-01, -1.03249528e-01, 1.14807218e-01,\n", + " 1.93410560e-01, -7.48374164e-02, 4.09806073e-02,\n", + " -1.25015989e-01, 1.75860271e-01, 1.65006757e-01,\n", + " 1.63865000e-01, 1.56919926e-01, -2.22888529e-01,\n", + " -3.29164751e-02, 4.06037048e-02, 2.24684268e-01,\n", + " 1.01046182e-01, -1.53632820e-01, -1.65310353e-01,\n", + " 4.86176573e-02, -2.46649399e-01, -2.84075760e-03,\n", + " 1.55264661e-01, 4.27330621e-02, -2.05510065e-01,\n", + " 1.62713528e-01, -3.14808562e-02, 1.86110288e-01,\n", + " 6.84845075e-02, 4.47224490e-02, -3.40451181e-01,\n", + " 1.40326787e-02, 2.19547436e-01, 7.52496868e-02,\n", + " 1.09770238e-01],\n", + " [-2.14519277e-01, -1.97733060e-01, -1.04191333e-01,\n", + " -1.52826672e-02, 1.04496861e-03, -6.56969398e-02,\n", + " -7.04714730e-02, -1.19291015e-01, 1.01761602e-01,\n", + " 7.52121955e-02, -2.15532720e-01, -1.47176266e-01,\n", + " 1.51603609e-01, -1.83050726e-02, -3.25457342e-02,\n", + " -5.11338934e-02, 1.16198196e-03, -2.66087204e-01,\n", + " 7.53995031e-02, 7.98415840e-02, 4.19246480e-02,\n", + " -7.96627849e-02, 1.22839414e-01, 1.80793643e-01,\n", + " -2.73334742e-01, 5.54925241e-02, 1.19968027e-01,\n", + " 1.63323641e-01, 1.11940101e-01, -1.46585837e-01,\n", + " 1.94005132e-01, 1.88561931e-01, -5.62924668e-02,\n", + " -4.18225750e-02, -1.56423241e-01, -2.25715101e-01,\n", + " -4.82656956e-02, 2.14031748e-02, 2.10182130e-01,\n", + " -3.18871409e-01, -7.38589093e-02, -2.32924759e-01,\n", + " 8.74556080e-02, -1.10086516e-01, 1.84157446e-01,\n", + " -1.46957889e-01, -1.06122330e-01, 2.88575172e-01,\n", + " 7.43130967e-02, 1.63028061e-01, 2.40940854e-01,\n", + " 8.84263813e-02, 1.86871052e-01, -1.03018314e-01,\n", + " -2.51245052e-02, -2.32590944e-01, 2.58567259e-02,\n", + " 1.24988005e-01, 4.27892543e-02, 6.42778203e-02,\n", + " 2.41022035e-01, -5.46587259e-02, -1.77857980e-01,\n", + " 3.70368622e-02, 2.42744144e-02, 1.84613451e-01,\n", + " 2.30415717e-01, -1.80632919e-01, -9.84579027e-02,\n", + " -4.87778150e-02],\n", + " [-2.97077070e-03, -9.92525965e-02, 9.59780440e-02,\n", + " -1.05714351e-01, -2.09908143e-01, 2.08500147e-01,\n", + " -9.31153223e-02, 2.99151987e-01, 4.34016176e-02,\n", + " -2.24611446e-01, 3.31769064e-02, 2.14490488e-01,\n", + " -2.24754527e-01, -1.74998924e-01, -4.15243544e-02,\n", + " -1.69698030e-01, 2.80564696e-01, 1.17882535e-01,\n", + " -9.80678648e-02, 3.15327570e-03, -2.08990425e-01,\n", + " 1.49431065e-01, -1.39306724e-01, 2.40346678e-02,\n", + " 2.40564555e-01, -5.09837978e-02, 2.17804000e-01,\n", + " 1.35088935e-01, 8.79955664e-02, -5.64928725e-02,\n", + " 4.61013429e-02, 6.54249862e-02, -8.42749923e-02,\n", + " 2.62729824e-01, -3.99206020e-02, -1.17483221e-01,\n", + " -1.40452668e-01, -1.06828704e-01, -1.74000204e-01,\n", + " -4.49550189e-02, 2.60878950e-01, 2.07423091e-01,\n", + " -9.15924609e-02, 1.91001654e-01, -1.47255644e-01,\n", + " 7.95471966e-02, -1.70050204e-01, 5.61165512e-02,\n", + " -1.48466706e-01, 1.08682081e-01, 2.04737335e-02,\n", + " -1.74528554e-01, -9.47896019e-02, 1.73530400e-01,\n", + " -1.12356387e-01, 9.92965326e-02, 1.26004890e-01,\n", + " -2.32813179e-01, 9.49711502e-02, -2.34253883e-01,\n", + " -2.76989549e-01, -7.66268969e-02, -3.41671556e-02,\n", + " -5.10511408e-03, -5.79159260e-02, -6.46380782e-02,\n", + " -5.50055876e-02, -3.11404735e-01, 2.45275497e-01,\n", + " -2.22187296e-01],\n", + " [ 4.54801060e-02, 2.56455511e-01, -1.82633027e-01,\n", + " -1.01602580e-02, -8.93032998e-02, 1.04237944e-01,\n", + " 5.84088564e-02, 1.54823989e-01, -1.07336426e-02,\n", + " 2.69688278e-01, 6.16033142e-03, -6.09616982e-03,\n", + " 8.98296311e-02, 1.78536490e-01, -1.43777172e-03,\n", + " -9.94328558e-02, -4.55807038e-02, 9.91010815e-02,\n", + " -4.42102812e-02, 3.77892517e-02, 1.33471981e-01,\n", + " 7.44501278e-02, 1.62690468e-02, 2.23104075e-01,\n", + " -2.61054993e-01, 3.15811366e-01, -2.96082795e-01,\n", + " 1.78025752e-01, -2.63285220e-01, -5.37474826e-02,\n", + " -9.58651751e-02, -2.15012103e-01, -4.33603339e-02,\n", + " -2.60652751e-01, -5.41594252e-02, 2.35952377e-01,\n", + " -3.74763012e-02, 1.91953376e-01, 1.17158510e-01,\n", + " 3.78518994e-03, -6.19563572e-02, 2.10780635e-01,\n", + " 1.62149847e-01, -1.30085796e-01, 1.28252106e-03,\n", + " 2.28483707e-01, -1.14689972e-02, -8.24389532e-02,\n", + " -1.77851245e-01, -1.37649611e-01, 1.65123567e-01,\n", + " 1.03654794e-01, 8.36220309e-02, 1.99557766e-02,\n", + " 6.00132421e-02, -3.04210056e-02, -2.81973660e-01,\n", + " -2.42123492e-02, -2.17434868e-01, -9.64278206e-02,\n", + " -1.85030416e-01, -2.62960136e-01, 5.34782112e-02,\n", + " 1.58508420e-01, 1.65380761e-01, -3.85079943e-02,\n", + " 2.55265355e-01, 5.09922206e-02, -1.47566527e-01,\n", + " 7.40251169e-02],\n", + " [ 1.29649222e-01, -3.14282179e-02, -6.06167972e-01,\n", + " -2.50955880e-01, -3.46874207e-01, 7.49993503e-01,\n", + " 7.28010595e-01, -1.06399655e+00, 1.06234324e+00,\n", + " -3.55233133e-01, -5.50140023e-01, 1.00409508e+00,\n", + " -5.45210958e-01, 1.93181217e-01, -7.01776028e-01,\n", + " -2.10634783e-01, -4.23527777e-01, 3.09440106e-01,\n", + " -1.91907719e-01, 2.85458267e-01, 7.82932997e-01,\n", + " -5.32808244e-01, -4.39185768e-01, -7.65542090e-01,\n", + " -3.82927716e-01, -1.15567505e+00, 1.67764112e-01,\n", + " -9.48192775e-01, 3.64812821e-01, 2.28667915e-01,\n", + " 6.75961256e-01, 8.27623010e-01, -6.38736844e-01,\n", + " -2.00036347e-01, -3.25849533e-01, 9.03906941e-01,\n", + " -2.68816352e-01, -6.27302647e-01, -3.23336124e-01,\n", + " 4.70992297e-01, -5.73931932e-01, 9.17997599e-01,\n", + " 7.42488205e-01, -2.06164107e-01, 2.04111740e-01,\n", + " -7.19973087e-01, -3.76782537e-01, 8.55549395e-01,\n", + " -8.38361323e-01, -9.57333803e-01, -5.20633638e-01,\n", + " -3.67659301e-01, 1.50605768e-01, -7.64182091e-01,\n", + " -3.19448918e-01, -5.01123592e-02, 1.64251193e-01,\n", + " -7.17021644e-01, 6.97100699e-01, -1.25111267e-01,\n", + " -4.96421248e-01, -5.05610764e-01, 1.01232016e+00,\n", + " -1.09313202e+00, 3.20109189e-01, -1.06782168e-01,\n", + " -9.03548539e-01, -2.81452984e-01, -2.17785537e-01,\n", + " 6.68265998e-01],\n", + " [-2.42571607e-01, 2.04211175e-01, -4.92268875e-02,\n", + " -1.63620815e-01, 4.04583551e-02, 3.45696330e-01,\n", + " 3.54173370e-02, 8.10830146e-02, -1.61551312e-03,\n", + " -1.48698622e-02, 1.94258001e-02, 1.15005746e-01,\n", + " -1.20848659e-02, 2.12298751e-01, 8.92769620e-02,\n", + " -7.64900148e-02, -3.41445431e-02, 7.51630887e-02,\n", + " -3.40494029e-02, 2.70350277e-01, 4.42853682e-02,\n", + " 5.13006859e-02, 2.81202555e-01, 1.35484681e-01,\n", + " 7.24686086e-02, -1.34075984e-01, 1.70696169e-01,\n", + " 8.00305977e-03, 2.56366223e-01, 1.33748680e-01,\n", + " 2.25041300e-01, -7.13687986e-02, -4.96987440e-02,\n", + " 3.10503058e-02, -2.25651234e-01, 3.94519985e-01,\n", + " 2.15304196e-01, 1.15869548e-02, 1.47072956e-01,\n", + " 3.18337977e-01, -6.86229253e-03, -5.09570874e-02,\n", + " 2.29824454e-01, 2.61031240e-02, 2.89728433e-01,\n", + " 3.48875783e-02, -5.50319031e-02, 1.21588549e-02,\n", + " -4.12969440e-02, 1.07327215e-01, 1.35437414e-01,\n", + " -2.93096341e-02, 3.36093381e-02, -1.90401971e-01,\n", + " -2.66215026e-01, 6.08073771e-02, 6.91038072e-02,\n", + " -1.98440487e-03, 7.31287152e-02, -2.77851731e-01,\n", + " -1.08341835e-01, -1.85085818e-01, 2.29901448e-01,\n", + " -2.96091676e-01, 2.23246500e-01, -1.44393981e-01,\n", + " -1.93921745e-01, -1.92566663e-01, 1.32529914e-01,\n", + " -1.94337085e-01],\n", + " [-2.53917843e-01, -2.51892120e-01, -1.32432416e-01,\n", + " 1.47464365e-01, -3.17318618e-01, 1.97301418e-01,\n", + " 2.69987226e-01, -4.56497446e-02, 2.30195507e-01,\n", + " 1.32218450e-02, -4.06064779e-01, 2.51328260e-01,\n", + " 5.33021428e-02, -2.66608417e-01, -1.79995075e-01,\n", + " 1.29986405e-01, 2.86205828e-01, 1.35580912e-01,\n", + " -2.18271971e-01, 1.57579169e-01, 2.17058808e-01,\n", + " -1.03528440e-01, -4.87327874e-02, -6.85375035e-02,\n", + " -1.29330382e-01, -1.23507090e-01, 4.80753556e-02,\n", + " -3.16315889e-01, 2.06642285e-01, -1.25930071e-01,\n", + " 4.74674441e-03, -2.28398144e-01, 1.07306920e-01,\n", + " 2.11515024e-01, -1.60666719e-01, 1.23706006e-01,\n", + " -2.24141285e-01, 5.61789013e-02, 1.76867880e-02,\n", + " -9.17073190e-02, 8.19897652e-02, -1.55695155e-02,\n", + " 3.17650735e-01, 2.38761097e-01, 2.45510742e-01,\n", + " 8.75920355e-02, 3.21321398e-01, -1.12799473e-01,\n", + " 2.10606474e-02, -1.81161851e-01, 2.40284592e-01,\n", + " -2.50088274e-01, -2.18976215e-01, 1.12234220e-01,\n", + " 4.77548651e-02, -4.73017395e-02, 1.37630356e-02,\n", + " 1.92280307e-01, 7.14965388e-02, -6.21563159e-02,\n", + " -7.16416389e-02, 1.23388998e-01, 1.82368487e-01,\n", + " 2.31735930e-01, -2.02105597e-01, 1.42061830e-01,\n", + " 1.23616353e-01, 1.56008020e-01, -2.19544828e-01,\n", + " 2.12301493e-01],\n", + " [-7.18890205e-02, -1.92233965e-01, 2.33305559e-01,\n", + " 6.87015578e-02, 8.51642191e-02, -2.19545767e-01,\n", + " 6.00749105e-02, 3.61590572e-02, -6.68269545e-02,\n", + " 1.48855716e-01, 2.30278343e-01, 2.16507941e-01,\n", + " 2.22660348e-01, -2.84734219e-01, 2.37847969e-01,\n", + " -1.55460656e-01, -2.26989180e-01, -6.12188876e-02,\n", + " 1.77810416e-01, 1.45450696e-01, 2.52608925e-01,\n", + " -1.36337921e-01, -1.94631949e-01, 1.51410148e-01,\n", + " 3.44162211e-02, 1.61046118e-01, -1.24759860e-01,\n", + " 1.83450043e-01, 1.75598450e-02, -2.05802217e-01,\n", + " -8.66022483e-02, 6.08737469e-02, -2.22572535e-01,\n", + " -1.20819479e-01, -1.68945014e-01, 2.10285246e-01,\n", + " -2.40360171e-01, -1.79741889e-01, -1.93881094e-01,\n", + " 1.32005673e-03, -8.93675536e-02, -1.65670961e-01,\n", + " -8.00130144e-02, -2.01122567e-01, 1.55159965e-01,\n", + " 4.84559573e-02, -1.92197278e-01, 1.46897465e-01,\n", + " -1.71061575e-01, 5.79360016e-02, -1.45457163e-01,\n", + " 1.78534076e-01, 1.95346713e-01, -9.44947526e-02,\n", + " -2.78981924e-01, -1.16451114e-01, 1.21675292e-02,\n", + " -1.05452980e-03, 2.97299847e-02, 1.15553983e-01,\n", + " -1.47618756e-01, 2.83984572e-01, -9.44054872e-02,\n", + " -6.82652295e-02, 1.54531911e-01, -9.11844522e-02,\n", + " 2.69836523e-02, -3.09856743e-01, 6.67436346e-02,\n", + " 2.40427703e-01]], dtype=float32),\n", + " array([[ 0.04073317, -0.16170031, 0.08170982, ..., 0.12143514,\n", + " -0.03804543, -0.1848121 ],\n", + " [-0.02006259, 0.04184515, 0.20358184, ..., 0.08938669,\n", + " 0.02554417, -0.0998741 ],\n", + " [-0.05848309, -0.13393435, 0.28651938, ..., -0.19336581,\n", + " 0.28697622, -0.18376462],\n", + " ...,\n", + " [-0.13138615, -0.10152157, 0.05253223, ..., 0.16827357,\n", + " 0.09525165, 0.17411834],\n", + " [-0.00976845, -0.10780089, 0.2228816 , ..., 0.1733975 ,\n", + " -0.10156322, 0.03318954],\n", + " [ 0.09590832, -0.01828083, 0.12743485, ..., 0.25016934,\n", + " 0.12800731, -0.10581163]], dtype=float32),\n", + " array([[-0.00384881, -0.12021059, 0.01248708, ..., 0.01682259,\n", + " -0.17754331, 0.02930963],\n", + " [-0.03520177, 0.0117013 , 0.03343487, ..., -0.16231427,\n", + " 0.1756002 , 0.00351096],\n", + " [-0.1752005 , 0.004585 , -0.11959553, ..., -0.17236647,\n", + " 0.28346488, 0.26809448],\n", + " ...,\n", + " [ 0.01488994, 0.00250473, -0.25695267, ..., -0.11059541,\n", + " 0.17581026, -0.23348542],\n", + " [ 0.21297403, 0.24602796, 0.06359419, ..., 0.205567 ,\n", + " 0.04510517, 0.11687386],\n", + " [-0.17597616, 0.07059528, 0.10327347, ..., -0.02315794,\n", + " 0.00959007, -0.01356981]], dtype=float32),\n", + " array([[-2.82930046e-01, 1.26908660e-01],\n", + " [ 2.37486243e-01, -3.81716669e-01],\n", + " [ 9.92978290e-02, -3.47963899e-01],\n", + " [ 3.02352726e-01, -3.74164760e-01],\n", + " [-2.05417976e-01, 2.52470911e-01],\n", + " [-4.55201864e-02, -2.02432677e-01],\n", + " [ 1.73006430e-01, -4.46816646e-02],\n", + " [-2.84130216e-01, -2.26977065e-01],\n", + " [-4.35910234e-03, 3.76744062e-01],\n", + " [ 1.45330116e-01, -3.25348943e-01],\n", + " [ 2.28147835e-01, -2.77784109e-01],\n", + " [-1.19501755e-01, 4.07545753e-02],\n", + " [ 1.01264335e-01, -2.43342578e-01],\n", + " [-1.60477936e-01, 5.24386704e-01],\n", + " [ 6.06849305e-02, 9.89513546e-02],\n", + " [-2.89398909e-01, 1.83537304e-01],\n", + " [ 1.01001307e-01, 2.95499355e-01],\n", + " [-2.97017217e-01, 3.22097719e-01],\n", + " [ 1.97861195e-01, -2.02269956e-01],\n", + " [-7.52512068e-02, 9.88621786e-02],\n", + " [-1.38137221e-01, -4.40452248e-01],\n", + " [-2.33402535e-01, 1.64692864e-01],\n", + " [ 1.27101064e-01, 1.98759794e-01],\n", + " [-3.01784992e-01, 2.12811917e-01],\n", + " [-1.96352318e-01, 1.54295802e-01],\n", + " [ 2.49975443e-01, -2.01082289e-01],\n", + " [-1.38984874e-01, 2.29037121e-01],\n", + " [ 1.06105595e-04, 2.89339125e-01],\n", + " [-3.00384670e-01, -4.83968072e-02],\n", + " [-1.86271910e-02, 3.05029899e-01],\n", + " [ 1.99009106e-03, 2.14236692e-01],\n", + " [-3.34532440e-01, -6.11541159e-02],\n", + " [-1.35282487e-01, 5.65957166e-02],\n", + " [-2.68078923e-01, 1.56603098e-01],\n", + " [-2.48180240e-01, 2.94318020e-01],\n", + " [-1.30787000e-01, -4.86690477e-02],\n", + " [ 2.72127450e-01, 1.19140044e-01],\n", + " [ 2.90248722e-01, -1.86683103e-01],\n", + " [ 9.85520706e-02, -2.18973175e-01],\n", + " [-3.67538421e-03, -1.40206725e-03],\n", + " [ 2.09687546e-01, 2.38504097e-01],\n", + " [-1.00464404e-01, 2.42502570e-01],\n", + " [ 1.55400500e-01, 1.01416796e-01],\n", + " [-3.02865952e-01, 7.42565766e-02],\n", + " [-1.75600335e-01, 3.34860444e-01],\n", + " [-1.38489887e-01, 2.21242890e-01],\n", + " [ 1.80740595e-01, 2.85507560e-01],\n", + " [ 2.81139612e-01, -1.64098963e-01],\n", + " [ 1.56524777e-01, -3.87664348e-01],\n", + " [ 4.04402643e-01, 3.33227254e-02]], dtype=float32)],\n", + " [array([-0.02854579, 0.05311754, 0.04959053, -0.03715162, 0.02330466,\n", + " -0.04269688, -0.01013005, 0.06874033, -0.01974296, 0.06377108,\n", + " 0.01915232, -0.02997592, 0.04385599, -0.03190788, 0.0822746 ,\n", + " -0.0659938 , 0.05614723, -0.05916027, 0.00469705, -0.07926938,\n", + " -0.04140961, 0.04683878, 0.05930374, 0.03659062, 0.01063251,\n", + " 0.08476436, -0.05785139, 0.06547377, -0.06716973, -0.03636409,\n", + " -0.07686226, -0.06849793, 0.05797694, 0.02435303, 0.03509652,\n", + " -0.05251152, 0. , 0.08996248, 0.03742762, -0.01523749,\n", + " 0.03550566, -0.03804714, -0.00484502, -0.00379997, -0.08697824,\n", + " 0.05449862, 0.0753291 , -0.04323955, 0.06112291, 0.07690092,\n", + " -0.0228012 , 0.0299318 , -0.00519692, -0.02068757, -0.00509608,\n", + " -0.05934853, 0.01074862, 0.04968895, -0.06797037, -0.03621136,\n", + " 0.04376265, 0.05112915, -0.01860299, 0.1174278 , -0.07478895,\n", + " -0.03162573, 0.03712742, -0.02407979, 0.01463216, -0.06149809],\n", + " dtype=float32),\n", + " array([-0.00573702, 0.02881279, 0.08348639, 0.05104946, 0.00677452,\n", + " 0.07581051, -0.00896564, 0.05683671, -0.04194446, -0.03258895,\n", + " 0.02584886, -0.01876919, -0.03003295, 0.06011844, -0.03459567,\n", + " -0.02480574, -0.02663354, -0.01426572, 0.04364663, 0.08130559,\n", + " 0.06495807, -0.04167927, 0.05315804, 0.06356885, -0.0253933 ,\n", + " 0.04731547, -0.05747719, 0.05308496, -0.05659075, 0.05597835,\n", + " -0.04872588, -0.04101903, 0.03206719, 0.05267171, -0.05795552,\n", + " 0.02027026, -0.05280019, 0.05442906, -0.01780812, 0.06429685,\n", + " -0.04762284, 0.06730605, -0.05093784, 0.05504497, 0.05357518,\n", + " -0.02114536, 0.07897805, 0.02339305, 0.08138859, -0.03255979],\n", + " dtype=float32),\n", + " array([-0.02713361, 0.09309322, 0.03805388, 0.03424142, -0.01113803,\n", + " 0.01667055, 0.07447569, -0.06668621, -0.03150655, 0.06387825,\n", + " 0.07785708, -0.02531729, 0.0723519 , 0.00347189, -0.0177357 ,\n", + " 0. , -0.03487411, -0.04851071, 0.06759044, -0.03533144,\n", + " 0.02401838, -0.03073125, -0.04450166, -0.02654641, -0.02827855,\n", + " 0. , -0.01567897, -0.0040624 , -0.02975524, -0.03281569,\n", + " -0.02176444, -0.03663797, -0.01836163, -0.01098274, -0.05059136,\n", + " -0.02361586, 0.06188901, 0.07352342, 0.05501923, -0.04410602,\n", + " -0.04001745, -0.05807052, 0.05206716, -0.01867674, -0.0089961 ,\n", + " -0.01759916, -0.01311922, -0.01491712, 0.05937983, 0.06336498],\n", + " dtype=float32),\n", + " array([ 0.05128358, -0.05128355], dtype=float32)])" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "get_parameters(model)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "([array([[ 1.49011612e-06, -1.16565228e-02, 7.42741823e-02,\n", + " 1.84727460e-02, -8.54203105e-03, -1.07147992e-02,\n", + " -1.21345818e-02, 7.30311871e-02, -9.66983289e-03,\n", + " 6.97897077e-02, 3.34345400e-02, 8.74996185e-05,\n", + " 8.57096016e-02, 1.60385072e-02, 1.19651444e-02,\n", + " 2.49370933e-03, 8.43369067e-02, 1.64110810e-02,\n", + " 2.05034018e-02, -1.07030272e-02, 1.41942352e-02,\n", + " 8.08978826e-03, 3.21836621e-02, 5.65186143e-02,\n", + " -4.39171195e-02, 4.30863351e-03, -1.15206484e-02,\n", + " 2.10180618e-02, -7.25415349e-03, 3.23727727e-06,\n", + " -3.62303257e-02, 1.32649988e-02, 6.42685592e-03,\n", + " 1.92415118e-02, 3.42182815e-02, 1.04672834e-02,\n", + " 0.00000000e+00, 2.78882086e-02, 4.19589877e-02,\n", + " -2.18840521e-02, 9.14855003e-02, 4.81577218e-03,\n", + " -5.83276898e-03, 1.10103935e-02, -1.27351061e-02,\n", + " 1.04498863e-02, -4.18264568e-02, 2.71596760e-02,\n", + " 2.05704421e-02, -3.03956270e-02, 3.23033035e-02,\n", + " -4.21546996e-02, 4.45771776e-02, -3.91504467e-02,\n", + " -8.15168023e-02, 2.26370320e-02, -4.53699529e-02,\n", + " -4.22071815e-02, -1.31393373e-02, 1.32917091e-02,\n", + " 3.45049202e-02, -2.07801014e-02, -3.54883261e-03,\n", + " 1.25047117e-02, -6.88902661e-03, -3.12848352e-02,\n", + " 1.91906095e-02, 4.19499911e-02, -6.82501495e-02,\n", + " -1.53118074e-02],\n", + " [ 1.26473606e-06, 3.19551006e-02, -9.95661318e-03,\n", + " 1.19727850e-02, -1.79602206e-03, 1.53238177e-02,\n", + " 3.48728746e-02, -6.26268983e-03, 1.16313845e-02,\n", + " -1.24692582e-02, 7.92261958e-03, 1.70102119e-02,\n", + " -6.51185215e-03, 1.54842287e-02, -1.05841383e-02,\n", + " 6.68349117e-03, -5.42246550e-03, 2.38585621e-02,\n", + " -4.80242074e-03, 4.91370745e-02, 1.68562233e-02,\n", + " 1.78126954e-02, 1.90734789e-02, 1.65916681e-02,\n", + " 1.61990523e-03, -8.57291371e-03, 2.71531343e-02,\n", + " -7.27659464e-03, 4.15879637e-02, 2.78651714e-06,\n", + " -5.89160994e-03, -5.41770346e-02, 2.21652985e-02,\n", + " 2.85971761e-02, 2.64423043e-02, 2.02812403e-02,\n", + " 0.00000000e+00, 4.61198390e-03, 1.03888065e-02,\n", + " -1.25160664e-02, 6.73976541e-03, 7.16526806e-03,\n", + " 2.04948001e-02, 3.80229056e-02, 5.23998961e-03,\n", + " 1.47194266e-02, -3.79353762e-04, 1.02847666e-02,\n", + " 2.61461437e-02, -2.40194798e-02, 6.51396215e-02,\n", + " -3.63333039e-02, -5.51320203e-02, 6.67399764e-02,\n", + " 2.30744481e-04, -1.98918562e-02, -1.48451179e-02,\n", + " 2.16108412e-02, 2.36949362e-02, 1.34854689e-02,\n", + " 2.18952857e-02, -3.03417742e-02, 1.44592375e-02,\n", + " -5.22305071e-03, 2.16116756e-02, -1.71370506e-02,\n", + " -1.07840151e-02, 2.18297653e-02, -4.98308837e-02,\n", + " 6.17059022e-02],\n", + " [ 7.59959221e-07, 6.49914891e-03, 6.02702424e-03,\n", + " 4.44378480e-02, -3.70854139e-03, -4.58493829e-03,\n", + " -6.07848167e-03, 1.07879266e-02, -7.88053870e-03,\n", + " 5.81801683e-03, -5.74702024e-03, -1.39639974e-02,\n", + " -6.83215261e-03, 4.36457917e-02, 3.14892232e-02,\n", + " 7.88080785e-03, -8.65940750e-03, 2.13787258e-02,\n", + " -4.84359264e-03, 3.76113243e-02, -7.68777728e-03,\n", + " 3.44041884e-02, 2.17271540e-02, 3.49786282e-02,\n", + " -9.57245380e-03, 1.17718354e-02, 1.76328421e-03,\n", + " 2.47428566e-02, -1.89929008e-02, 9.83476639e-07,\n", + " -4.87717986e-03, 1.81090236e-02, -1.68472528e-03,\n", + " 1.75802708e-02, 1.03006195e-02, 2.29674578e-03,\n", + " 0.00000000e+00, 1.71143487e-02, -5.93899190e-03,\n", + " -8.52180272e-03, 2.21078098e-03, -6.65974617e-03,\n", + " -5.27688861e-03, 1.78252906e-03, 1.92168504e-02,\n", + " 3.20638865e-02, 3.12381685e-02, -9.20593739e-04,\n", + " 2.36202478e-02, 1.44775957e-02, -2.05530357e-02,\n", + " 1.06063485e-03, 3.32839042e-03, -1.01231039e-03,\n", + " -2.06442177e-03, 9.97217000e-03, 2.65306234e-03,\n", + " 1.02972686e-02, -7.31389225e-03, 8.04215856e-03,\n", + " 9.98052955e-03, -1.05124712e-03, -6.54879957e-03,\n", + " 2.02297121e-02, 1.26888603e-02, 2.11532712e-02,\n", + " 8.01965594e-03, -1.44157931e-02, -1.14605129e-02,\n", + " 5.96144795e-03],\n", + " [ 0.00000000e+00, -1.39323249e-02, 6.39429688e-03,\n", + " -6.28607944e-02, -8.64857435e-03, 1.24468859e-02,\n", + " 2.86395848e-03, 8.28726962e-03, 1.22120306e-02,\n", + " 3.18306237e-02, -6.85551018e-03, 2.63845623e-02,\n", + " 2.07465943e-02, 0.00000000e+00, 2.27561444e-02,\n", + " 1.00681484e-02, 2.27013733e-02, -6.71585724e-02,\n", + " -7.02583045e-03, -1.04030967e-02, -4.20713425e-03,\n", + " 1.82433575e-02, -1.60830319e-02, 1.62253212e-02,\n", + " -1.33093297e-02, -3.11521888e-02, 3.77973914e-03,\n", + " 5.21981716e-03, -1.98365897e-02, 0.00000000e+00,\n", + " 2.18344629e-02, -2.21273601e-02, 3.62999588e-02,\n", + " 8.26938823e-03, -7.49810040e-03, -1.21858753e-02,\n", + " 0.00000000e+00, -1.79555565e-02, 6.84709847e-03,\n", + " -5.80060855e-02, 5.37337959e-02, -3.77775021e-02,\n", + " -7.55639747e-04, -2.18640100e-02, -4.38468456e-02,\n", + " 1.06569529e-02, -2.04435959e-02, 1.75144523e-02,\n", + " -3.92878056e-03, -3.37126106e-03, 9.41401571e-02,\n", + " 2.34664902e-02, 3.57846729e-02, -1.89120620e-02,\n", + " -3.53061408e-02, -1.61360130e-02, -7.82630816e-02,\n", + " 2.25571841e-02, -3.70850414e-03, 0.00000000e+00,\n", + " 1.72483921e-03, -3.43789682e-02, -6.79399073e-03,\n", + " 8.89310054e-03, -6.08799458e-02, -2.71661580e-03,\n", + " -3.56471390e-02, 1.11350343e-02, -2.56258100e-02,\n", + " 2.05378458e-02],\n", + " [ 2.23517418e-06, -7.50984997e-02, 1.20643079e-02,\n", + " 2.09403038e-02, 1.78195983e-02, 1.29758865e-02,\n", + " 9.08046961e-04, -7.36726820e-03, 8.75070691e-03,\n", + " -3.93393636e-03, 4.93140984e-03, -7.30901957e-03,\n", + " -2.97645107e-03, 4.18809950e-02, 1.70973763e-02,\n", + " 5.04964590e-03, -1.19856521e-02, -5.31809032e-03,\n", + " 9.80186462e-03, 6.25750273e-02, 7.91132450e-04,\n", + " 5.25092632e-02, 1.74276680e-02, 2.94222236e-02,\n", + " -9.10501927e-03, 1.67332776e-03, 2.11818963e-02,\n", + " 1.82696581e-02, -2.05858052e-03, 1.21816993e-06,\n", + " -2.51303613e-02, -7.00031966e-03, -1.90818682e-02,\n", + " 2.38134339e-02, 1.85852498e-02, 1.88816562e-02,\n", + " 0.00000000e+00, 2.37748027e-03, 1.05057806e-02,\n", + " -5.73440939e-02, -5.06321341e-03, 1.27147976e-02,\n", + " -9.73396003e-04, 9.90636647e-03, 2.64340937e-02,\n", + " 1.75562948e-02, 1.66450143e-02, 1.51741505e-03,\n", + " 2.09401064e-02, 2.11877432e-02, -1.54795647e-02,\n", + " 1.00296140e-02, -3.15607935e-02, -2.11823583e-02,\n", + " 5.22204116e-02, 6.60809129e-02, -2.77949665e-02,\n", + " 1.30063444e-02, 1.25313178e-03, 2.72946656e-02,\n", + " 4.89727184e-02, -1.69895999e-02, 1.30629092e-02,\n", + " 1.55981630e-03, 3.25212553e-02, 2.51948833e-04,\n", + " -5.40875457e-03, -1.81352496e-02, -5.66543639e-03,\n", + " 2.94677094e-02],\n", + " [ 0.00000000e+00, -4.97145504e-02, -3.21335346e-03,\n", + " 2.53131799e-02, 4.13774513e-03, -1.40788406e-03,\n", + " -5.41490316e-03, -3.77743170e-02, 5.08022308e-03,\n", + " 1.75616145e-03, 1.51727945e-02, 1.46141499e-02,\n", + " 4.79898602e-02, 5.34517616e-02, 1.71546526e-02,\n", + " 3.00733373e-03, 2.73067039e-02, 4.28619981e-03,\n", + " 1.95392966e-03, -1.62802637e-03, -1.72183998e-02,\n", + " -9.32161510e-03, 8.53136927e-03, -1.32551491e-02,\n", + " -2.45124102e-04, 2.85785496e-02, -4.13239896e-02,\n", + " 9.71779227e-03, -4.31602448e-03, 0.00000000e+00,\n", + " 2.40202248e-03, -5.27902842e-02, -3.45261432e-02,\n", + " -1.48351621e-02, 1.39084011e-02, 1.61831826e-02,\n", + " 0.00000000e+00, -5.43668903e-02, 7.43016601e-03,\n", + " 1.10875398e-01, -9.89424437e-03, 1.02454871e-02,\n", + " -2.66520679e-03, -3.40092704e-02, -2.39770263e-02,\n", + " -4.27880883e-02, -1.02666572e-01, -4.84505296e-03,\n", + " -3.08043435e-02, 2.15946883e-02, 7.27008581e-02,\n", + " -1.56747103e-02, 3.66818905e-03, 2.17209607e-02,\n", + " -1.88962929e-02, -1.54630542e-02, 3.92973498e-02,\n", + " 3.84222344e-02, 2.19388902e-02, 1.04343221e-02,\n", + " 3.98382545e-03, 2.20829248e-03, 1.71164274e-02,\n", + " 6.02347218e-03, -4.65963036e-03, 5.45533001e-03,\n", + " 1.20884478e-02, 1.58099681e-02, 3.72115970e-02,\n", + " 4.13580202e-02],\n", + " [ 5.59026375e-07, 2.19073892e-02, 1.40187293e-02,\n", + " 4.17544991e-02, 7.14531541e-03, 1.93578005e-03,\n", + " 3.68960202e-04, 9.27004218e-03, -7.46112317e-04,\n", + " 7.42256641e-03, 6.71111792e-03, 1.33904815e-02,\n", + " -7.87302852e-03, 1.38419718e-02, 2.20113285e-02,\n", + " 2.13821232e-02, -3.75646353e-03, 2.36315280e-02,\n", + " 6.75320625e-05, 2.17337422e-02, 3.00095975e-03,\n", + " 2.45792195e-02, 1.44614577e-02, 2.65318509e-02,\n", + " -1.13605559e-02, 8.07042047e-03, 3.87509167e-03,\n", + " 2.12808922e-02, 6.70775771e-04, 7.59959221e-07,\n", + " -5.00490516e-03, 2.46562809e-03, -7.56940991e-03,\n", + " 1.34187639e-02, 4.02834602e-02, 1.11031756e-02,\n", + " 0.00000000e+00, -2.19111145e-03, 2.61621177e-03,\n", + " 1.26020499e-02, -5.91659546e-03, 4.12940979e-03,\n", + " 7.18377531e-04, 1.28328502e-02, 9.83218849e-03,\n", + " 1.39725059e-02, 7.12452829e-03, 9.86279920e-03,\n", + " 1.29076242e-02, 1.15619823e-02, -2.26510987e-02,\n", + " -1.01644099e-02, 6.79536909e-03, 3.28451395e-03,\n", + " -2.67382041e-02, 3.20950896e-03, 2.13438272e-03,\n", + " 2.84881890e-03, -1.82493776e-03, 2.74327397e-03,\n", + " 5.69045544e-04, 1.74677297e-02, 2.59828940e-03,\n", + " 9.38987918e-03, 2.89438292e-02, -1.68662779e-02,\n", + " 2.16373429e-03, 5.52284718e-02, 1.71898305e-03,\n", + " 1.84112042e-02],\n", + " [ 1.92970037e-06, -7.33059645e-03, 6.08241558e-03,\n", + " 3.09040304e-02, -2.66182497e-02, -1.04110688e-03,\n", + " 6.55319542e-03, 1.30551010e-02, -6.43594936e-03,\n", + " 3.14751267e-03, -4.20112303e-03, -3.34899174e-03,\n", + " 1.19088590e-03, 2.75004655e-02, 1.39071923e-02,\n", + " -3.07884067e-02, -6.91840798e-03, 3.75329219e-02,\n", + " -9.75349918e-03, 1.89966261e-02, 1.39560252e-02,\n", + " 1.82078369e-02, 2.44169421e-02, 3.45724225e-02,\n", + " 3.83076072e-03, 2.96819210e-03, -2.79713571e-02,\n", + " 1.44217312e-02, -1.04773343e-02, 2.71573663e-06,\n", + " -6.85906410e-03, -2.69274563e-02, -1.63219962e-02,\n", + " 2.06331909e-02, 1.40391737e-02, 9.01226699e-03,\n", + " 0.00000000e+00, -8.03788006e-03, 2.00668350e-02,\n", + " 1.17101632e-02, 8.09775665e-03, 1.32986456e-02,\n", + " 5.69504499e-03, 3.66747677e-02, -3.86754144e-03,\n", + " 2.89282054e-02, 3.69293764e-02, -2.47304142e-03,\n", + " 1.71321183e-02, -7.54578412e-03, -9.61725414e-03,\n", + " 1.61448121e-03, -1.06247291e-02, 3.18420455e-02,\n", + " 3.38531211e-02, 3.97837535e-02, -5.06854057e-03,\n", + " 4.11833264e-03, -2.18693763e-02, 0.00000000e+00,\n", + " 4.73356694e-02, 2.79366970e-02, 1.81355327e-03,\n", + " 1.18001401e-02, 1.80805475e-02, -4.14552540e-03,\n", + " -4.44880128e-03, -3.96236032e-03, -2.36448124e-02,\n", + " 1.94880292e-02],\n", + " [ 6.82026148e-05, 7.18571916e-02, 4.76970673e-02,\n", + " 9.57600772e-02, 1.31565988e-01, -1.64669454e-01,\n", + " -7.16286302e-02, 1.19898200e-01, -2.11336374e-01,\n", + " 1.21684730e-01, 1.44014955e-01, -2.38232017e-01,\n", + " 1.46218359e-01, 4.44871187e-03, 7.22973943e-02,\n", + " 2.78884321e-02, 1.45830631e-01, -1.08749241e-01,\n", + " -3.77449691e-02, -1.96678042e-02, -1.76536024e-01,\n", + " 5.42484522e-02, 6.87512457e-02, 2.47801125e-01,\n", + " 1.44790769e-01, 1.97795749e-01, 4.98535782e-02,\n", + " 2.33573496e-01, 1.99592710e-02, 2.87592411e-06,\n", + " -1.66886806e-01, -1.45491302e-01, 2.27394998e-01,\n", + " 8.62296224e-02, -2.82013416e-03, -1.88887715e-01,\n", + " 0.00000000e+00, 6.12580180e-02, 1.33518666e-01,\n", + " -1.46177143e-01, 1.59401894e-02, -2.11909533e-01,\n", + " -1.65744841e-01, 1.35906681e-01, -2.10718960e-02,\n", + " 1.91993415e-01, 2.53243625e-01, -1.76544130e-01,\n", + " 4.48338389e-02, 1.66558981e-01, 4.43497002e-01,\n", + " 9.72419083e-02, -1.00000083e-01, 3.86097550e-01,\n", + " -1.01794243e-01, -5.31688072e-02, -7.42531270e-02,\n", + " 2.60491490e-01, -1.89655364e-01, -1.26332268e-02,\n", + " 5.77271283e-02, 1.28030062e-01, -1.77092314e-01,\n", + " 2.60912895e-01, -1.63142890e-01, -2.03584060e-02,\n", + " 1.95915878e-01, -1.84553564e-02, 1.37363464e-01,\n", + " -1.53825760e-01],\n", + " [ 0.00000000e+00, 5.06502837e-02, 1.29767060e-02,\n", + " 6.28483295e-02, 6.34282753e-02, -1.72892511e-02,\n", + " -9.57809016e-03, -1.54778808e-02, 9.92524205e-04,\n", + " -7.11901113e-03, -1.05080698e-02, 1.87457874e-02,\n", + " 3.44372615e-02, -4.49787974e-02, 6.05929643e-03,\n", + " -8.81178081e-02, -3.13648731e-02, -3.27146053e-03,\n", + " 1.48571506e-02, 9.23416018e-03, -1.23183951e-02,\n", + " 4.48894389e-02, 5.51292300e-02, -4.26419824e-02,\n", + " 2.51580812e-02, 5.67226112e-03, -1.00748777e-01,\n", + " -2.07852945e-02, -2.21131444e-02, 0.00000000e+00,\n", + " -2.12397873e-02, 1.61559135e-03, 1.88499428e-02,\n", + " 8.83308202e-02, 0.00000000e+00, -2.65170336e-02,\n", + " 0.00000000e+00, -3.47553790e-02, 3.28133181e-02,\n", + " -1.98322237e-02, -5.88746481e-02, 3.99494730e-02,\n", + " -2.77045369e-03, -2.44742259e-04, -3.06502879e-02,\n", + " -2.40243226e-03, 4.12266403e-02, -4.80247736e-02,\n", + " 3.03203985e-02, -2.57963911e-02, -3.67290676e-02,\n", + " 6.71202391e-02, -8.20160732e-02, 2.72179991e-02,\n", + " 2.75232196e-02, 2.28866339e-02, -4.97791916e-03,\n", + " 3.20631042e-02, 2.17096210e-02, 0.00000000e+00,\n", + " -4.09361497e-02, 2.31957436e-02, -2.09520608e-02,\n", + " -1.32357776e-02, 5.95888197e-02, 0.00000000e+00,\n", + " -8.28909650e-02, 4.27012593e-02, 2.20479891e-02,\n", + " 2.66924053e-02],\n", + " [ 2.80141830e-06, 3.44573855e-02, 1.68616474e-02,\n", + " -2.16292441e-02, 3.42520177e-02, -2.51318514e-03,\n", + " -2.99223959e-02, 3.11351269e-02, -7.86145031e-03,\n", + " 3.79217640e-02, 5.88734746e-02, -2.47791409e-03,\n", + " -1.02755167e-02, 1.48909986e-02, 6.38209134e-02,\n", + " -2.27537304e-02, -2.68034339e-02, 8.43681395e-03,\n", + " 5.44354320e-04, 2.46977806e-03, -6.86615705e-03,\n", + " 5.15806824e-02, 3.51931527e-02, 4.50379923e-02,\n", + " 1.82978213e-02, 1.10631064e-02, 5.33567518e-02,\n", + " 5.02201617e-02, -2.68985629e-02, 2.14576721e-06,\n", + " 3.81427701e-03, 2.03454792e-02, -1.38287246e-02,\n", + " -2.03619152e-02, 3.11780423e-02, 1.26634762e-02,\n", + " 0.00000000e+00, 2.77237929e-02, 9.62861814e-03,\n", + " -1.84376612e-02, 9.24696773e-03, -5.45406435e-03,\n", + " -1.83334649e-02, 7.35372305e-04, -1.68629438e-02,\n", + " 8.47994536e-03, -1.79683268e-02, 1.69715211e-02,\n", + " 1.91382784e-02, 4.71077561e-02, -4.38696444e-02,\n", + " 5.99594712e-02, 7.32099563e-02, -4.90160286e-02,\n", + " 2.35909577e-02, 1.54711753e-02, 1.36633702e-02,\n", + " -4.24712151e-02, -2.49780267e-02, 7.48580322e-03,\n", + " 6.65621459e-03, -2.29308978e-02, -2.68799514e-02,\n", + " 1.97568089e-02, 6.91704005e-02, 2.21450403e-02,\n", + " -2.95889378e-03, -7.42284954e-03, 3.10672224e-02,\n", + " 7.04184175e-04],\n", + " [ 9.68575478e-08, 3.99114043e-02, -9.91953909e-03,\n", + " 3.25712711e-02, -1.46806389e-02, 5.75292110e-03,\n", + " 4.84618917e-03, 5.86069934e-03, -4.93802130e-04,\n", + " 1.27205700e-02, -1.29961967e-03, 6.42070174e-03,\n", + " -1.86905414e-02, 3.09723020e-02, 1.03429109e-02,\n", + " 5.13049960e-03, -5.01155853e-04, 3.36712301e-02,\n", + " -3.53804231e-03, 3.75385135e-02, 8.72188807e-03,\n", + " 1.75718218e-02, 3.87762487e-03, 3.15307751e-02,\n", + " -1.69710331e-02, 7.96510279e-03, 1.68320760e-02,\n", + " 2.14911848e-02, 3.29448655e-03, 4.91738319e-07,\n", + " 9.02945548e-03, 1.97583884e-02, -3.59129906e-03,\n", + " 3.39994431e-02, 1.95574164e-02, 1.80864036e-02,\n", + " 0.00000000e+00, 1.42792165e-02, 4.98805940e-03,\n", + " 3.40463128e-03, 3.10290605e-03, -1.59697235e-03,\n", + " 3.81151587e-03, -1.06487721e-02, 2.39797831e-02,\n", + " 1.12190805e-02, 1.06314570e-02, 1.27038062e-02,\n", + " 5.77275455e-03, 4.94579226e-03, -7.09898770e-03,\n", + " 1.74231827e-03, 1.20699406e-06, -1.24374777e-03,\n", + " 6.46207333e-02, 2.14003474e-02, 1.49126686e-02,\n", + " 6.48828223e-03, -5.38637117e-03, 1.02209002e-02,\n", + " 7.54408538e-03, 7.72580504e-03, 3.43549997e-03,\n", + " 7.72102922e-03, 2.00577080e-02, 4.22081202e-02,\n", + " -5.91963530e-04, 1.37665868e-02, -9.87207890e-03,\n", + " 2.15003788e-02]], dtype=float32),\n", + " array([[-2.91615725e-05, 0.00000000e+00, 0.00000000e+00, ...,\n", + " 1.14738941e-05, 0.00000000e+00, 0.00000000e+00],\n", + " [ 8.49024765e-03, -4.31102738e-02, -5.05789816e-02, ...,\n", + " 1.43139213e-02, -3.98727357e-02, -8.07711482e-03],\n", + " [ 1.00968346e-01, 7.57245272e-02, -2.73677707e-03, ...,\n", + " -3.34357619e-02, -2.99723148e-02, 3.60149145e-03],\n", + " ...,\n", + " [ 2.19667554e-02, 1.55268461e-02, 1.16532966e-02, ...,\n", + " -6.69102073e-02, -4.36145514e-02, 8.49020481e-03],\n", + " [ 4.64867577e-02, 2.68688127e-02, -7.21535087e-03, ...,\n", + " -1.35802627e-02, 1.19180009e-02, 1.85405612e-02],\n", + " [-4.64454293e-04, 2.76498124e-02, -5.06965816e-03, ...,\n", + " 2.36923993e-03, 1.80526227e-02, 2.38161311e-02]], dtype=float32),\n", + " array([[-0.03158482, 0.03222603, 0.02458179, ..., 0.08041045,\n", + " 0.01049814, -0.04392172],\n", + " [ 0.01048962, -0.01012598, -0.01415333, ..., -0.13836852,\n", + " 0.00485021, -0.01224236],\n", + " [-0.00740135, -0.00507368, 0.00139559, ..., -0.13694564,\n", + " -0.00293422, 0.00983405],\n", + " ...,\n", + " [ 0.00256901, -0.00454452, 0.00040635, ..., 0.00017413,\n", + " 0.0033875 , 0.03571756],\n", + " [-0.01601908, 0.00801374, -0.00454768, ..., -0.10811353,\n", + " 0.01376003, -0.00697925],\n", + " [-0.0469159 , 0.01431713, 0.0028784 , ..., 0.02210182,\n", + " -0.01271321, -0.04000308]], dtype=float32),\n", + " array([[-1.15624368e-02, 1.15615055e-02],\n", + " [ 7.26526976e-03, -7.26565719e-03],\n", + " [ 2.03531086e-02, -2.03537643e-02],\n", + " [-6.14860654e-03, 6.14863634e-03],\n", + " [-1.56270713e-02, 1.56272203e-02],\n", + " [-1.02237239e-03, 1.02218986e-03],\n", + " [ 1.03584975e-02, -1.03584863e-02],\n", + " [ 2.06144452e-02, -2.06139088e-02],\n", + " [-1.49596017e-04, 1.48773193e-04],\n", + " [ 4.79128957e-03, -4.79125977e-03],\n", + " [ 6.30669296e-03, -6.30635023e-03],\n", + " [ 0.00000000e+00, 0.00000000e+00],\n", + " [-3.16215307e-03, 3.16265225e-03],\n", + " [ 5.85643947e-02, -5.85644245e-02],\n", + " [-7.11206347e-04, 7.11120665e-04],\n", + " [-1.32675469e-02, 1.32675469e-02],\n", + " [-4.61159647e-03, 4.61125374e-03],\n", + " [ 1.90031528e-03, -1.90034509e-03],\n", + " [ 1.55715942e-02, -1.55715793e-02],\n", + " [-1.19959861e-02, 1.19959936e-02],\n", + " [-1.39654800e-02, 1.39658749e-02],\n", + " [-5.74159622e-03, 5.74158132e-03],\n", + " [-2.31015980e-02, 2.31016129e-02],\n", + " [-3.14235687e-04, 3.14325094e-04],\n", + " [ 2.00573653e-02, -2.00573802e-02],\n", + " [ 0.00000000e+00, 0.00000000e+00],\n", + " [-1.67117715e-02, 1.67115927e-02],\n", + " [-9.77398362e-03, 9.77447629e-03],\n", + " [ 0.00000000e+00, 0.00000000e+00],\n", + " [-1.44827366e-02, 1.44827068e-02],\n", + " [-4.94149094e-03, 4.94125485e-03],\n", + " [-7.87574053e-03, 7.87599012e-03],\n", + " [-1.95778906e-03, 1.95809081e-03],\n", + " [ 5.55017591e-02, -5.55018783e-02],\n", + " [ 1.07816160e-02, -1.07815862e-02],\n", + " [-3.71644944e-02, 3.71643975e-02],\n", + " [-2.17199326e-04, 2.16580927e-04],\n", + " [ 5.87260723e-03, -5.87230921e-03],\n", + " [-2.86758766e-02, 2.86761075e-02],\n", + " [-2.14453358e-02, 2.14453079e-02],\n", + " [-1.89587176e-02, 1.89587921e-02],\n", + " [-2.95607746e-03, 2.95570493e-03],\n", + " [ 1.51866078e-02, -1.51871741e-02],\n", + " [-1.82390213e-05, 1.85519457e-05],\n", + " [-3.47673893e-04, 3.47673893e-04],\n", + " [ 2.12892145e-02, -2.12893784e-02],\n", + " [-3.14503908e-04, 3.14474106e-04],\n", + " [-2.15784311e-02, 2.15792507e-02],\n", + " [-3.74348462e-03, 3.74361873e-03],\n", + " [-4.07821834e-02, 4.07828614e-02]], dtype=float32)],\n", + " [array([ 1.93715096e-06, -7.33081996e-03, -2.59644464e-02, 3.53913940e-02,\n", + " -3.17766666e-02, 2.61158086e-02, 3.09046637e-02, -1.39724910e-02,\n", + " 1.64157562e-02, -1.33834183e-02, -2.69608200e-02, 1.18504055e-02,\n", + " -3.39354798e-02, 2.74997652e-02, 8.21013749e-03, 3.14356387e-03,\n", + " -1.39056407e-02, 3.89015339e-02, -4.97761788e-03, 4.34829444e-02,\n", + " 1.99493878e-02, 1.66720618e-02, -8.93015787e-03, 1.38194822e-02,\n", + " -3.19721177e-02, -1.15122199e-02, 3.29295881e-02, -1.12637132e-03,\n", + " 2.69653797e-02, 2.71573663e-06, 2.13105604e-02, 8.05605203e-03,\n", + " -1.68812610e-02, 6.44891895e-03, 1.83254965e-02, 3.12887095e-02,\n", + " 0.00000000e+00, -1.13615617e-02, -1.91472992e-02, 2.73335315e-02,\n", + " -8.05212930e-03, 1.09619126e-02, 2.75111161e-02, 7.87547231e-03,\n", + " 5.21017611e-02, 8.26306641e-04, 4.48312610e-03, 1.90423653e-02,\n", + " -9.32561979e-03, -1.31163374e-02, -7.42557272e-03, -8.75856541e-03,\n", + " -3.20619484e-03, 1.95884481e-02, -3.87861580e-02, 7.03234226e-04,\n", + " -4.51306719e-03, 1.04681477e-02, 9.19631869e-03, 4.49468940e-03,\n", + " 8.59957188e-04, -1.29299201e-02, 2.50575244e-02, -6.41628355e-03,\n", + " 1.74099207e-02, 6.38409331e-03, -1.43194981e-02, -5.27505018e-03,\n", + " -2.04372257e-02, 3.64710018e-02], dtype=float32),\n", + " array([ 0.00331618, 0.0539567 , -0.03811157, -0.01109656, 0.02418597,\n", + " -0.02698364, -0.00161759, -0.02607356, 0.00603188, 0.0166943 ,\n", + " 0.00161197, 0.01395666, -0.02311181, -0.02586061, 0.01232279,\n", + " 0. , 0.00035155, 0.01375218, 0.01308896, -0.01829513,\n", + " -0.01323594, 0.01266809, -0.01157204, -0.01712865, 0.00102041,\n", + " -0.00915188, 0.00994506, 0.0003348 , 0.01999653, -0.01448206,\n", + " 0.02454722, -0.00036955, 0.01865685, -0.00514613, 0.01885039,\n", + " -0.02837751, 0.00375455, -0.04659929, 0.00275468, 0.00854283,\n", + " 0.02362405, -0.02947872, 0.01506701, -0.00749737, -0.02634757,\n", + " 0.02495424, -0.02921787, 0.01843503, -0.00016499, 0.01456888],\n", + " dtype=float32),\n", + " array([ 0.00049266, -0.02066331, -0.0207266 , -0.01927541, 0.02330254,\n", + " 0.00480025, -0.0293951 , 0.01040582, 0.00090722, -0.0259526 ,\n", + " -0.02241635, 0. , -0.02288882, 0.02825862, 0.00242692,\n", + " 0.02985683, 0.00974142, 0.00933962, -0.02005748, 0.01471977,\n", + " 0.00262542, 0.00572324, 0.02363591, 0.00187776, 0.01032351,\n", + " 0. , -0.01178178, -0.01856739, 0. , 0.01882949,\n", + " -0.00219416, 0.01156303, 0.00352357, 0.01064696, 0.02187355,\n", + " 0.00582501, -0.02582753, -0.02192662, -0.00337636, 0.02204296,\n", + " 0.02771452, 0.0342248 , 0.00142591, 0.0136618 , 0.00158552,\n", + " 0.01320035, 0.0052134 , -0.03961967, -0.01420508, 0.00930639],\n", + " dtype=float32),\n", + " array([-0.01254899, 0.01254895], dtype=float32)])" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "get_updates(model, X_train, y_train, 32, 2)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "W = get_parameters(model)[0]\n", + "B = get_parameters(model)[1]" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "# BASELINE SCENARIO\n", + "#buid the model as base line for the shards (sequential)\n", + "# Number of peers\n", + "#accordin to what we need\n", + "ss = int(len(X_train)/n_peers)\n", + "inputs_in = X_train[0*ss:0*ss+ss]\n", + "outputs_in = y_train[0*ss:0*ss+ss]\n", + "def build_model(X_t, y_t):\n", + " model = Sequential()\n", + " model.add(Dense(70, input_dim=Features_number, activation='relu'))\n", + " model.add(Dense(50, activation='relu'))\n", + " model.add(Dense(50, activation='relu'))\n", + " model.add(Dense(2, activation='softmax'))\n", + " model.compile(loss='categorical_crossentropy', optimizer='adam',metrics=['accuracy'])\n", + " model.fit(X_t,\n", + " y_t, \n", + " batch_size=32, \n", + " epochs=250, \n", + " verbose=1,\n", + " validation_data=((X_test, y_test)))\n", + " return model\n" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Model: \"sequential_1\"\n", + "_________________________________________________________________\n", + "Layer (type) Output Shape Param # \n", + "=================================================================\n", + "dense_1 (Dense) (None, 70) 910 \n", + "_________________________________________________________________\n", + "dense_2 (Dense) (None, 50) 3550 \n", + "_________________________________________________________________\n", + "dense_3 (Dense) (None, 50) 2550 \n", + "_________________________________________________________________\n", + "dense_4 (Dense) (None, 2) 102 \n", + "=================================================================\n", + "Total params: 7,112\n", + "Trainable params: 7,112\n", + "Non-trainable params: 0\n", + "_________________________________________________________________\n" + ] + }, + { + "data": { + "text/plain": [ + "None" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "display(model.summary())" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "# predict probabilities for test set\n", + "yhat_probs = model.predict(X_test, verbose=0)\n", + "# predict crisp classes for test set\n", + "yhat_classes = model.predict_classes(X_test, verbose=0)" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Accuracy: 0.836071\n", + "Precision: 0.791075\n", + "Recall: 0.483871\n", + "F1 score: 0.600462\n" + ] + } + ], + "source": [ + "# accuracy: (tp + tn) / (p + n)\n", + "accuracy = accuracy_score(np.argmax(y_test, axis=1), np.argmax(model.predict(X_test), axis=1))\n", + "print('Accuracy: %f' % accuracy)\n", + "# precision tp / (tp + fp)\n", + "precision = precision_score(np.argmax(y_test, axis=1), np.argmax(model.predict(X_test), axis=1))\n", + "print('Precision: %f' % precision)\n", + "# recall: tp / (tp + fn)\n", + "recall = recall_score(np.argmax(y_test, axis=1), np.argmax(model.predict(X_test), axis=1))\n", + "print('Recall: %f' % recall)\n", + "# f1: 2 tp / (2 tp + fp + fn)\n", + "f1 = f1_score(np.argmax(y_test, axis=1), np.argmax(model.predict(X_test), axis=1))\n", + "print('F1 score: %f' % f1)" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[2257, 103],\n", + " [ 416, 390]], dtype=int64)" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQUAAADwCAYAAAAJvnGPAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAATe0lEQVR4nO3dfYxc1XnH8e8PQ0hIgsDhpcSY2kQGFVDixBZYQkGJaLChEYZKpHYrcADJgIwEUv4AEiQoFCltSlBQE5ApFiBRiFsCWMgEXDcqSYUBGyjEvNpAwLAyMkaAajDs7tM/7pndO/a83B2f8ezO/j7S1c49c+feOyv72fNyz3kUEZiZ1ezT6xsws/HFQcHM6jgomFkdBwUzq+OgYGZ1HBTMrI6Dgtk4I2m6pN9JelHSRkmXpfKfSXpJ0nOS7pd0UCqfIeljSc+m7dbSueZIel7SJkk3S1Lb6/s5BbPxRdIRwBER8bSkLwMbgLOAI4H/iohBSf8IEBFXSJoBPBQRJzQ415PAZcA6YDVwc0Q83Or6rimYjTMRMRART6fXHwEvAtMi4tGIGEyHraMIEk2l4HJgRDwexV//uyiCS0v77tHdmxkA87/7xXhv+1ClYzc8t3Mj8EmpaHlELG90bKoFfBN4Ype3LgB+XdqfKekZ4EPg6oj4PTAN2FI6Zksqa8lBwSyDbduHeOKRln+4R+x3xOZPImJuu+MkfQm4D7g8Ij4slf8EGATuTkUDwFER8Z6kOcADko4HGvUftO0vcFAwyyIYiuFsZ5O0H0VAuDsiflMqXwJ8Hzg1NQmIiJ3AzvR6g6TNwDEUNYNypDoSeKfdtd2nYJZBAMNEpa2dNEJwO/BiRPy8VL4AuAI4MyJ2lMoPlTQlvT4amAW8FhEDwEeS5qVzngc82O76rimYZRAEn0W1PoUKTgbOBZ6X9Gwq+zFwM7A/sCaNLK6LiIuBU4DrJA0CQ8DFEbE9fe4S4A7gC8DDaWvJQcEskyq1gCoi4g807g9Y3eT4+yiaGo3eWw/sNlTZyqRvPkhaIOnl9HDHlb2+n34jaYWkdyX9sdf30k0BDBGVtvFuUgeF1A77JXA6cBywWNJxvb2rvnMHsKDXN7E35OpT6LVJHRSAE4FNEfFaRHwK3Ass7PE99ZWIeAzY3vbACS6AoYhK23g32YPCNOCt0n6lhzvMGhmuuI13k72jsaOHO8x2FROkv6CKyR4UtgDTS/uVHu4w21UEfNYfMWHSB4WngFmSZgJvA4uAv+3tLdnEJIYaVjwnnkndp5BmnF0KPEIxE21lRGzs7V31F0n3AI8Dx0raIunCXt9TNwQwHNW28W6y1xSIiNU0eSjE9lxELO71Pewt/VJTmPRBwSyH4uElBwUzKxkOBwUzS1xTMLM6gfgspvT6NrKY1KMPNZKW9voe+l2//45rNYUq23jnoFDo63+w40Sf/47FUOxTaRvv3Hwwy6BYeWn8/4evoitB4ZCpU2LG9P26cequOGravsz9xucnwGMlo1557oBe38KYfJ4DOFBTJ9Tv+BP+j09jZ+X6/kRoGlTRlaAwY/p+PPnI9PYHWsfmf3V2r2+h7z0RaysfG6FsTQNJ0ylyNPwZxcTK5RHxC0lTKZZ1nwG8AfwgIt5P6y/+AjgD2AH8sJY3Ii30enU69T9ExJ3trt8f9R2zcWAYVdoqGAR+FBF/AcwDlqXFf64E1kbELGBt2odikaBZaVsK3AKQgsg1wEkUa4dcI+ngdhd3UDDLIBCfxr6VtrbnapIhimIBoNpf+jsZzfa0ELgrCuuAg1J2qPnAmojYHhHvA2uosAqWOxrNMhhjR+MhktaX9qtmiDo8LdtORAxIOiwd1myxoI4WEXJQMMtkqPpjzts6yRDVImF0s8WCOlpEyM0HswwCMcQ+lbYqmmSI2pqaBbXkse+m8maLBXW0iJCDglkmw7FPpa2dZhmigFXAkvR6CaPZnlYB56kwD/ggNTMeAU6TdHDqYDwtlbXk5oNZBsVjztn+xjbLEPVTYGVaqOZN4Jz03mqK4chNFEOS5wNExHZJ11OsMAZwXSlzVFMOCmYZ5JwQ1SJDFMCpDY4PYFmTc60AVozl+g4KZhlEMCHmNVThoGCWReUHk8Y9BwWzDIoMUa4pmFlJxo7GnnJQMMsgkNdoNLN6rimY2Yh+WqPRQcEsgyJDlGsKZlbilZfMbESEXFMws3p+TsHMRhSLrLj5YGYj8i3c2msOCmYZBHhI0sxG+YlGM9uNM0SZ2YhiPYX+qCn0R2gzGweGQ5W2KiStkPSupD+Wyn4t6dm0vVFbqk3SDEkfl967tfSZOZKel7RJ0s1qsSR0jWsKZhkUfQpZ/8beAfwLRfq44hoRf1N7LelG4IPS8ZsjolEuwVsoskato1jLcQHwcKsLu6ZglskQqrRVERGPAQ0XWU1/7X8A3NPqHGkZ+AMj4vG0juNdjGaVaso1BbMMAjE4XHlIsnKGqCa+DWyNiFdLZTMlPQN8CFwdEb+nyAa1pXSMM0SZ7U1jeKKxUoaoFhZTX0sYAI6KiPckzQEekHQ8HWaIclAwy2BvjT5I2hf4a2DO6LVjJ7Azvd4gaTNwDEXN4MjSx50hymxvypUhqo2/BF6KiJFmgaRDJU1Jr4+mSEn/WsoS9ZGkeakf4jxGs0o15aBglkHticaMQ5L3AI8Dx0rakrJCASxi9w7GU4DnJP0v8B/AxaVMUJcA/0qRPWozbUYewM0Hs2xyzpKMiMVNyn/YoOw+imS0jY5fD5wwlms7KJhlUCzH1h9PNDoomOUQYxqSHNccFMwy8CIrZrYbNx/MbEQ/9SlUGpKUtEDSy2mm1ZXdvimziSjnkGQvta0ppIcifgl8j+IJqackrYqIF7p9c2YTxWRbeelEYFNEvAYg6V5gIeCgYFYTMDiJFm6dBrxV2t8CnNSd2zGbmPqpT6FKUKg000rSUorFHDhqmvsvbfLpl6BQpb6zBZhe2m840yoilkfE3IiYe+hX+uMhDrOqcs996KUqQeEpYJakmZI+RzEhY1V3b8ts4olQpW28a1vPj4hBSZcCjwBTgBURsbHrd2Y2wUyqJxojYjXFoo9m1kBE//QpuEfQLAsxNDx5hiTNrIKJ0F9QhYOCWQb99JxCf9R3zHotin6FKlsVTTJEXSvp7VImqDNK712V5ia9LGl+qXzM85YcFMwyGUaVtoruoMjmtKubImJ22lYDSDqO4lGB49NnfiVpSmne0unAccDidGxLbj6YZRDk7VOIiMckzah4+ELg3rTU++uSNlHMWYIO5i25pmCWxZieaDxE0vrStnQMF7pU0nOpeXFwKms0P2lai/KWXFMwy2R4uOsZom4BrqeomFwP3AhcQPP5SY3+6DtDlNneUHQidnf0ISK21l5Lug14KO22mp/Udt7Srtx8MMuk2xOiUhbpmrOB2sjEKmCRpP0lzaTIEPUkHc5bck3BLJOqw41VpAxR36Hof9gCXAN8R9JsiibAG8BFxXVjo6SVFB2Ig8CyiBhK5xnzvCUHBbNMMo8+NMoQdXuL428AbmhQPuZ5Sw4KZhkEE2NadBUOCmaZZGw99JSDglkOAVF9SHJcc1Awy8TNBzOrk3P0oZccFMwyyD33oZccFMxyCMBBwczK3Hwws3oOCmY2Sh6SNLOSvTBLcm9xUDDLxc0HM6vnmoKZlbmmYGZ1HBTMbIQnRJnZbvqkpuA1Gs1yCVXbKmiSIepnkl5KS7zfL+mgVD5D0selzFG3lj4zR9LzKUPUzZLa3oCDglkmimpbRXewe4aoNcAJEfF14BXgqtJ7m0uZoy4uld8CLKVYzHVWg3PuxkHBLIcYw1bldBGPAdt3KXs0IgbT7jqKJdubSqs/HxgRj0dEAHcBZ7W7toOCWRYVmw57niGq5gLg4dL+TEnPSPpvSd9OZdMockLUOEOU2V5VvWnQaYYoACT9hGIp97tT0QBwVES8J2kO8ICk42meOaolBwWzXIa7fwlJS4DvA6emJgEpsezO9HqDpM3AMRQ1g3ITwxmizPaa2iIrmUYfGpG0ALgCODMidpTKD01p55F0NEWH4msRMQB8JGleGnU4D3iw3XVcUzDLZAwjC+3P1ThD1FXA/sCaNLK4Lo00nAJcJ2kQGAIujohaJ+UlFCMZX6Dogyj3QzTkoGCWS8agMJYMURFxH3Bfk/fWAyeM5dpdCQqvbJrKgr/6u26c2pJ9vj7U61voe3rlf3p9Cz3hmoJZJjmbD73koGCWi1deMrMRwV4ZktwbHBTMMnHzwczqOSiYWR0HBTOrGeO06HHNQcEsF48+mFkd1xTMrEwekjSzEe5TMLPdOCiYWR0HBTMr65fmg1deMrM6rimY5eKagpmNiGJIsspWRZMMUVMlrZH0avp5cCpXyv60KWWP+lbpM0vS8a+mRV/bclAwyyVjMhgaZ4i6ElgbEbOAtWkf4HRGM0AtpcgKhaSpFGs7ngScCFxTCyStOCiYZSDypo1rlCEKWAjcmV7fyWi2p4XAXVFYBxyUskPNB9ZExPaIeJ8i7VzbtHHuUzDLpXot4BBJ60v7yyNieYXPHZ6WbSciBiQdlsqnAW+VjqtlgmpW3pKDglkOY3uicY8yRDXQLBNURxmi3HwwyyVvn0IjW1OzoJY89t1UvgWYXjqulgmqWXlLDgpmmeQcfWhiFVAbQVjCaLanVcB5aRRiHvBBamY8Apwm6eDUwXhaKmvJzQezXLqfIeqnwEpJFwJvAuekw1cDZwCbgB3A+QARsV3S9cBT6bjrSpmjmnJQMMthz5sG9adrnCEK4NQGxwawrMl5VgArxnJtBwWzTPpl7oODglkuDgpmVuaagpnVc1Awsxov8W5mu3NQMLMy1xTMrJ6DgpnVcVAwsxHuaDSz3TgomFmZ08aZWR03H8xsVOZZkr3koGCWi4OCmdXUVnPuBw4KZrn0SVBou0Zjo0w1ZrY7RVTa2p5HOlbSs6XtQ0mXS7pW0tul8jNKn7kqZYh6WdL8PfkeVRZuvYMKCSTMJrWMaeMi4uWImB0Rs4E5FOsu3p/evqn2XkSsBpB0HLAIOJ7i/+qvJE3p9Ku0DQpNMtWY2a66s8T7qcDmiPhTi2MWAvdGxM6IeJ1iAdcTx3ylJNsS75KWSlovaf1ngztyndZswhhD2rhDav9X0ra0xWkXAfeU9i9NSWRXlPJCdpQJqplsQSEilkfE3IiYu9++B+Q6rdnEUb2msK32fyVtDVPGSfoccCbw76noFuBrwGxgALixdmiTu+mIRx/McujOhKjTgacjYitA7SeApNuAh9JuR5mgmnGGKLNc8vcpLKbUdKiljEvOBmojgquARZL2lzSTIiX9kx19ByrUFBplqomI2zu9oFk/yv3wkqQDgO8BF5WK/0nSbIrQ8kbtvYjYKGkl8AIwCCyLiKFOr902KLTIVGNmJRrOFxUiYgfwlV3Kzm1x/A3ADTmu7T4Fsxw8IcrMduX1FMysnmsKZlbmWZJmNiqACpOdJgIHBbNM3KdgZiO8yIqZ1Ytw88HM6rmmYGb1HBTMrMw1BTMbFUDGuQ+95KBglomHJM2snkcfzKzMfQpmNspTp82srHiisT+igtdoNMtluOJWgaQ3JD2fMkGtT2VTJa2R9Gr6eXAql6SbU4ao5yR9a0++hoOCWSa50saVfDdlgpqb9q8E1kbELGBt2odi1edZaVtKsRR8xxwUzHKIKJ5TqLJ1biFwZ3p9J3BWqfyuKKwDDtpl5ecxcVAwyyRzhqgAHpW0ofT+4RExAJB+HpbKs2aIckejWS7VmwbbSk2CZk6OiHckHQaskfRSi2OzZohyTcEsh4xZpwEi4p30812KjNMnAltrzYL08910uDNEmY1LtTUV2m1tSPqipC/XXgOnUWSDWgUsSYctAR5Mr1cB56VRiHnAB7VmRifcfDDLJd9jCocD90uC4v/ov0XEbyU9BayUdCHwJnBOOn41cAZFCvodwPl7cnEHBbNMcj28FBGvAd9oUP4ecGqD8gCWZbk4DgpmeQQw1B9PNDoomGUgxvxg0rjloGCWi4OCmdVxUDCzEUHlyU7jnYOCWSbuUzCzeg4KZjYiAob7o/3goGCWS3/EBAcFs1zcp2Bm9RwUzGyEM0S19tGOgW1rnv77P3Xj3F1yCLCt1zfR5ybi7/jPqx/qVPQtRcSh3Thvt0haX2ElHNsDk+J37KBgZiMCGOqP4QcHBbMsAqI/goKXYyss7/UNTAL9/zvOtxzbdEm/k/SipI2SLkvl10p6OyWIeVbSGaXPXJWSwbwsaf6efA3XFICI6P9/sD3W97/jvKMPg8CPIuLptFbjBklr0ns3RcQ/lw+WdBywCDge+Crwn5KOiYihTi7umoJZLplqChExEBFPp9cfAS/SOo/DQuDeiNgZEa9TrNV4Yqdfw0HBLJdMQaFM0gzgm8ATqejSlC9yRS2XJJmTwTgomOUQAUND1bZqGaKQ9CXgPuDyiPiQIkfk14DZwABwY+3QRnfU6Vdxn4JZLhkzREnajyIg3B0RvylOH1tL798GPJR2nQzGbFzKN/og4HbgxYj4eam8nDT2bIoEMVAkg1kkaX9JMymyTz/Z6ddwTcEsiz3OKF12MnAu8LykZ1PZj4HFkmYXF+MN4CKAiNgoaSXwAsXIxbJORx7AQcEsj4DI9PBSRPyBxv0Eq1t85gbghhzXd1Awy8WzJM2sjidEmdmI2pBkH3BQMMskvHCrmY3yIitmVubl2MxsN32ynoKDglkGAYRrCmY2Ivpn5SUHBbNMok+GJBV90mNq1kuSfkuxjH0V2yJiQTfvZ084KJhZHU+dNrM6DgpmVsdBwczqOCiYWR0HBTOr8/9uDb05Fa5ZDwAAAABJRU5ErkJggg==\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "# confusion matrix\n", + "mat = confusion_matrix(np.argmax(y_test, axis=1), np.argmax(model.predict(X_test), axis=1))\n", + "\n", + "display(mat)\n", + "plt.matshow(mat);\n", + "plt.colorbar()\n", + "plt.show()\n" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [], + "source": [ + "# the dectinary\n", + "FI_dic1= {0:[],1:[],2:[],3:[],4:[],5:[],6:[],7:[],8:[],9:[]}\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train on 420 samples, validate on 3166 samples\n", + "Epoch 1/250\n", + "420/420 [==============================] - ETA: 1s - loss: 0.8020 - accuracy: 0.15 - 0s 475us/step - loss: 0.6159 - accuracy: 0.6571 - val_loss: 0.5505 - val_accuracy: 0.7454\n", + "Epoch 2/250\n", + "420/420 [==============================] - ETA: 0s - loss: 0.4634 - accuracy: 0.81 - 0s 221us/step - loss: 0.5195 - accuracy: 0.7667 - val_loss: 0.5165 - val_accuracy: 0.7454\n", + "Epoch 3/250\n", + "420/420 [==============================] - ETA: 0s - loss: 0.4718 - accuracy: 0.78 - 0s 197us/step - loss: 0.4819 - accuracy: 0.7667 - val_loss: 0.4850 - val_accuracy: 0.7454\n", + "Epoch 4/250\n", + "420/420 [==============================] - ETA: 0s - loss: 0.4996 - accuracy: 0.78 - 0s 202us/step - loss: 0.4604 - accuracy: 0.7667 - val_loss: 0.4648 - val_accuracy: 0.7454\n", + "Epoch 5/250\n", + "420/420 [==============================] - ETA: 0s - loss: 0.4819 - accuracy: 0.81 - 0s 216us/step - loss: 0.4435 - accuracy: 0.7690 - val_loss: 0.4511 - val_accuracy: 0.7527\n", + "Epoch 6/250\n", + "420/420 [==============================] - ETA: 0s - loss: 0.5019 - accuracy: 0.71 - 0s 180us/step - loss: 0.4364 - accuracy: 0.7667 - val_loss: 0.4513 - val_accuracy: 0.7489\n", + "Epoch 7/250\n", + "420/420 [==============================] - ETA: 0s - loss: 0.4597 - accuracy: 0.71 - 0s 164us/step - loss: 0.4270 - accuracy: 0.7857 - val_loss: 0.4398 - val_accuracy: 0.7748\n", + "Epoch 8/250\n", + "420/420 [==============================] - ETA: 0s - loss: 0.4168 - accuracy: 0.75 - 0s 178us/step - loss: 0.4183 - accuracy: 0.8048 - val_loss: 0.4361 - val_accuracy: 0.7761\n", + "Epoch 9/250\n", + "420/420 [==============================] - ETA: 0s - loss: 0.3737 - accuracy: 0.75 - 0s 199us/step - loss: 0.4082 - accuracy: 0.8167 - val_loss: 0.4300 - val_accuracy: 0.7738\n", + "Epoch 10/250\n", + "420/420 [==============================] - ETA: 0s - loss: 0.4122 - accuracy: 0.84 - 0s 180us/step - loss: 0.4041 - accuracy: 0.8024 - val_loss: 0.4348 - val_accuracy: 0.7798\n", + "Epoch 11/250\n", + "420/420 [==============================] - ETA: 0s - loss: 0.5593 - accuracy: 0.62 - 0s 185us/step - loss: 0.3933 - accuracy: 0.8167 - val_loss: 0.4230 - val_accuracy: 0.7817\n", + "Epoch 12/250\n", + "420/420 [==============================] - ETA: 0s - loss: 0.3269 - accuracy: 0.84 - 0s 190us/step - loss: 0.3849 - accuracy: 0.8167 - val_loss: 0.4186 - val_accuracy: 0.7890\n", + "Epoch 13/250\n", + "420/420 [==============================] - ETA: 0s - loss: 0.4152 - accuracy: 0.87 - 0s 197us/step - loss: 0.3739 - accuracy: 0.8190 - val_loss: 0.4109 - val_accuracy: 0.7922\n", + "Epoch 14/250\n", + "420/420 [==============================] - ETA: 0s - loss: 0.3175 - accuracy: 0.87 - 0s 212us/step - loss: 0.3681 - accuracy: 0.8357 - val_loss: 0.4049 - val_accuracy: 0.7953\n", + "Epoch 15/250\n", + "420/420 [==============================] - ETA: 0s - loss: 0.2642 - accuracy: 0.87 - 0s 254us/step - loss: 0.3643 - accuracy: 0.8190 - val_loss: 0.3993 - val_accuracy: 0.8061\n", + "Epoch 16/250\n", + "420/420 [==============================] - ETA: 0s - loss: 0.3769 - accuracy: 0.84 - 0s 230us/step - loss: 0.3681 - accuracy: 0.8286 - val_loss: 0.3994 - val_accuracy: 0.8016\n", + "Epoch 17/250\n", + "420/420 [==============================] - ETA: 0s - loss: 0.4769 - accuracy: 0.84 - 0s 214us/step - loss: 0.3530 - accuracy: 0.8381 - val_loss: 0.3918 - val_accuracy: 0.8092\n", + "Epoch 18/250\n", + "420/420 [==============================] - ETA: 0s - loss: 0.3678 - accuracy: 0.87 - 0s 202us/step - loss: 0.3377 - accuracy: 0.8310 - val_loss: 0.3871 - val_accuracy: 0.8130\n", + "Epoch 19/250\n", + "420/420 [==============================] - ETA: 0s - loss: 0.3190 - accuracy: 0.81 - 0s 230us/step - loss: 0.3469 - accuracy: 0.8310 - val_loss: 0.4820 - val_accuracy: 0.7738\n", + "Epoch 20/250\n", + "420/420 [==============================] - ETA: 0s - loss: 0.3204 - accuracy: 0.84 - 0s 199us/step - loss: 0.3986 - accuracy: 0.8167 - val_loss: 0.4357 - val_accuracy: 0.7713\n", + "Epoch 21/250\n", + "420/420 [==============================] - ETA: 0s - loss: 0.4518 - accuracy: 0.75 - 0s 211us/step - loss: 0.3504 - accuracy: 0.8286 - val_loss: 0.4117 - val_accuracy: 0.7979\n", + "Epoch 22/250\n", + "420/420 [==============================] - ETA: 0s - loss: 0.4504 - accuracy: 0.81 - 0s 197us/step - loss: 0.3356 - accuracy: 0.8357 - val_loss: 0.3779 - val_accuracy: 0.8222\n", + "Epoch 23/250\n", + "420/420 [==============================] - ETA: 0s - loss: 0.3286 - accuracy: 0.81 - 0s 181us/step - loss: 0.3188 - accuracy: 0.8357 - val_loss: 0.4099 - val_accuracy: 0.7982\n", + "Epoch 24/250\n", + "420/420 [==============================] - ETA: 0s - loss: 0.5356 - accuracy: 0.81 - 0s 178us/step - loss: 0.3277 - accuracy: 0.8310 - val_loss: 0.3724 - val_accuracy: 0.8247\n", + "Epoch 25/250\n", + "420/420 [==============================] - ETA: 0s - loss: 0.2666 - accuracy: 0.87 - 0s 192us/step - loss: 0.3120 - accuracy: 0.8476 - val_loss: 0.3793 - val_accuracy: 0.8181\n", + "Epoch 26/250\n", + "420/420 [==============================] - ETA: 0s - loss: 0.1937 - accuracy: 0.90 - 0s 191us/step - loss: 0.3137 - accuracy: 0.8595 - val_loss: 0.3741 - val_accuracy: 0.8272\n", + "Epoch 27/250\n", + "420/420 [==============================] - ETA: 0s - loss: 0.3485 - accuracy: 0.81 - 0s 228us/step - loss: 0.3016 - accuracy: 0.8571 - val_loss: 0.3725 - val_accuracy: 0.8241\n", + "Epoch 28/250\n", + "420/420 [==============================] - ETA: 0s - loss: 0.1122 - accuracy: 1.00 - 0s 261us/step - loss: 0.2970 - accuracy: 0.8548 - val_loss: 0.3731 - val_accuracy: 0.8222\n", + "Epoch 29/250\n", + "420/420 [==============================] - ETA: 0s - loss: 0.2215 - accuracy: 0.90 - 0s 197us/step - loss: 0.3014 - accuracy: 0.8476 - val_loss: 0.3781 - val_accuracy: 0.8174\n", + "Epoch 30/250\n", + "420/420 [==============================] - ETA: 0s - loss: 0.2079 - accuracy: 0.84 - 0s 204us/step - loss: 0.2942 - accuracy: 0.8548 - val_loss: 0.3813 - val_accuracy: 0.8181\n", + "Epoch 31/250\n", + "420/420 [==============================] - ETA: 0s - loss: 0.3559 - accuracy: 0.75 - 0s 225us/step - loss: 0.2938 - accuracy: 0.8429 - val_loss: 0.3857 - val_accuracy: 0.8105\n", + "Epoch 32/250\n", + "420/420 [==============================] - ETA: 0s - loss: 0.4537 - accuracy: 0.78 - 0s 239us/step - loss: 0.3120 - accuracy: 0.8500 - val_loss: 0.3785 - val_accuracy: 0.8184\n", + "Epoch 33/250\n", + "420/420 [==============================] - ETA: 0s - loss: 0.1957 - accuracy: 0.90 - 0s 229us/step - loss: 0.3323 - accuracy: 0.8500 - val_loss: 0.4259 - val_accuracy: 0.7997\n", + "Epoch 34/250\n", + "420/420 [==============================] - ETA: 0s - loss: 0.2729 - accuracy: 0.87 - 0s 212us/step - loss: 0.2830 - accuracy: 0.8738 - val_loss: 0.3710 - val_accuracy: 0.8282\n", + "Epoch 35/250\n", + "420/420 [==============================] - ETA: 0s - loss: 0.4034 - accuracy: 0.78 - 0s 245us/step - loss: 0.2836 - accuracy: 0.8643 - val_loss: 0.3741 - val_accuracy: 0.8272\n", + "Epoch 36/250\n", + "420/420 [==============================] - ETA: 0s - loss: 0.2914 - accuracy: 0.87 - 0s 219us/step - loss: 0.2785 - accuracy: 0.8667 - val_loss: 0.3696 - val_accuracy: 0.8256\n", + "Epoch 37/250\n", + "420/420 [==============================] - ETA: 0s - loss: 0.2081 - accuracy: 0.90 - 0s 252us/step - loss: 0.2737 - accuracy: 0.8690 - val_loss: 0.3721 - val_accuracy: 0.8244\n", + "Epoch 38/250\n", + "420/420 [==============================] - ETA: 0s - loss: 0.2847 - accuracy: 0.90 - 0s 185us/step - loss: 0.2932 - accuracy: 0.8762 - val_loss: 0.3984 - val_accuracy: 0.8149\n", + "Epoch 39/250\n", + "420/420 [==============================] - ETA: 0s - loss: 0.2068 - accuracy: 0.96 - 0s 275us/step - loss: 0.2746 - accuracy: 0.8786 - val_loss: 0.3747 - val_accuracy: 0.8260\n", + "Epoch 40/250\n", + "420/420 [==============================] - ETA: 0s - loss: 0.2561 - accuracy: 0.87 - 0s 235us/step - loss: 0.2991 - accuracy: 0.8524 - val_loss: 0.3819 - val_accuracy: 0.8253\n", + "Epoch 41/250\n", + "420/420 [==============================] - ETA: 0s - loss: 0.2118 - accuracy: 0.90 - 0s 188us/step - loss: 0.2920 - accuracy: 0.8429 - val_loss: 0.3901 - val_accuracy: 0.8143\n", + "Epoch 42/250\n", + "420/420 [==============================] - ETA: 0s - loss: 0.1978 - accuracy: 0.90 - 0s 230us/step - loss: 0.2979 - accuracy: 0.8667 - val_loss: 0.4308 - val_accuracy: 0.8080\n", + "Epoch 43/250\n", + "420/420 [==============================] - ETA: 0s - loss: 0.3652 - accuracy: 0.75 - 0s 233us/step - loss: 0.2874 - accuracy: 0.8571 - val_loss: 0.3759 - val_accuracy: 0.8266\n", + "Epoch 44/250\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "420/420 [==============================] - ETA: 0s - loss: 0.2414 - accuracy: 0.90 - 0s 211us/step - loss: 0.2671 - accuracy: 0.8643 - val_loss: 0.3861 - val_accuracy: 0.8152\n", + "Epoch 45/250\n", + "420/420 [==============================] - ETA: 0s - loss: 0.1696 - accuracy: 0.96 - 0s 211us/step - loss: 0.2889 - accuracy: 0.8738 - val_loss: 0.3877 - val_accuracy: 0.8234\n", + "Epoch 46/250\n", + "420/420 [==============================] - ETA: 0s - loss: 0.2738 - accuracy: 0.84 - 0s 220us/step - loss: 0.2635 - accuracy: 0.8738 - val_loss: 0.3914 - val_accuracy: 0.8222\n", + "Epoch 47/250\n", + " 32/420 [=>............................] - ETA: 0s - loss: 0.2366 - accuracy: 0.9062" + ] + } + ], + "source": [ + "# select aa random peer to be the scanner peer\n", + "peers_selected=random.sample(range(n_peers), 1)\n", + "scaner = peers_selected[0]\n", + "\n", + "# Percentage and number of peers participating at each global training epoch\n", + "percentage_participants = 1.0\n", + "n_participants = int(n_peers * percentage_participants)\n", + "\n", + "# Number of global training epochs\n", + "n_rounds = 10\n", + "start_attack_round = 4\n", + "end_attack_round = 7\n", + "# Number of local training epochs per global training epoch\n", + "n_local_rounds = 5\n", + "\n", + "# Local batch size\n", + "local_batch_size = 32\n", + "\n", + "# Local learning rate\n", + "local_lr = 0.001\n", + "\n", + "# Global learning rate or 'gain'\n", + "model_substitution_rate = 1.0\n", + "\n", + "# Attack detection / prevention mechanism = {None, 'distance', 'median', 'accuracy', 'krum'}\n", + "discard_outliers = None\n", + "\n", + "# Used in 'dist' attack detection, defines how far the outliers are (1.5 is a typical value)\n", + "tau = 1.5\n", + "\n", + "# Used in 'accuracy' attack detection, defines the error margin for the accuracy improvement\n", + "sensitivity = 0.05\n", + "\n", + "# Used in 'krum' attack detection, defines how many byzantine attackers we want to defend against\n", + "tolerance=4\n", + "\n", + "# Prevent suspicious peers from participating again, only valid for 'dist' and 'accuracy'\n", + "ban_malicious = False\n", + "\n", + "# Clear nans and infinites in model updates\n", + "clear_nans = True\n", + "\n", + "number_for_threshold1 = numpy.empty(20, dtype=float)\n", + "number_for_threshold2 = numpy.empty(20, dtype=float)\n", + "for r in range(len(number_for_threshold1)):\n", + " number_for_threshold1[r] = 0\n", + " number_for_threshold2[r] = 0\n", + "\n", + "########################\n", + "# ATTACK CONFIGURATION #\n", + "########################\n", + "\n", + "# Percentage of malicious peers\n", + "r_malicious_peers = 0.0\n", + "\n", + "# Number of malicious peers (absolute or relative to total number of peers)\n", + "n_malicious_peers = int(n_peers * r_malicious_peers)\n", + "#n_malicious_peers = 1\n", + "\n", + "# Malicious peers\n", + "malicious_peer = range(n_malicious_peers)\n", + "\n", + "# Target for coalitions\n", + "common_attack_target = [4,7]\n", + "\n", + "# Target class of the attack, per each malicious peer\n", + "malicious_targets = dict([(p, t) for p,t in zip(malicious_peer, [common_attack_target]*n_malicious_peers)])\n", + "\n", + "# Boosting parameter per each malicious peer\n", + "common_malicious_boost = 12\n", + "malicious_boost = dict([(p, b) for p,b in zip(malicious_peer, [common_malicious_boost]*n_malicious_peers)])\n", + "\n", + "###########\n", + "# METRICS #\n", + "###########\n", + "metrics = {'accuracy': [],\n", + " 'atk_effectivity': [],\n", + " 'update_distances': [],\n", + " 'outliers_detected': [],\n", + "\n", + " 'acc_no_target': []}\n", + "\n", + "####################################\n", + "# MODEL AND NETWORK INITIALIZATION #\n", + "####################################\n", + "inputs = X_train[0*ss:0*ss+ss]\n", + "outputs = y_train[0*ss:0*ss+ss]\n", + "global_model = build_model(inputs,outputs)\n", + "n_layers = len(trainable_layers(global_model))\n", + "\n", + "print('Initializing network.')\n", + "sleep(1)\n", + "network = []\n", + "for i in tqdm(range(n_peers)):\n", + " ss = int(len(X_train)/n_peers)\n", + " inputs = X_train[i*ss:i*ss+ss]\n", + " outputs = y_train[i*ss:i*ss+ss]\n", + "# network.append(build_model(inputs, outputs))\n", + " network.append(global_model)\n", + "\n", + "\n", + "banned_peers = set()\n", + "\n", + "##################\n", + "# BEGIN TRAINING #\n", + "##################\n", + "for t in range(n_rounds):\n", + " print(f'Round {t+1}.')\n", + " sleep(1)\n", + "\n", + " ## SERVER SIDE #################################################################\n", + " # Fetch global model parameters\n", + " global_weights, global_biases = get_parameters(global_model)\n", + "\n", + " if clear_nans:\n", + " global_weights, global_biases = nans_to_zero(global_weights, global_biases)\n", + "\n", + " # Initialize peer update lists\n", + " network_weight_updates = []\n", + " network_bias_updates = []\n", + "\n", + " # Selection of participant peers in this global training epoch\n", + " if ban_malicious:\n", + " good_peers = list([p for i,p in enumerate(network) if i not in banned_peers])\n", + " n_participants = n_participants if n_participants <= len(good_peers) else int(len(good_peers) * percentage_participants)\n", + " participants = random.sample(list(enumerate(good_peers)), n_participants)\n", + " else:\n", + " participants = random.sample(list(enumerate(network)),n_participants)\n", + " ################################################################################\n", + "\n", + "\n", + " ## CLIENT SIDE #################################################################\n", + " for i, local_model in tqdm(participants):\n", + "\n", + " # Update local model with global parameters \n", + " set_parameters(local_model, global_weights, global_biases)\n", + "\n", + " # Initialization of user data\n", + " ss = int(len(X_train)/n_peers)\n", + " inputs = X_train[i*ss:i*ss+ss]\n", + " outputs = y_train[i*ss:i*ss+ss]\n", + "\n", + "# the scanner peer side\n", + " if(i == scaner):\n", + " X_train_local, X_test_local, y_train_local, y_test_local = train_test_split(inputs,outputs, test_size=0.7, random_state=rs)\n", + " inputs = X_train_local\n", + " outputs = y_train_local\n", + " if(t == 0):\n", + " forest = build_forest(X_train_local,y_train_local)\n", + " forest_predictions = forest.predict(X_test_local)\n", + " acc_forest = np.mean([t==p for t,p in zip(y_test_local, forest_predictions)])\n", + " FL_predict1 = global_model.predict(X_test_local)\n", + " imp = scan_wrong(forest_predictions, FL_predict1, forest , y_test_local, X_test_local)\n", + " FI_dic1[t] = imp\n", + "\n", + "\n", + " # Benign peer\n", + " # Train local model \n", + " local_weight_updates, local_bias_updates = get_updates(local_model, \n", + " inputs, outputs, \n", + " local_batch_size, n_local_rounds)\n", + " if clear_nans:\n", + " local_weight_updates, local_bias_updates = nans_to_zero(local_weight_updates, local_bias_updates)\n", + " network_weight_updates.append(local_weight_updates)\n", + " network_bias_updates.append(local_bias_updates)\n", + "\n", + " ## END OF CLIENT SIDE ##########################################################\n", + "\n", + " ######################################\n", + " # SERVER SIDE AGGREGATION MECHANISMS #\n", + " ######################################\n", + "\n", + "\n", + " # Aggregate client updates\n", + " aggregated_weights, aggregated_biases = aggregate(n_layers, \n", + " n_participants, \n", + " np.mean, \n", + " network_weight_updates, \n", + " network_bias_updates)\n", + "\n", + " if clear_nans:\n", + " aggregated_weights, aggregated_biases = nans_to_zero(aggregated_weights, aggregated_biases)\n", + "\n", + " # Apply updates to global model\n", + " apply_updates(global_model, model_substitution_rate, aggregated_weights, aggregated_biases)\n", + "\n", + " # Proceed as in first case\n", + " aggregated_weights, aggregated_biases = aggregate(n_layers, \n", + " n_participants, \n", + " np.mean, \n", + " network_weight_updates, \n", + " network_bias_updates)\n", + " if clear_nans:\n", + " aggregated_weights, aggregated_biases = nans_to_zero(aggregated_weights, aggregated_biases)\n", + "\n", + " apply_updates(global_model, model_substitution_rate, aggregated_weights, aggregated_biases)\n", + "\n", + " ###################\n", + " # COMPUTE METRICS #\n", + " ###################\n", + "\n", + " # Global model accuracy\n", + " score = global_model.evaluate(X_test, y_test, verbose=0)\n", + " print(f'Global model loss: {score[0]}; global model accuracy: {score[1]}')\n", + " metrics['accuracy'].append(score[1])\n", + "\n", + "\n", + " # Accuracy without the target\n", + " score = global_model.evaluate(X_test, y_test, verbose=0)\n", + " metrics['acc_no_target'].append(score[1])\n", + "\n", + "\n", + " # Distance of individual updates to the final aggregation\n", + " metrics['update_distances'].append([dist_weights(aggregated_weights, w_i) for w_i in network_weight_updates])\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# sort the feature according to the last epoch and print it with importances\n", + "\n", + "sort_index = np.argsort(FI_dic1[9])\n", + "for x in sort_index:\n", + " print(names[x], ', ', FI_dic1[9][x])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.4" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}