在 keras 中创建自定义损失函数

2025-02-08 08:52:00
admin
原创
46
摘要:问题描述:嗨,我一直在尝试在 keras 中为 dice_error_coefficient 创建自定义损失函数。它在tensorboard中有实现,我尝试在 keras 中使用 tensorflow 中的相同函数,但当我使用model.train_on_batch或model.fit时,它一直返回NoneT...

问题描述:

嗨,我一直在尝试在 keras 中为 dice_error_coefficient 创建自定义损失函数。它在tensorboard中有实现,我尝试在 keras 中使用 tensorflow 中的相同函数,但当我使用model.train_on_batchmodel.fit时,它一直返回NoneType,而当在模型的指标中使用时,它会给出正确的值。有人能帮我吗?我试过使用 ahundt 的 Keras-FCN 等库,他使用了自定义损失函数,但似乎都不起作用。代码中的目标和输出分别是 y_true 和 y_pred,如 keras 中losses.py文件中使用的。

def dice_hard_coe(target, output, threshold=0.5, axis=[1,2], smooth=1e-5):
    """References
    -----------
    - `Wiki-Dice <https://en.wikipedia.org/wiki/Sørensen–Dice_coefficient>`_
    """

    output = tf.cast(output > threshold, dtype=tf.float32)
    target = tf.cast(target > threshold, dtype=tf.float32)
    inse = tf.reduce_sum(tf.multiply(output, target), axis=axis)
    l = tf.reduce_sum(output, axis=axis)
    r = tf.reduce_sum(target, axis=axis)
    hard_dice = (2. * inse + smooth) / (l + r + smooth)
    hard_dice = tf.reduce_mean(hard_dice)
    return hard_dice

解决方案 1:

在 Keras 中实现参数化自定义损失函数有两个步骤。首先,编写一个系数/度量方法。其次,编写一个包装函数,按照 Keras 需要的方式格式化内容。

  1. 对于 DICE 等简单的自定义损失函数,直接使用 Keras 后端而不是 tensorflow 实际上要简洁得多。以下是通过这种方式实现的系数示例:

import keras.backend as K
def dice_coef(y_true, y_pred, smooth, thresh):
    y_pred = y_pred > thresh
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)

    return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)
  1. 现在到了棘手的部分。Keras 损失函数只能将 (y_true, y_pred) 作为参数。因此,我们需要一个返回另一个函数的单独函数。

def dice_loss(smooth, thresh):
  def dice(y_true, y_pred)
    return -dice_coef(y_true, y_pred, smooth, thresh)
  return dice

最后,您可以在Keras编译中按如下方式使用它。

# build model 
model = my_model()
# get the loss function
model_dice = dice_loss(smooth=1e-5, thresh=0.5)
# compile model
model.compile(loss=model_dice)

解决方案 2:

根据文档,您可以使用如下自定义损失函数:

loss_fn(y_true, y_pred)任何具有返回损失数组(输入批次中的样本之一)的签名的可调用函数都可以作为损失传递给 compile()。请注意,对于任何此类损失,都会自动支持样本加权。

举一个简单的例子:

def my_loss_fn(y_true, y_pred):
    squared_difference = tf.square(y_true - y_pred)
    return tf.reduce_mean(squared_difference, axis=-1)  # Note the `axis=-1`

model.compile(optimizer='adam', loss=my_loss_fn)

完整示例:

import tensorflow as tf
import numpy as np

def my_loss_fn(y_true, y_pred):
    squared_difference = tf.square(y_true - y_pred)
    return tf.reduce_mean(squared_difference, axis=-1)  # Note the `axis=-1`

model = tf.keras.Sequential([
    tf.keras.layers.Dense(8, activation='relu'),
    tf.keras.layers.Dense(16, activation='relu'),
    tf.keras.layers.Dense(1)])

model.compile(optimizer='adam', loss=my_loss_fn)

x = np.random.rand(1000)
y = x**2

history = model.fit(x, y, epochs=10)

解决方案 3:

此外,你可以通过继承来扩展现有的损失函数。例如,屏蔽BinaryCrossEntropy

class MaskedBinaryCrossentropy(tf.keras.losses.BinaryCrossentropy):
    def call(self, y_true, y_pred):
        mask = y_true != -1
        y_true = y_true[mask]
        y_pred = y_pred[mask]
        return super().call(y_true, y_pred)

一个很好的起点是custom log指南:https ://www.tensorflow.org/guide/keras/train_and_evaluate#custom_losses

相关推荐
  政府信创国产化的10大政策解读一、信创国产化的背景与意义信创国产化,即信息技术应用创新国产化,是当前中国信息技术领域的一个重要发展方向。其核心在于通过自主研发和创新,实现信息技术应用的自主可控,减少对外部技术的依赖,并规避潜在的技术制裁和风险。随着全球信息技术竞争的加剧,以及某些国家对中国在科技领域的打压,信创国产化显...
工程项目管理   1565  
  为什么项目管理通常仍然耗时且低效?您是否还在反复更新电子表格、淹没在便利贴中并参加每周更新会议?这确实是耗费时间和精力。借助软件工具的帮助,您可以一目了然地全面了解您的项目。如今,国内外有足够多优秀的项目管理软件可以帮助您掌控每个项目。什么是项目管理软件?项目管理软件是广泛行业用于项目规划、资源分配和调度的软件。它使项...
项目管理软件   1354  
  信创国产芯片作为信息技术创新的核心领域,对于推动国家自主可控生态建设具有至关重要的意义。在全球科技竞争日益激烈的背景下,实现信息技术的自主可控,摆脱对国外技术的依赖,已成为保障国家信息安全和产业可持续发展的关键。国产芯片作为信创产业的基石,其发展水平直接影响着整个信创生态的构建与完善。通过不断提升国产芯片的技术实力、产...
国产信创系统   21  
  信创生态建设旨在实现信息技术领域的自主创新和安全可控,涵盖了从硬件到软件的全产业链。随着数字化转型的加速,信创生态建设的重要性日益凸显,它不仅关乎国家的信息安全,更是推动产业升级和经济高质量发展的关键力量。然而,在推进信创生态建设的过程中,面临着诸多复杂且严峻的挑战,需要深入剖析并寻找切实可行的解决方案。技术创新难题技...
信创操作系统   27  
  信创产业作为国家信息技术创新发展的重要领域,对于保障国家信息安全、推动产业升级具有关键意义。而国产芯片作为信创产业的核心基石,其研发进展备受关注。在信创国产芯片的研发征程中,面临着诸多复杂且艰巨的难点,这些难点犹如一道道关卡,阻碍着国产芯片的快速发展。然而,科研人员和相关企业并未退缩,积极探索并提出了一系列切实可行的解...
国产化替代产品目录   28  
热门文章
项目管理软件有哪些?
云禅道AD
禅道项目管理软件

云端的项目管理软件

尊享禅道项目软件收费版功能

无需维护,随时随地协同办公

内置subversion和git源码管理

每天备份,随时转为私有部署

免费试用