PyTorch中的torch.gather函数详解
2024.02.16 10:16浏览量:28简介:torch.gather是一个在PyTorch中常用的函数,用于在指定维度上根据索引值从输入张量中提取数据。本文将详细介绍该函数的工作原理和用法,并通过示例代码演示其应用。
千帆应用开发平台“智能体Pro”全新上线 限时免费体验
面向慢思考场景,支持低代码配置的方式创建“智能体Pro”应用
torch.gather是PyTorch中的一个函数,用于在指定维度上根据索引值从输入张量中提取数据。这个函数允许用户在张量的特定维度上选择性地收集数据,从而实现更灵活的数据处理和分析。
函数的基本语法如下:
torch.gather(input, dim, index, *, sparse_grad=False, out=None)
参数说明:
input
:输入张量,即要从中提取数据的原始张量。dim
:要索引的维,即选择数据的维度。index
:要收集的元素的索引,即从哪个位置提取数据。sparse_grad
:一个可选参数,默认为False。如果为True,关于input的梯度将是稀疏张量。out
:可选参数,用于指定输出张量的形状和类型。
工作原理:
torch.gather函数的工作原理是根据指定的维度和索引值,从输入张量中提取对应位置的数据,并组合成一个新的张量返回。具体来说,它会沿着指定的维度(dim)遍历索引张量(index),并从输入张量(input)中提取对应位置的数据。提取出的数据将按照索引张量的形状重新组合成一个新的张量。
值得注意的是,输入张量和索引张量之间不会进行广播操作。这意味着它们的形状必须完全匹配,或者输入张量的形状必须能够被索引张量的形状所容纳。
使用示例:
下面是一个使用torch.gather函数的示例代码,演示了如何从一个形状为[3, 4]的输入张量中提取数据:
import torch
# 创建一个形状为[3, 4]的输入张量
input = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]])
# 创建一个形状为[3, 2]的索引张量
index = torch.tensor([[0, 2], [1, 3], [2, 1]])
# 使用torch.gather函数提取数据
output = torch.gather(input, dim=1, index=index)
# 打印输出结果
print(output)
输出结果:
tensor([[1, 3],
[6, 8],
[9, 10]])
在上面的示例中,我们创建了一个形状为[3, 4]的输入张量input和一个形状为[3, 2]的索引张量index。然后,我们使用torch.gather函数沿着维度1(即列方向)从input中提取对应位置的数据,得到一个新的形状为[3, 2]的输出张量output。最后,我们打印输出结果tensor([[1, 3], [6, 8], [9, 10]])。

发表评论
登录后可评论,请前往 登录 或 注册