from tensorflow.keras import backend as K from tensorflow.keras import Input, Model from tensorflow.keras.layers import Conv3D, Activation, Dense, concatenate, add from tensorflow.keras.layers import Conv3DTranspose, LeakyReLU from tensorflow.keras import regularizers from tensorflow.keras.layers import GlobalAveragePooling3D, Reshape, Dense, multiply, Permute def squeeze_excite_block(input, ratio=8): ''' Create a channel-wise squeeze-excite block Args: input: input tensor filters: number of output filters Returns: a keras tensor References - [Squeeze and Excitation Networks](https://arxiv.org/abs/1709.01507) ''' init = input channel_axis = 1 if K.image_data_format() == "channels_first" else -1 filters = init.shape[channel_axis] se_shape = (1, 1, 1, filters) se = GlobalAveragePooling3D()(init) se = Reshape(se_shape)(se) se = Dense(filters // ratio, activation='relu', kernel_initializer='he_normal', use_bias=False)(se) se = Dense(filters, activation='sigmoid', kernel_initializer='he_normal', use_bias=False)(se) if K.image_data_format() == 'channels_first': se = Permute((4, 1, 2, 3))(se) x = multiply([init, se]) return x def build_dual_attention_unet( input_shape, l2_regularization = 0.0001, ): def conv_layer(x, kernel_size, out_filters, strides=(1,1,1)): x = Conv3D(out_filters, kernel_size, strides = strides, padding = 'same', kernel_regularizer = regularizers.l2(l2_regularization), kernel_initializer = 'he_normal', use_bias = False )(x) return x def conv_block(input, out_filters, strides=(1,1,1), with_residual=False, with_se=False, activation='relu'): # Strided convolution to convsample x = conv_layer(input, (3,3,3), out_filters, strides) x = Activation('relu')(x) # Unstrided convolution x = conv_layer(x, (3,3,3), out_filters) # Add a squeeze-excite block if with_se: se = squeeze_excite_block(x) x = add([x, se]) # Add a residual connection using a 1x1x1 convolution with strides if with_residual: residual = conv_layer(input, (1,1,1), out_filters, strides) x = add([x, residual]) if activation == 'leaky': x = LeakyReLU(alpha=.1)(x) else: x = Activation('relu')(x) # Activate whatever comes out of this return x # If we already have only one input, no need to combine anything inputs = Input(input_shape) # Downsampling conv1 = conv_block(inputs, 16) conv2 = conv_block(conv1, 32, strides=(2,2,1), with_residual=True, with_se=True) #72x72x18 conv3 = conv_block(conv2, 64, strides=(2,2,1), with_residual=True, with_se=True) #36x36x18 conv4 = conv_block(conv3, 128, strides=(2,2,2), with_residual=True, with_se=True) #18x18x9 conv5 = conv_block(conv4, 256, strides=(2,2,2), with_residual=True, with_se=True) #9x9x9 # First upsampling sequence up1_1 = Conv3DTranspose(128, (3,3,3), strides=(2,2,2), padding='same')(conv5) #18x18x9 up1_2 = Conv3DTranspose(128, (3,3,3), strides=(2,2,2), padding='same')(up1_1) #36x36x18 up1_3 = Conv3DTranspose(128, (3,3,3), strides=(2,2,1), padding='same')(up1_2) #72x72x18 bridge1 = concatenate([conv4, up1_1]) #18x18x9 (128+128=256) dec_conv_1 = conv_block(bridge1, 128, with_residual=True, with_se=True, activation='leaky') #18x18x9 # Second upsampling sequence up2_1 = Conv3DTranspose(64, (3,3,3), strides=(2,2,2), padding='same')(dec_conv_1) # 36x36x18 up2_2 = Conv3DTranspose(64, (3,3,3), strides=(2,2,1), padding='same')(up2_1) # 72x72x18 bridge2 = concatenate([conv3, up1_2, up2_1]) # 36x36x18 (64+128+64=256) dec_conv_2 = conv_block(bridge2, 64, with_residual=True, with_se=True, activation='leaky') # Final upsampling sequence up3_1 = Conv3DTranspose(32, (3,3,3), strides=(2,2,1), padding='same')(dec_conv_2) # 72x72x18 bridge3 = concatenate([conv2, up1_3, up2_2, up3_1]) # 72x72x18 (32+128+64+32=256) dec_conv_3 = conv_block(bridge3, 32, with_residual=True, with_se=True, activation='leaky') # Last upsampling to make heatmap up4_1 = Conv3DTranspose(16, (3,3,3), strides=(2,2,1), padding='same')(dec_conv_3) # 72x72x18 dec_conv_4 = conv_block(up4_1, 16, with_residual=False, with_se=True, activation='leaky') #144x144x18 (16) # Reduce to a single output channel with a 1x1x1 convolution single_channel = Conv3D(1, (1, 1, 1))(dec_conv_4) # Apply sigmoid activation to get binary prediction per voxel act = Activation('sigmoid')(single_channel) # Model definition model = Model(inputs=inputs, outputs=act) return model