深入理解PyTorch的交叉熵损失函数:参数、说明与实践

作者:狼烟四起2023.12.25 07:15浏览量:4

简介:PyTorch中的交叉熵损失函数(CrossEntropy)是一种常用的损失函数,用于多分类问题。这个函数计算的是真实标签与预测标签之间的交叉熵距离。以下是关于PyTorch的交叉熵损失函数的输入参数以及相关说明:

千帆应用开发平台“智能体Pro”全新上线 限时免费体验

面向慢思考场景,支持低代码配置的方式创建“智能体Pro”应用

立即体验

PyTorch中的交叉熵损失函数(CrossEntropy)是一种常用的损失函数,用于多分类问题。这个函数计算的是真实标签与预测标签之间的交叉熵距离。以下是关于PyTorch的交叉熵损失函数的输入参数以及相关说明:

  1. 输入参数
  • input (Tensor): 模型的输出,是一个未归一化的对数概率张量,形状为(batch_size, C, …),其中C是类别数,…表示其他维度。
  • target (LongTensor or Tensor): 目标标签,是一个形状为(batch_size, …)的一维张量,其中每个元素的值在0到C-1之间。如果是LongTensor,它的类型必须是torch.int64。
  • weight (Tensor, optional): 一个可选的权重张量,与inputtarget具有相同的形状。默认值为None。
  • reduction (str, optional): 指定如何对每个类别的损失进行归约。可选的值有’none’、’mean’和’sum’。默认值为’mean’。
  1. 函数说明
  • input:这是一个未归一化的对数概率张量,表示模型对于每个类别的预测概率。它的值在0到1之间(包括0和1),并且经过softmax函数处理。在多分类问题中,通常使用模型的最后一层的输出作为input
  • target:这是真实标签的张量,每个元素表示一个样本所属的类别。它的值应该在0到C-1之间,其中C是类别数。注意,target可以是任意形状的一维张量,只要和input的batch_size维度相匹配即可。
  • weight:这是一个可选参数,表示每个类别的权重。它的默认值为None,表示所有类别的权重相等。如果指定了权重,则每个类别的损失将根据其权重进行计算。这在处理类别不平衡问题时非常有用。
  • reduction:这个参数决定了如何对每个类别的损失进行归约。如果设置为’none’,则返回每个类别的损失;如果设置为’mean’,则将所有类别的损失相加然后平均;如果设置为’sum’,则将所有类别的损失相加。在大多数情况下,使用默认值’mean’就足够了。
  1. 返回值
  • 如果reduction='none',则返回一个形状为(batch_size, C, …)的张量,其中包含每个类别的损失;
  • 如果reduction='mean'reduction='sum',则返回一个标量,表示所有类别的损失的总和。
  1. 示例代码
    1. import torch
    2. import torch.nn as nn
    3. # 假设我们有一个3分类问题
    4. criterion = nn.CrossEntropyLoss()
    5. input = torch.randn(3, 3, requires_grad=True) # 随机生成一个形状为(3, 3)的张量作为模型输出
    6. target = torch.empty(3, dtype=torch.long).random_(3) # 随机生成3个类别标签
    7. output = criterion(input, target) # 计算交叉熵损失
    8. output.backward() # 反向传播计算梯度
    在上面的示例中,我们首先导入了PyTorch库和交叉熵损失函数。然后我们随机生成了一个形状为(3, 3)的张量作为模型的输出,并随机生成了3个类别标签作为真实标签。最后我们计算了交叉熵损失并进行了反向传播计算梯度。
article bottom image

相关文章推荐

发表评论

图片