-
torch.cat vs torch.stackAI\ML\DL/Pytorch 2023. 9. 29. 22:14๋ฐ์ํ
torch.cat ๊ณผ torch.stack ์ ์๋ก ๋ค๋ฅธ ํ ์ ๊ฒฐํฉ ๋ฐฉ๋ฒ์ด๋ค.
1) torch.cat (Concatenate)
torch.cat ์ ์๋ก ๋ค๋ฅธ ํ๋ ฌ์ด์๋ ๋ ํ ์๋ฅผ ํ๋์ ํ ์๋ก ๋ง๋ค์ด์ค๋ค. ๋ฐ๋ผ์ ์ฐจ์์ ๋ณํ๊ฐ ์๊ธฐ์ง ์๋๋ค.
์๋ฅผ ๋ค์ด ๋ ๊ฐ์ 2x3 ์ง๋ฆฌ ํ๋ ฌ์ cat ํด์ฃผ๋ฉด ์์์ฃผ๋ ๋ฐฉํฅ์ ๋ฐ๋ผ 4x3 ํ๋ ฌ์ด๋ 2x6 ํ๋ ฌ์ด ๋๋ค.
import torch # 2x3 ํ๋ ฌ ํ๋ a = torch.tensor([[1,2,3],[4,5,6]]) # 2x3 ํ๋ ฌ ํ๋ b = torch.tensor([[7,8,9],[10,11,12]]) # ํ๋ฐฉํฅ ์๊ธฐ c = torch.cat([a,b], dim=0) # ์ด๋ฐฉํฅ ์๊ธฐ d = torch.cat([a,b], dim=1) print(c) print(d)
# c ํ ์, 4x3 ํ๋ ฌ ํ๋ tensor([[ 1, 2, 3], [ 4, 5, 6], [ 7, 8, 9], [10, 11, 12]]) # d ํ ์, 2x6 ํ๋ ฌ ํ๋ tensor([[ 1, 2, 3, 7, 8, 9], [ 4, 5, 6, 10, 11, 12]])
์ด๋ ๋ฏ ๋ ๊ฐ์ tensor ๋ฅผ catํ๋ฉด ์ฐจ์ ๋ณํ ์์ด ํ๋์ ํ ์๋ก ํฉ์ณ์ง๋ ๊ฒ์ ๋ณผ ์ ์๋ค.
2) torch.stack
torch.stack ํจ์๋ ์๋ก ๋ค๋ฅธ ํ ์๋ค์ ์๋ก์ด ์ฐจ์์ ์์ฑํ์ฌ ์ฐ๊ฒฐํด์ค๋ค.
๋ฐ๋ผ์ ๊ฒฐ๊ณผ ํ ์๋ ์ ๋ ฅ ํ ์๋ณด๋ค ์ฐจ์์ด ํ๋ ๋ ๋ง์ ๊ฒฝ์ฐ๊ฐ ๋ง๋ค. ์๋ฅผ ๋ค์ด 2x3 ํ๋ ฌ ๋๊ฐ๋ฅผ torch.stack ํ๋ฉด (2,2,3) ํ ์๊ฐ ๋๋ค.
torch.stack ์ ํ๋ฐฉํฅ๊ณผ ์ด๋ฐฉํฅ์ผ๋ก ์์์ ๋ ๊ฒฐ๊ณผ๊ฐ ๋ชจ๋ (2,2,3) ์ด์๋๋ฐ,
์ด์ฒ๋ผ ๊ธฐ์กด์ ํ๋ ฌ ์ฌ์ด์ฆ(2,3) ๋ ๋ณํํ์ง ์๊ณ , ์ฐจ์๋ง ํ๋ ๋ ๋์ด๋๊ฒ ๋๋ค.
import torch # 2x3 ํ๋ ฌ ํ๋ a = torch.tensor([[1,2,3],[4,5,6]]) # 2x3 ํ๋ ฌ ํ๋ b = torch.tensor([[7,8,9],[10,11,12]]) # ํ๋ฐฉํฅ ์๊ธฐ c = torch.stack([a,b], dim=0) # ์ด๋ฐฉํฅ ์๊ธฐ d = torch.stack([a,b], dim=1) print(c) print(d)
# c ํ ์, 2x3 ํ๋ ฌ 2๊ฐ tensor([[[ 1, 2, 3], [ 4, 5, 6]], [[ 7, 8, 9], [10, 11, 12]]]) # d ํ ์, 2x3 ํ๋ ฌ 2๊ฐ tensor([[[ 1, 2, 3], [ 7, 8, 9]], [[ 4, 5, 6], [10, 11, 12]]])
์ฃผ์ํ ์
torch.stack ์ด๋ torch.cat ์ ์ฌ์ฉํ ๋๋ ํจ์์ ์ธ์๋ก ๋ฆฌ์คํธ๋ ํํ์ ์ ๋ฌํด์ผ ํ๋ค.
torch.stack(a,b)์ ๊ฐ์ด ๊ทธ๋ฅ ํ ์๋ก๋ง ๋ฃ์ด์ค๋ค๋ฉด ๋ค์๊ณผ ๊ฐ์ ์ค๋ฅ๊ฐ ๋ฐ์ํ ๊ฒ์ด๋ค.
TypeError: stack(): argument 'tensors' (position 1) must be tuple of Tensors, not Tensor
๋ฐ๋ผ์ ๊ผญ ํํ์ด๋ ๋ฆฌ์คํธ๋ก ๋ฌถ์ ๋ค์ ์๋ง์ ํจ์๋ฅผ ์ฌ์ฉํด์ผ ํ๋ค.
'AI\ML\DL > Pytorch' ์นดํ ๊ณ ๋ฆฌ์ ๋ค๋ฅธ ๊ธ
๋ถ์ฐ ๋ฐ ๋ณ๋ ฌ ํ์ต (0) 2023.12.13 torch.cumprod() w.r.t. diffusion noise scheduling (1) 2023.12.03 CNN ๋ชจ๋ธ์ classifier๋จ์์ FC layer์ ์ ๋ ฅ ๋ ธ๋ ๊ฐ์ (0) 2023.08.03 PyTorch์ Numpy์์์ ์ด๋ฏธ์ง ๋ฐ์ดํฐ ํฌ๋งท ์ฐจ์ด (0) 2023.08.01