admin 管理员组

文章数量: 887021


2023年12月18日发(作者:mysql创建表单)

pytorch类别数不确定的损失函数

在机器学习任务中,损失函数是非常重要的一环。损失函数通常用于评估模型的性能以及指导模型的参数更新。在分类任务中,我们常常使用交叉熵损失函数进行训练。但是,如果我们的分类任务中的类别数不确定,怎么办呢?

先来回顾一下交叉熵损失函数。给定一个数据集 ${(x_i, y_i)}_{i=1}^n$,其中

$x_i$ 是输入数据,$y_i$ 是对应输入数据的真实标签。对于一次前向传播过程,我们计算出模型的预测值 $hat{y}$。交叉熵损失函数可以定义为:

$$L(y, hat{y}) = - sum_j y_j log(hat{y}_j)$$

其中 $y_j$ 是真实标签为 $j$ 的类别的概率值,$hat{y}_j$ 是模型预测标签为

$j$ 的类别的概率值。通常情况下,$y$ 是一个 one-hot 向量,即只有一个位置为 1,其他位置都是 0。这样,交叉熵损失函数可以简化为:

其中 $k$ 是真实标签的索引。

在不确定性量化框架中,我们引入一个额外的类别,表示所有不属于分类任务中明确定义的类别的数据点。这个类别可以称为“未知类别”(unknown class)或“噪声类别”(noise class)。因此,我们的分类任务可以被转化为一个多分类任务,其中包含了任意数量的明确定义的类别和一个未知类别。

对于明确定义的类别,我们可以使用交叉熵损失函数进行训练。对于未知类别,我们需要定义一个适当的损失函数。一种常见的定义是使用最大间距损失函数(maximum margin

loss function):

其中 $m$ 是一个超参数,表示不同类别之间的间隔距离。我们希望分类器将未知类别的概率最小化,而同时最大化不同类别之间的间隔距离。

然而,最大间距损失函数存在一个问题,即不用于训练时不利于梯度下降。因此,我们可以使用软最大化方法将最大间距损失函数转换为可微分的形式:

其中 $s$ 是一个超参数,表示 softmax 温度。当 $s$ 趋近于 0 时,软最大化方法将趋近于最大间距损失函数。

因此,不确定性量化框架中的总损失函数可以定义为:

其中 $L_{ce}$ 表示交叉熵损失函数,$lambda$ 是一个超参数,控制着最大间距损失函数的权重。

```python

import torch

import as nn

def forward(self, logits, targets):

ce_loss = ntropyLoss()(logits[:, :_classes],

targets)

# Calculate the maximum margin loss

batch_size = (0)

margin = ([batch_size, _classes, _classes])

*

margin[:, (_classes),

(_classes)] = 0.

mask = ([batch_size, _classes, _classes])

* (targets[:, None] != (_classes)[None, :])

margin = margin * (device=)

mm_loss = (exp(ature *

(logits[:, :_classes, None] - logits[:, None, :_classes] + margin),

dim=2))

return _ce * ce_loss + _mm * mm_loss

```

在计算最大间距损失函数时,我们使用了软最大化方法,将最大间距损失函数转换为可微分的形式。具体来说,我们使用了 PyTorch 内置的 `logsumexp` 函数来计算对数和,并使用超参数 `temperature` 来控制 softmax 温度。

最后,我们还需要注意到,我们将损失函数的计算限制在前 `num_classes` 个类别中,因为这些类别是我们明确定义的类别,而后面的类别则被视为未知类别。


本文标签: 损失 函数 类别 间距 定义