ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • 5๏ธโƒฃ Pytorch ๊ธฐ์ดˆ_autograd
    AI\ML\DL/Pytorch 2023. 7. 14. 17:08
    ๋ฐ˜์‘ํ˜•

    1) requires_grad

    ํ…์„œ ๊ฐ์ฒด์˜ attribute ๋กœ, ๊ทธ๋ž˜๋””์–ธํŠธ์˜ ๊ณ„์‚ฐ ์—ฌ๋ถ€๋ฅผ ๊ฒฐ์ •ํ•œ๋‹ค. 

    requires_grad ์†์„ฑ์„ True ๋กœ ์„ค์ •ํ•˜์—ฌ ํ…์„œ์— ๋Œ€ํ•œ ๊ทธ๋ž˜๋””์–ธํŠธ ๊ณ„์‚ฐ์„ ํ™œ์„ฑํ™”ํ•  ์ˆ˜ ์žˆ๋‹ค. 

    requres_grad ์†์„ฑ์ด False ๋กœ ์„ค์ •๋œ ํ…์„œ์— ๋Œ€ํ•ด์„œ๋Š” ๊ทธ๋ž˜๋””์–ธํŠธ๋ฅผ ๊ณ„์‚ฐํ•˜์ง€ ์•Š๋Š”๋‹ค. ์ด๋ฅผ ํ†ตํ•ด ๋ถˆํ•„์š”ํ•œ ๊ณ„์‚ฐ์„ ์ค„์—ฌ ์—ฐ์‚ฐ ์†๋„๋ฅผ ํ–ฅ์ƒ์‹œํ‚ฌ ์ˆ˜ ์žˆ๋‹ค. 

     

    2) ์—ญ์ „ํŒŒ ๊ณ„์‚ฐ

    x=torch.tensor([1.], requires_grad=True)
    y=x**2
    y.retain_grad()
    z=3*y
    z.backward()
    print(x.grad)
    print(y.grad)

    x -> x**2 -> 3x**2 ์ˆœ์œผ๋กœ ์ง„ํ–‰๋˜๋ฏ€๋กœ

    x ์— ๋Œ€ํ•ด์„œ 3x**2 ๋ฅผ chain rule์„ ํ†ตํ•ด ํŽธ๋ฏธ๋ถ„ํ•˜๋ฉด ๊ฒฐ๊ณผ๋Š” 6์ด ๋œ๋‹ค. 

    ์ถœ๋ ฅ์€ ๋‹ค์Œ๊ณผ ๊ฐ™๋‹ค. 

    tensor([6.])
    tensor([3.])

    ์ค‘๊ฐ„์— ์žˆ๋Š” y ์˜ ๋ฏธ๋ถ„ ๊ณ„์‚ฐ์€ ์›๋ž˜ ๋˜์ง€ ์•Š๋Š”๋ฐ, y.retain_grad ๋กœ ๋ฏธ๋ถ„๊ฐ’์„ ์œ ์ง€์‹œ์ผœ ์ฃผ์—ˆ์œผ๋ฏ€๋กœ ๋ณผ ์ˆ˜ ์žˆ๋‹ค. 

     

    3) detach ์™€ torch.no_grad

     

    detach() ๋ฉ”์„œ๋“œ๋Š” ํ…์„œ๋ฅผ ๊ทธ๋ž˜ํ”„ ๊ณ„์‚ฐ์—์„œ ๋ถ„๋ฆฌํ•œ๋‹ค. ์ฆ‰, ๊ทธ๋ž˜๋””์–ธํŠธ ๊ณ„์‚ฐ์„ ๋น„ํ™œ์„ฑํ™”ํ•˜๊ณ  ํ•ด๋‹น ํ…์„œ๋ฅผ ์ƒˆ๋กœ์šด ํ…์„œ๋กœ ๋งŒ๋“ ๋‹ค. ์ฒ˜์Œ์— requires_grad ๋ฅผ True๋กœ ์„ค์ •ํ–ˆ๋”๋ผ๋„ detach๋ฅผ ํ•ด์ฃผ๋ฉด ์—ญ์ „ํŒŒ๋ฅผ ์ˆ˜ํ–‰ํ•  ๋•Œ ์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ•œ๋‹ค. 

     

    torch.no_grad() ๋˜ํ•œ ๊ทธ๋ž˜๋””์–ธํŠธ ๊ณ„์‚ฐ์„ ๋น„ํ™œ์„ฑํ™”ํ•ด์ฃผ๋Š” ์—ญํ• ์„ ํ•˜๋ฉฐ ์ด๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ํŠน์ • ์˜์—ญ์„ ๊ฐ์‹ธ๋ฉด (ex. with) ๊ทธ ์•ˆ์—์„œ ์ˆ˜ํ–‰๋˜๋Š” ์—ฐ์‚ฐ์€ ๊ทธ๋ž˜๋””์–ธํŠธ๊ฐ€ ๊ณ„์‚ฐ๋˜์ง€ ์•Š๋Š”๋‹ค. 

     

    'AI\ML\DL > Pytorch' ์นดํ…Œ๊ณ ๋ฆฌ์˜ ๋‹ค๋ฅธ ๊ธ€

    torch.utils.data.DataLoader  (0) 2023.07.25
    nn.Linear  (0) 2023.07.14
    4๏ธโƒฃ Pytorch ๊ธฐ์ดˆ_์—ฌ๋Ÿฌ ํ•จ์ˆ˜๋“ค  (0) 2023.07.14
    3๏ธโƒฃ Pytorch ๊ธฐ์ดˆ_boolean ์ธ๋ฑ์‹ฑ  (0) 2023.07.06
Designed by Tistory.