fast-mri/src/sfransen/Saliency/integrated_gradients.py

64 lines
2.4 KiB
Python
Executable File

import numpy as np
import tensorflow as tf
from tensorflow.keras.applications import densenet
from base import SaliencyMap
class IntegratedGradients(SaliencyMap):
def get_mask(self, image, baseline=None, num_steps=2):
"""Computes Integrated Gradients for a predicted label.
Args:
image (ndarray): Original image
top_pred_idx: Predicted label for the input image
baseline (ndarray): The baseline image to start with for interpolation
num_steps: Number of interpolation steps between the baseline
and the input used in the computation of integrated gradients. These
steps along determine the integral approximation error. By default,
num_steps is set to 50.
Returns:
Integrated gradients w.r.t input image
"""
# If baseline is not provided, start with a black image
# having same size as the input image.
if baseline is None:
img_size = image.shape
baseline = np.zeros(img_size).astype(np.float32)
else:
baseline = baseline.astype(np.float32)
print(">>>> step ONE completed")
img_input = image
top_pred_idx = self.get_top_predicted_idx(image)
interpolated_image = [
baseline + (i / num_steps) * (img_input - baseline)
for i in range(num_steps + 1)
]
interpolated_image = np.vstack(interpolated_image).astype(np.float32)
print(">>>> step TWO completed")
grads = []
for i, img in enumerate(interpolated_image):
print("number of image:",i)
print("size of image:",np.shape(img))
img = tf.expand_dims(img, axis=0)
grad = self.get_gradients(img)
print("size of grad is:",np.shape(grad))
grads.append(grad[0])
grads = tf.convert_to_tensor(grads, dtype=tf.float32)
print(">>>> step THREE completed")
# 4. Approximate the integral using the trapezoidal rule
grads = (grads[:-1] + grads[1:]) / 2.0
avg_grads = tf.reduce_mean(grads, axis=0)
# tf.reduce_mean(grads, axis=(0, 1, 2, 3))
print(">>>> step FOUR completed")
# 5. Calculate integrated gradients and return
integrated_grads = (img_input - baseline) * avg_grads
print(">>>> step FIVE completed")
return integrated_grads