Pytorchのgatherの挙動
インターンで実装を行っているときに,torchのgatherの挙動で少し困ったのでまとめてみた.
Pytorchのgatherの挙動
サンプルコード
まずサンプルコードとその出力結果を確認する
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import torch
input = torch.tensor([
[1, 2, 3],
[4, 5, 6],
[7, 8, 9]
])
indices = torch.tensor([
[1, 2, 0],
[0, 1, 2],
[2, 0, 1]
])
result1 = torch.gather(input=input, dim=0, index=indices)
result2 = torch.gather(input=input, dim=1, index=indices)
print(result1)
print(result2)
これに対する出力結果が
1
2
3
4
5
6
tensor([[4, 8, 3],
[1, 5, 9],
[7, 2, 6]])
tensor([[2, 3, 1],
[4, 5, 6],
[9, 7, 8]])
となる.
解説
gatherとは「集める」の意味で,torch.gatherは各次元にそってindexの値を集める.
dim=0のときは,dim=0にそって各indexの値をその位置に出力するようになっている.
おわりに
まとめたことによって理解できた気がする. 図にしてみるとなんとなくわかるが,また今度やるときには忘れてそう.
This post is licensed under CC BY 4.0 by the author.