fast-mri/src/sfransen/Saliency/base.py

82 lines
2.5 KiB
Python
Executable File

import numpy as np
import tensorflow as tf
class SaliencyMap():
def __init__(self, model):
"""Constructs a Vanilla Gradient Map by computing dy/dx.
Args:
model: The TensorFlow model used to evaluate Gradient Map.
model takes image as input and outputs probabilities vector.
"""
self.model = model
def get_top_predicted_idx(self, image):
"""Outputs top predicted class for the input image.
Args:
img_processed: numpy image array in NHWC format, pre-processed according
to the defined model standard.
Returns:
Index of the top predicted class for the input image.
"""
preds = self.model.predict(image)
# top_pred_idx = tf.argmax(preds[0])
top_pred_idx = 1
return top_pred_idx
def get_gradients(self, image):
"""Computes the gradients of outputs w.r.t input image.
Args:
image: numpy image array in NHWC format, pre-processed according
to the defined model standard.
Returns:
Gradients of the predictions w.r.t image (same shape as input image)
"""
image = tf.convert_to_tensor(image)
top_pred_idx = self.get_top_predicted_idx(image)
with tf.GradientTape() as tape:
tape.watch(image)
preds = self.model(image)
print("get_gradients, size of preds",np.shape(preds))
top_class = preds[:]
print("get_gradients, size of top_class",np.shape(top_class))
grads = tape.gradient(top_class, image)
return grads
def norm_grad(self, grad_x):
"""Normalizes gradient to the range between 0 and 1
(for visualization purposes).
Args:
grad_x: numpy gradients array.
Returns:
Gradients of the predictions w.r.t image (same shape as input image)
"""
abs_grads = np.abs(grad_x)
grad_max_ = np.max(abs_grads, axis=3)[0]
arr_min, arr_max = np.min(grad_max_), np.max(grad_max_)
normalized_grad = (grad_max_ - arr_min) / (arr_max - arr_min + 1e-18)
normalized_grad = normalized_grad.reshape(1,grad_x.shape[1],grad_x.shape[2],1)
return normalized_grad
def get_mask(self, image, tensor_format=False):
"""Returns a saliency mask specific to each method.
Args:
image: input image in NHWC format, not batched.
"""
raise NotImplementedError('A derived class should implement get_mask()')