78 lines
2.5 KiB
Python
Executable File
78 lines
2.5 KiB
Python
Executable File
import tensorflow as tf
|
|
from tensorflow.image import ssim
|
|
from tensorflow.keras import backend as K
|
|
|
|
def weighted_binary_cross_entropy(weights: dict, from_logits: bool = False):
|
|
'''
|
|
Return a function for calculating weighted binary cross entropy
|
|
It should be used for multi-hot encoded labels
|
|
|
|
# Example
|
|
y_true = tf.convert_to_tensor([1, 0, 0, 0, 0, 0], dtype=tf.int64)
|
|
y_pred = tf.convert_to_tensor([0.6, 0.1, 0.1, 0.9, 0.1, 0.], dtype=tf.float32)
|
|
weights = {
|
|
0: 1.,
|
|
1: 2.
|
|
}
|
|
# with weights
|
|
loss_fn = get_loss_for_multilabels(weights=weights, from_logits=False)
|
|
loss = loss_fn(y_true, y_pred)
|
|
print(loss)
|
|
# tf.Tensor(0.6067193, shape=(), dtype=float32)
|
|
|
|
# without weights
|
|
loss_fn = get_loss_for_multilabels()
|
|
loss = loss_fn(y_true, y_pred)
|
|
print(loss)
|
|
# tf.Tensor(0.52158177, shape=(), dtype=float32)
|
|
|
|
# Another example
|
|
y_true = tf.convert_to_tensor([[0., 1.], [0., 0.]], dtype=tf.float32)
|
|
y_pred = tf.convert_to_tensor([[0.6, 0.4], [0.4, 0.6]], dtype=tf.float32)
|
|
weights = {
|
|
0: 1.,
|
|
1: 2.
|
|
}
|
|
# with weights
|
|
loss_fn = get_loss_for_multilabels(weights=weights, from_logits=False)
|
|
loss = loss_fn(y_true, y_pred)
|
|
print(loss)
|
|
# tf.Tensor(1.0439969, shape=(), dtype=float32)
|
|
|
|
# without weights
|
|
loss_fn = get_loss_for_multilabels()
|
|
loss = loss_fn(y_true, y_pred)
|
|
print(loss)
|
|
# tf.Tensor(0.81492424, shape=(), dtype=float32)
|
|
|
|
@param weights A dict setting weights for 0 and 1 label. e.g.
|
|
{
|
|
0: 1.
|
|
1: 8.
|
|
}
|
|
For this case, we want to emphasise those true (1) label,
|
|
because we have many false (0) label. e.g.
|
|
[
|
|
[0 1 0 0 0 0 0 0 0 1]
|
|
[0 0 0 0 1 0 0 0 0 0]
|
|
[0 0 0 0 1 0 0 0 0 0]
|
|
]
|
|
|
|
|
|
|
|
@param from_logits If False, we apply sigmoid to each logit
|
|
@return A function to calcualte (weighted) binary cross entropy
|
|
'''
|
|
assert 0 in weights
|
|
assert 1 in weights
|
|
|
|
def weighted_cross_entropy_fn(y_true, y_pred):
|
|
tf_y_true = tf.cast(y_true, dtype=y_pred.dtype)
|
|
tf_y_pred = tf.cast(y_pred, dtype=y_pred.dtype)
|
|
|
|
weights_v = tf.where(tf.equal(tf_y_true, 1), weights[1], weights[0])
|
|
ce = K.binary_crossentropy(tf_y_true, tf_y_pred, from_logits=from_logits)
|
|
loss = K.mean(tf.multiply(ce, weights_v))
|
|
return loss
|
|
|
|
return weighted_cross_entropy_fn |