深入理解PyTorch的交叉熵损失函数:参数、说明与实践
2023.12.25 07:15浏览量:4简介:PyTorch中的交叉熵损失函数(CrossEntropy)是一种常用的损失函数,用于多分类问题。这个函数计算的是真实标签与预测标签之间的交叉熵距离。以下是关于PyTorch的交叉熵损失函数的输入参数以及相关说明:
千帆应用开发平台“智能体Pro”全新上线 限时免费体验
面向慢思考场景,支持低代码配置的方式创建“智能体Pro”应用
立即体验
PyTorch中的交叉熵损失函数(CrossEntropy)是一种常用的损失函数,用于多分类问题。这个函数计算的是真实标签与预测标签之间的交叉熵距离。以下是关于PyTorch的交叉熵损失函数的输入参数以及相关说明:
- 输入参数:
input
(Tensor): 模型的输出,是一个未归一化的对数概率张量,形状为(batch_size, C, …),其中C是类别数,…表示其他维度。target
(LongTensor or Tensor): 目标标签,是一个形状为(batch_size, …)的一维张量,其中每个元素的值在0到C-1之间。如果是LongTensor,它的类型必须是torch.int64。weight
(Tensor, optional): 一个可选的权重张量,与input
和target
具有相同的形状。默认值为None。reduction
(str, optional): 指定如何对每个类别的损失进行归约。可选的值有’none’、’mean’和’sum’。默认值为’mean’。
- 函数说明:
- input:这是一个未归一化的对数概率张量,表示模型对于每个类别的预测概率。它的值在0到1之间(包括0和1),并且经过softmax函数处理。在多分类问题中,通常使用模型的最后一层的输出作为
input
。 - target:这是真实标签的张量,每个元素表示一个样本所属的类别。它的值应该在0到C-1之间,其中C是类别数。注意,
target
可以是任意形状的一维张量,只要和input
的batch_size维度相匹配即可。 - weight:这是一个可选参数,表示每个类别的权重。它的默认值为None,表示所有类别的权重相等。如果指定了权重,则每个类别的损失将根据其权重进行计算。这在处理类别不平衡问题时非常有用。
- reduction:这个参数决定了如何对每个类别的损失进行归约。如果设置为’none’,则返回每个类别的损失;如果设置为’mean’,则将所有类别的损失相加然后平均;如果设置为’sum’,则将所有类别的损失相加。在大多数情况下,使用默认值’mean’就足够了。
- 返回值:
- 如果
reduction='none'
,则返回一个形状为(batch_size, C, …)的张量,其中包含每个类别的损失; - 如果
reduction='mean'
或reduction='sum'
,则返回一个标量,表示所有类别的损失的总和。
- 示例代码:
在上面的示例中,我们首先导入了PyTorch库和交叉熵损失函数。然后我们随机生成了一个形状为(3, 3)的张量作为模型的输出,并随机生成了3个类别标签作为真实标签。最后我们计算了交叉熵损失并进行了反向传播计算梯度。import torch
import torch.nn as nn
# 假设我们有一个3分类问题
criterion = nn.CrossEntropyLoss()
input = torch.randn(3, 3, requires_grad=True) # 随机生成一个形状为(3, 3)的张量作为模型输出
target = torch.empty(3, dtype=torch.long).random_(3) # 随机生成3个类别标签
output = criterion(input, target) # 计算交叉熵损失
output.backward() # 反向传播计算梯度

发表评论
登录后可评论,请前往 登录 或 注册