82 lines
2.5 KiB
Python
Executable File
82 lines
2.5 KiB
Python
Executable File
import numpy as np
|
|
import tensorflow as tf
|
|
|
|
class SaliencyMap():
|
|
def __init__(self, model):
|
|
"""Constructs a Vanilla Gradient Map by computing dy/dx.
|
|
|
|
Args:
|
|
model: The TensorFlow model used to evaluate Gradient Map.
|
|
model takes image as input and outputs probabilities vector.
|
|
"""
|
|
self.model = model
|
|
|
|
|
|
def get_top_predicted_idx(self, image):
|
|
"""Outputs top predicted class for the input image.
|
|
|
|
Args:
|
|
img_processed: numpy image array in NHWC format, pre-processed according
|
|
to the defined model standard.
|
|
|
|
Returns:
|
|
Index of the top predicted class for the input image.
|
|
"""
|
|
preds = self.model.predict(image)
|
|
# top_pred_idx = tf.argmax(preds[0])
|
|
top_pred_idx = 1
|
|
return top_pred_idx
|
|
|
|
|
|
def get_gradients(self, image):
|
|
"""Computes the gradients of outputs w.r.t input image.
|
|
|
|
Args:
|
|
image: numpy image array in NHWC format, pre-processed according
|
|
to the defined model standard.
|
|
|
|
Returns:
|
|
Gradients of the predictions w.r.t image (same shape as input image)
|
|
"""
|
|
image = tf.convert_to_tensor(image)
|
|
top_pred_idx = self.get_top_predicted_idx(image)
|
|
|
|
with tf.GradientTape() as tape:
|
|
tape.watch(image)
|
|
preds = self.model(image)
|
|
print("get_gradients, size of preds",np.shape(preds))
|
|
top_class = preds[:]
|
|
print("get_gradients, size of top_class",np.shape(top_class))
|
|
|
|
|
|
grads = tape.gradient(top_class, image)
|
|
return grads
|
|
|
|
|
|
def norm_grad(self, grad_x):
|
|
"""Normalizes gradient to the range between 0 and 1
|
|
(for visualization purposes).
|
|
|
|
Args:
|
|
grad_x: numpy gradients array.
|
|
|
|
Returns:
|
|
Gradients of the predictions w.r.t image (same shape as input image)
|
|
"""
|
|
abs_grads = np.abs(grad_x)
|
|
grad_max_ = np.max(abs_grads, axis=3)[0]
|
|
arr_min, arr_max = np.min(grad_max_), np.max(grad_max_)
|
|
normalized_grad = (grad_max_ - arr_min) / (arr_max - arr_min + 1e-18)
|
|
normalized_grad = normalized_grad.reshape(1,grad_x.shape[1],grad_x.shape[2],1)
|
|
|
|
return normalized_grad
|
|
|
|
|
|
def get_mask(self, image, tensor_format=False):
|
|
"""Returns a saliency mask specific to each method.
|
|
|
|
Args:
|
|
image: input image in NHWC format, not batched.
|
|
"""
|
|
raise NotImplementedError('A derived class should implement get_mask()')
|