
1. 基本作用torch.gather的作用是从 input 的指定维度 dim 上按照 index 给出的索引位置取值。基本语法output torch.gather(input, dim, index)基本公式 三维举例 dim1 output[b][k][d] input[b][index[b][k][d]][d]其中input原始张量 dim指定在哪个维度上取值 index索引张量 output取出的结果2. 核心规则torch.gather有一个非常重要的规则output 的形状和 index 的形状相同。也就是说output.shape index.shapeindex不只是告诉从哪里取值它还决定了最终输出张量的形状。3. 二维例子假设import torch input torch.tensor([ [10, 20, 30], [40, 50, 60] ]) index torch.tensor([ [0, 2], [1, 0] ]) output torch.gather(input, dim1, indexindex)因为dim1所以是在列方向取值。取值过程output[0][0] input[0][index[0][0]] input[0][0] 10 output[0][1] input[0][index[0][1]] input[0][2] 30 output[1][0] input[1][index[1][0]] input[1][1] 50 output[1][1] input[1][index[1][1]] input[1][0] 40最终结果tensor([ [10, 30], [50, 40] ])4. 三维例子假设input.shape [B, N, D]含义是Bbatch size样本数量 N每个样本中的 patch 数量 D每个 patch 的特征维度如果index.shape [B, K, D]并且output torch.gather(input, dim1, indexindex)那么output.shape [B, K, D]因为dim1所以是在N这个维度上取值。核心公式是output[b][k][d] input[b][index[b][k][d]][d]解释B 维保持对应 D 维保持对应 只有 N 维根据 index[b][k][d] 指定的位置取值5. 结合 patch 选择代码理解常见代码_, indices torch.topk(attention_weights, k, dim1) selected_patches torch.gather( patches, 1, indices.unsqueeze(-1).expand(-1, -1, D) )假设patches.shape [B, N, D] attention_weights.shape [B, N] indices.shape [B, K]其中B样本数量 Npatch 数量 D每个 patch 的特征维度 K要选出的 patch 数量torch.topk得到的是每个样本中分数最高的K个 patch 索引indices.shape [B, K]但是patches是三维张量patches.shape [B, N, D]所以需要先扩展索引indices.unsqueeze(-1)形状变为[B, K, 1]再使用expand(-1, -1, D)形状变为[B, K, D]这样才能和patches的三维结构对应起来。6. 为什么要 expand 到 D 维因为每个 patch 不是一个数而是一个D维特征向量。如果某个 patch 的索引是indices[b][k] 3扩展后变成index[b][k] [3, 3, 3, ..., 3]长度是D。于是output[b][k][0] input[b][3][0] output[b][k][1] input[b][3][1] output[b][k][2] input[b][3][2] ... output[b][k][D-1] input[b][3][D-1]也就是把第3个 patch 的完整D维特征全部取出来。7. 最终效果对于代码selected_patches torch.gather( patches, 1, indices.unsqueeze(-1).expand(-1, -1, D) )它的作用是从每个样本的 N 个 patch 中 根据 top-k 得到的索引 选出 K 个重要 patch 并保留每个 patch 的完整 D 维特征。形状变化patches: [B, N, D] indices: [B, K] expanded index: [B, K, D] selected_patches: [B, K, D]8. 记忆方法可以这样记gather 按照 index从 input 的某个 dim 维度上取值。如果output torch.gather(input, dim1, indexindex)那么就是在第 1 维上取值 其他维度保持对应关系 output 的形状等于 index 的形状。对于三维张量input.shape [B, N, D] index.shape [B, K, D] dim 1核心公式output[b][k][d] input[b][index[b][k][d]][d]