PyTorch中的torch.gather函数详解

作者:起个名字好难2024.02.16 10:16浏览量:28

简介:torch.gather是一个在PyTorch中常用的函数,用于在指定维度上根据索引值从输入张量中提取数据。本文将详细介绍该函数的工作原理和用法,并通过示例代码演示其应用。

千帆应用开发平台“智能体Pro”全新上线 限时免费体验

面向慢思考场景,支持低代码配置的方式创建“智能体Pro”应用

立即体验

torch.gather是PyTorch中的一个函数,用于在指定维度上根据索引值从输入张量中提取数据。这个函数允许用户在张量的特定维度上选择性地收集数据,从而实现更灵活的数据处理和分析。

函数的基本语法如下:

  1. torch.gather(input, dim, index, *, sparse_grad=False, out=None)

参数说明:

  • input:输入张量,即要从中提取数据的原始张量。
  • dim:要索引的维,即选择数据的维度。
  • index:要收集的元素的索引,即从哪个位置提取数据。
  • sparse_grad:一个可选参数,默认为False。如果为True,关于input的梯度将是稀疏张量。
  • out:可选参数,用于指定输出张量的形状和类型。

工作原理:

torch.gather函数的工作原理是根据指定的维度和索引值,从输入张量中提取对应位置的数据,并组合成一个新的张量返回。具体来说,它会沿着指定的维度(dim)遍历索引张量(index),并从输入张量(input)中提取对应位置的数据。提取出的数据将按照索引张量的形状重新组合成一个新的张量。

值得注意的是,输入张量和索引张量之间不会进行广播操作。这意味着它们的形状必须完全匹配,或者输入张量的形状必须能够被索引张量的形状所容纳。

使用示例:

下面是一个使用torch.gather函数的示例代码,演示了如何从一个形状为[3, 4]的输入张量中提取数据:

  1. import torch
  2. # 创建一个形状为[3, 4]的输入张量
  3. input = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]])
  4. # 创建一个形状为[3, 2]的索引张量
  5. index = torch.tensor([[0, 2], [1, 3], [2, 1]])
  6. # 使用torch.gather函数提取数据
  7. output = torch.gather(input, dim=1, index=index)
  8. # 打印输出结果
  9. print(output)

输出结果:

  1. tensor([[1, 3],
  2. [6, 8],
  3. [9, 10]])

在上面的示例中,我们创建了一个形状为[3, 4]的输入张量input和一个形状为[3, 2]的索引张量index。然后,我们使用torch.gather函数沿着维度1(即列方向)从input中提取对应位置的数据,得到一个新的形状为[3, 2]的输出张量output。最后,我们打印输出结果tensor([[1, 3], [6, 8], [9, 10]])。

article bottom image

相关文章推荐

发表评论