fast-mri/src/sfransen/Saliency/heatmap.py

86 lines
3.3 KiB
Python
Executable File

import os
import numpy as np
from PIL import Image
import matplotlib
import matplotlib.pyplot as plt
import scipy.ndimage as ndimage
class HeatMap:
def __init__(self,image,heat_map,gaussian_std=10):
#if image is numpy array
if isinstance(image,np.ndarray):
height = image.shape[0]
width = image.shape[1]
self.image = image
else:
#PIL open the image path, record the height and width
image = Image.open(image)
width, height = image.size
self.image = image
#Convert numpy heat_map values into image formate for easy upscale
#Rezie the heat_map to the size of the input image
#Apply the gausian filter for smoothing
#Convert back to numpy
heatmap_image = Image.fromarray(heat_map*255)
heatmap_image_resized = heatmap_image.resize((width,height))
heatmap_image_resized = ndimage.gaussian_filter(heatmap_image_resized,
sigma=(gaussian_std, gaussian_std),
order=0)
heatmap_image_resized = np.asarray(heatmap_image_resized)
self.heat_map = heatmap_image_resized
#Plot the figure
def plot(self,transparency=0.7,color_map='bwr',
show_axis=False, show_original=False, show_colorbar=False,width_pad=0):
#If show_original is True, then subplot first figure as orginal image
#Set x,y to let the heatmap plot in the second subfigure,
#otherwise heatmap will plot in the first sub figure
if show_original:
plt.subplot(1, 2, 1)
if not show_axis:
plt.axis('off')
plt.imshow(self.image,cmap='gray')
x,y=2,2
else:
x,y=1,1
#Plot the heatmap
plt.subplot(1,x,y)
if not show_axis:
plt.axis('off')
plt.imshow(self.image,cmap='gray')
plt.imshow(self.heat_map/255, alpha=transparency, cmap=color_map)
if show_colorbar:
plt.colorbar()
plt.tight_layout(w_pad=width_pad)
plt.show()
###Save the figure
def save(self,filename,format='png',save_path=os.getcwd(),
transparency=0.7,color_map='bwr',width_pad = -10,
show_axis=False, show_original=False, show_colorbar=False, **kwargs):
if show_original:
plt.subplot(1, 2, 1)
if not show_axis:
plt.axis('off')
plt.imshow(self.image,cmap='gray')
x,y=2,2
else:
x,y=1,1
#Plot the heatmap
plt.subplot(1,x,y)
if not show_axis:
plt.axis('off')
plt.imshow(self.image,cmap='gray')
plt.imshow(self.heat_map/255, alpha=transparency, cmap=color_map, caxis = [min(nonzeros(self.image)) max(nonzeros(self.image))])
if show_colorbar:
plt.colorbar()
plt.tight_layout(w_pad=width_pad)
plt.savefig(os.path.join(save_path,filename+'.'+format),
format=format,
bbox_inches='tight',
pad_inches = 0, **kwargs)
print('{}.{} has been successfully saved to {}'.format(filename,format,save_path))