ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • torch.cat vs torch.stack
    AI\ML\DL/Pytorch 2023. 9. 29. 22:14
    ๋ฐ˜์‘ํ˜•

    torch.cat ๊ณผ torch.stack ์€ ์„œ๋กœ ๋‹ค๋ฅธ ํ…์„œ ๊ฒฐํ•ฉ ๋ฐฉ๋ฒ•์ด๋‹ค. 

     

    1) torch.cat (Concatenate)

    torch.cat ์€ ์„œ๋กœ ๋‹ค๋ฅธ ํ–‰๋ ฌ์ด์—ˆ๋˜ ๋‘ ํ…์„œ๋ฅผ ํ•˜๋‚˜์˜ ํ…์„œ๋กœ ๋งŒ๋“ค์–ด์ค€๋‹ค. ๋”ฐ๋ผ์„œ ์ฐจ์›์˜ ๋ณ€ํ™”๊ฐ€ ์ƒ๊ธฐ์ง€ ์•Š๋Š”๋‹ค. 

    ์˜ˆ๋ฅผ ๋“ค์–ด ๋‘ ๊ฐœ์˜ 2x3 ์งœ๋ฆฌ ํ–‰๋ ฌ์„ cat ํ•ด์ฃผ๋ฉด ์Œ“์•„์ฃผ๋Š” ๋ฐฉํ–ฅ์— ๋”ฐ๋ผ 4x3 ํ–‰๋ ฌ์ด๋‚˜ 2x6 ํ–‰๋ ฌ์ด ๋œ๋‹ค. 

     

    import torch
    
    # 2x3 ํ–‰๋ ฌ ํ•˜๋‚˜
    a = torch.tensor([[1,2,3],[4,5,6]])
    # 2x3 ํ–‰๋ ฌ ํ•˜๋‚˜
    b = torch.tensor([[7,8,9],[10,11,12]])
    
    # ํ–‰๋ฐฉํ–ฅ ์Œ“๊ธฐ
    c = torch.cat([a,b], dim=0)
    
    # ์—ด๋ฐฉํ–ฅ ์Œ“๊ธฐ
    d = torch.cat([a,b], dim=1)
    
    print(c)
    print(d)
    # c ํ…์„œ, 4x3 ํ–‰๋ ฌ ํ•˜๋‚˜
    tensor([[ 1,  2,  3],
            [ 4,  5,  6],
            [ 7,  8,  9],
            [10, 11, 12]])
    
    # d ํ…์„œ, 2x6 ํ–‰๋ ฌ ํ•˜๋‚˜
    tensor([[ 1,  2,  3,  7,  8,  9],
            [ 4,  5,  6, 10, 11, 12]])

    ์ด๋ ‡๋“ฏ ๋‘ ๊ฐœ์˜ tensor ๋ฅผ catํ•˜๋ฉด ์ฐจ์› ๋ณ€ํ™” ์—†์ด ํ•˜๋‚˜์˜ ํ…์„œ๋กœ ํ•ฉ์ณ์ง€๋Š” ๊ฒƒ์„ ๋ณผ ์ˆ˜ ์žˆ๋‹ค. 

     

    2) torch.stack

    torch.stack ํ•จ์ˆ˜๋Š” ์„œ๋กœ ๋‹ค๋ฅธ ํ…์„œ๋“ค์„ ์ƒˆ๋กœ์šด ์ฐจ์›์„ ์ƒ์„ฑํ•˜์—ฌ ์—ฐ๊ฒฐํ•ด์ค€๋‹ค.

    ๋”ฐ๋ผ์„œ ๊ฒฐ๊ณผ ํ…์„œ๋Š” ์ž…๋ ฅ ํ…์„œ๋ณด๋‹ค ์ฐจ์›์ด ํ•˜๋‚˜ ๋” ๋งŽ์€ ๊ฒฝ์šฐ๊ฐ€ ๋งŽ๋‹ค. ์˜ˆ๋ฅผ ๋“ค์–ด 2x3 ํ–‰๋ ฌ ๋‘๊ฐœ๋ฅผ torch.stack ํ•˜๋ฉด (2,2,3) ํ…์„œ๊ฐ€ ๋œ๋‹ค.

    torch.stack ์€ ํ–‰๋ฐฉํ–ฅ๊ณผ ์—ด๋ฐฉํ–ฅ์œผ๋กœ ์Œ“์•˜์„ ๋•Œ ๊ฒฐ๊ณผ๊ฐ€ ๋ชจ๋‘ (2,2,3) ์ด์—ˆ๋Š”๋ฐ,

    ์ด์ฒ˜๋Ÿผ ๊ธฐ์กด์˜ ํ–‰๋ ฌ ์‚ฌ์ด์ฆˆ(2,3) ๋Š” ๋ณ€ํ™”ํ•˜์ง€ ์•Š๊ณ , ์ฐจ์›๋งŒ ํ•˜๋‚˜ ๋” ๋Š˜์–ด๋‚˜๊ฒŒ ๋œ๋‹ค.

     

    import torch
    
    # 2x3 ํ–‰๋ ฌ ํ•˜๋‚˜
    a = torch.tensor([[1,2,3],[4,5,6]])
    # 2x3 ํ–‰๋ ฌ ํ•˜๋‚˜
    b = torch.tensor([[7,8,9],[10,11,12]])
    
    # ํ–‰๋ฐฉํ–ฅ ์Œ“๊ธฐ
    c = torch.stack([a,b], dim=0)
    # ์—ด๋ฐฉํ–ฅ ์Œ“๊ธฐ
    d = torch.stack([a,b], dim=1)
    
    print(c)
    print(d)
    # c ํ…์„œ, 2x3 ํ–‰๋ ฌ 2๊ฐœ
    tensor([[[ 1,  2,  3],
             [ 4,  5,  6]],
    
            [[ 7,  8,  9],
             [10, 11, 12]]])
             
    # d ํ…์„œ, 2x3 ํ–‰๋ ฌ 2๊ฐœ
    tensor([[[ 1,  2,  3],
             [ 7,  8,  9]],
    
            [[ 4,  5,  6],
             [10, 11, 12]]])

    ์ฃผ์˜ํ• ์ 

    torch.stack ์ด๋‚˜ torch.cat ์„ ์‚ฌ์šฉํ•  ๋•Œ๋Š” ํ•จ์ˆ˜์˜ ์ธ์ž๋กœ ๋ฆฌ์ŠคํŠธ๋‚˜ ํŠœํ”Œ์„ ์ „๋‹ฌํ•ด์•ผ ํ•œ๋‹ค.

    torch.stack(a,b)์™€ ๊ฐ™์ด ๊ทธ๋ƒฅ ํ…์„œ๋กœ๋งŒ ๋„ฃ์–ด์ค€๋‹ค๋ฉด ๋‹ค์Œ๊ณผ ๊ฐ™์€ ์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ•  ๊ฒƒ์ด๋‹ค.

    TypeError: stack(): argument 'tensors' (position 1) must be tuple of Tensors, not Tensor

     

    ๋”ฐ๋ผ์„œ ๊ผญ ํŠœํ”Œ์ด๋‚˜ ๋ฆฌ์ŠคํŠธ๋กœ ๋ฌถ์€ ๋‹ค์Œ ์•Œ๋งž์€ ํ•จ์ˆ˜๋ฅผ ์‚ฌ์šฉํ•ด์•ผ ํ•œ๋‹ค. 

Designed by Tistory.