import numpy as np import tensorflow as tf class SaliencyMap(): def __init__(self, model): """Constructs a Vanilla Gradient Map by computing dy/dx. Args: model: The TensorFlow model used to evaluate Gradient Map. model takes image as input and outputs probabilities vector. """ self.model = model def get_top_predicted_idx(self, image): """Outputs top predicted class for the input image. Args: img_processed: numpy image array in NHWC format, pre-processed according to the defined model standard. Returns: Index of the top predicted class for the input image. """ preds = self.model.predict(image) # top_pred_idx = tf.argmax(preds[0]) top_pred_idx = 1 return top_pred_idx def get_gradients(self, image): """Computes the gradients of outputs w.r.t input image. Args: image: numpy image array in NHWC format, pre-processed according to the defined model standard. Returns: Gradients of the predictions w.r.t image (same shape as input image) """ image = tf.convert_to_tensor(image) top_pred_idx = self.get_top_predicted_idx(image) with tf.GradientTape() as tape: tape.watch(image) preds = self.model(image) print("get_gradients, size of preds",np.shape(preds)) top_class = preds[:] print("get_gradients, size of top_class",np.shape(top_class)) grads = tape.gradient(top_class, image) return grads def norm_grad(self, grad_x): """Normalizes gradient to the range between 0 and 1 (for visualization purposes). Args: grad_x: numpy gradients array. Returns: Gradients of the predictions w.r.t image (same shape as input image) """ abs_grads = np.abs(grad_x) grad_max_ = np.max(abs_grads, axis=3)[0] arr_min, arr_max = np.min(grad_max_), np.max(grad_max_) normalized_grad = (grad_max_ - arr_min) / (arr_max - arr_min + 1e-18) normalized_grad = normalized_grad.reshape(1,grad_x.shape[1],grad_x.shape[2],1) return normalized_grad def get_mask(self, image, tensor_format=False): """Returns a saliency mask specific to each method. Args: image: input image in NHWC format, not batched. """ raise NotImplementedError('A derived class should implement get_mask()')