114 lines
4.9 KiB
Python
Executable File
114 lines
4.9 KiB
Python
Executable File
from tensorflow.keras import backend as K
|
|
from tensorflow.keras import Input, Model
|
|
from tensorflow.keras.layers import Conv3D, Activation, Dense, concatenate, add
|
|
from tensorflow.keras.layers import Conv3DTranspose, LeakyReLU
|
|
from tensorflow.keras import regularizers
|
|
from tensorflow.keras.layers import GlobalAveragePooling3D, Reshape, Dense, multiply, Permute
|
|
|
|
def squeeze_excite_block(input, ratio=8):
|
|
''' Create a channel-wise squeeze-excite block
|
|
Args:
|
|
input: input tensor
|
|
filters: number of output filters
|
|
Returns: a keras tensor
|
|
References
|
|
- [Squeeze and Excitation Networks](https://arxiv.org/abs/1709.01507)
|
|
'''
|
|
init = input
|
|
channel_axis = 1 if K.image_data_format() == "channels_first" else -1
|
|
filters = init.shape[channel_axis]
|
|
se_shape = (1, 1, 1, filters)
|
|
|
|
se = GlobalAveragePooling3D()(init)
|
|
se = Reshape(se_shape)(se)
|
|
se = Dense(filters // ratio, activation='relu', kernel_initializer='he_normal', use_bias=False)(se)
|
|
se = Dense(filters, activation='sigmoid', kernel_initializer='he_normal', use_bias=False)(se)
|
|
|
|
if K.image_data_format() == 'channels_first':
|
|
se = Permute((4, 1, 2, 3))(se)
|
|
|
|
x = multiply([init, se])
|
|
return x
|
|
|
|
def build_dual_attention_unet(
|
|
input_shape,
|
|
l2_regularization = 0.0001,
|
|
):
|
|
|
|
def conv_layer(x, kernel_size, out_filters, strides=(1,1,1)):
|
|
x = Conv3D(out_filters, kernel_size,
|
|
strides = strides,
|
|
padding = 'same',
|
|
kernel_regularizer = regularizers.l2(l2_regularization),
|
|
kernel_initializer = 'he_normal',
|
|
use_bias = False
|
|
)(x)
|
|
return x
|
|
|
|
def conv_block(input, out_filters, strides=(1,1,1), with_residual=False, with_se=False, activation='relu'):
|
|
# Strided convolution to convsample
|
|
x = conv_layer(input, (3,3,3), out_filters, strides)
|
|
x = Activation('relu')(x)
|
|
|
|
# Unstrided convolution
|
|
x = conv_layer(x, (3,3,3), out_filters)
|
|
|
|
# Add a squeeze-excite block
|
|
if with_se:
|
|
se = squeeze_excite_block(x)
|
|
x = add([x, se])
|
|
|
|
# Add a residual connection using a 1x1x1 convolution with strides
|
|
if with_residual:
|
|
residual = conv_layer(input, (1,1,1), out_filters, strides)
|
|
x = add([x, residual])
|
|
|
|
if activation == 'leaky':
|
|
x = LeakyReLU(alpha=.1)(x)
|
|
else:
|
|
x = Activation('relu')(x)
|
|
|
|
# Activate whatever comes out of this
|
|
return x
|
|
|
|
# If we already have only one input, no need to combine anything
|
|
inputs = Input(input_shape)
|
|
|
|
# Downsampling
|
|
conv1 = conv_block(inputs, 16)
|
|
conv2 = conv_block(conv1, 32, strides=(2,2,1), with_residual=True, with_se=True) #72x72x18
|
|
conv3 = conv_block(conv2, 64, strides=(2,2,1), with_residual=True, with_se=True) #36x36x18
|
|
conv4 = conv_block(conv3, 128, strides=(2,2,2), with_residual=True, with_se=True) #18x18x9
|
|
conv5 = conv_block(conv4, 256, strides=(2,2,2), with_residual=True, with_se=True) #9x9x9
|
|
|
|
# First upsampling sequence
|
|
up1_1 = Conv3DTranspose(128, (3,3,3), strides=(2,2,2), padding='same')(conv5) #18x18x9
|
|
up1_2 = Conv3DTranspose(128, (3,3,3), strides=(2,2,2), padding='same')(up1_1) #36x36x18
|
|
up1_3 = Conv3DTranspose(128, (3,3,3), strides=(2,2,1), padding='same')(up1_2) #72x72x18
|
|
bridge1 = concatenate([conv4, up1_1]) #18x18x9 (128+128=256)
|
|
dec_conv_1 = conv_block(bridge1, 128, with_residual=True, with_se=True, activation='leaky') #18x18x9
|
|
|
|
# Second upsampling sequence
|
|
up2_1 = Conv3DTranspose(64, (3,3,3), strides=(2,2,2), padding='same')(dec_conv_1) # 36x36x18
|
|
up2_2 = Conv3DTranspose(64, (3,3,3), strides=(2,2,1), padding='same')(up2_1) # 72x72x18
|
|
bridge2 = concatenate([conv3, up1_2, up2_1]) # 36x36x18 (64+128+64=256)
|
|
dec_conv_2 = conv_block(bridge2, 64, with_residual=True, with_se=True, activation='leaky')
|
|
|
|
# Final upsampling sequence
|
|
up3_1 = Conv3DTranspose(32, (3,3,3), strides=(2,2,1), padding='same')(dec_conv_2) # 72x72x18
|
|
bridge3 = concatenate([conv2, up1_3, up2_2, up3_1]) # 72x72x18 (32+128+64+32=256)
|
|
dec_conv_3 = conv_block(bridge3, 32, with_residual=True, with_se=True, activation='leaky')
|
|
|
|
# Last upsampling to make heatmap
|
|
up4_1 = Conv3DTranspose(16, (3,3,3), strides=(2,2,1), padding='same')(dec_conv_3) # 72x72x18
|
|
dec_conv_4 = conv_block(up4_1, 16, with_residual=False, with_se=True, activation='leaky') #144x144x18 (16)
|
|
|
|
# Reduce to a single output channel with a 1x1x1 convolution
|
|
single_channel = Conv3D(1, (1, 1, 1))(dec_conv_4)
|
|
|
|
# Apply sigmoid activation to get binary prediction per voxel
|
|
act = Activation('sigmoid')(single_channel)
|
|
|
|
# Model definition
|
|
model = Model(inputs=inputs, outputs=act)
|
|
return model |