-
5๏ธโฃ Pytorch ๊ธฐ์ด_autogradAI\ML\DL/Pytorch 2023. 7. 14. 17:08๋ฐ์ํ
1) requires_grad
ํ ์ ๊ฐ์ฒด์ attribute ๋ก, ๊ทธ๋๋์ธํธ์ ๊ณ์ฐ ์ฌ๋ถ๋ฅผ ๊ฒฐ์ ํ๋ค.
requires_grad ์์ฑ์ True ๋ก ์ค์ ํ์ฌ ํ ์์ ๋ํ ๊ทธ๋๋์ธํธ ๊ณ์ฐ์ ํ์ฑํํ ์ ์๋ค.
requres_grad ์์ฑ์ด False ๋ก ์ค์ ๋ ํ ์์ ๋ํด์๋ ๊ทธ๋๋์ธํธ๋ฅผ ๊ณ์ฐํ์ง ์๋๋ค. ์ด๋ฅผ ํตํด ๋ถํ์ํ ๊ณ์ฐ์ ์ค์ฌ ์ฐ์ฐ ์๋๋ฅผ ํฅ์์ํฌ ์ ์๋ค.
2) ์ญ์ ํ ๊ณ์ฐ
x=torch.tensor([1.], requires_grad=True) y=x**2 y.retain_grad() z=3*y z.backward() print(x.grad) print(y.grad)
x -> x**2 -> 3x**2 ์์ผ๋ก ์งํ๋๋ฏ๋ก
x ์ ๋ํด์ 3x**2 ๋ฅผ chain rule์ ํตํด ํธ๋ฏธ๋ถํ๋ฉด ๊ฒฐ๊ณผ๋ 6์ด ๋๋ค.
์ถ๋ ฅ์ ๋ค์๊ณผ ๊ฐ๋ค.
tensor([6.]) tensor([3.])
์ค๊ฐ์ ์๋ y ์ ๋ฏธ๋ถ ๊ณ์ฐ์ ์๋ ๋์ง ์๋๋ฐ, y.retain_grad ๋ก ๋ฏธ๋ถ๊ฐ์ ์ ์ง์์ผ ์ฃผ์์ผ๋ฏ๋ก ๋ณผ ์ ์๋ค.
3) detach ์ torch.no_grad
detach() ๋ฉ์๋๋ ํ ์๋ฅผ ๊ทธ๋ํ ๊ณ์ฐ์์ ๋ถ๋ฆฌํ๋ค. ์ฆ, ๊ทธ๋๋์ธํธ ๊ณ์ฐ์ ๋นํ์ฑํํ๊ณ ํด๋น ํ ์๋ฅผ ์๋ก์ด ํ ์๋ก ๋ง๋ ๋ค. ์ฒ์์ requires_grad ๋ฅผ True๋ก ์ค์ ํ๋๋ผ๋ detach๋ฅผ ํด์ฃผ๋ฉด ์ญ์ ํ๋ฅผ ์ํํ ๋ ์ค๋ฅ๊ฐ ๋ฐ์ํ๋ค.
torch.no_grad() ๋ํ ๊ทธ๋๋์ธํธ ๊ณ์ฐ์ ๋นํ์ฑํํด์ฃผ๋ ์ญํ ์ ํ๋ฉฐ ์ด๋ฅผ ์ฌ์ฉํ์ฌ ํน์ ์์ญ์ ๊ฐ์ธ๋ฉด (ex. with) ๊ทธ ์์์ ์ํ๋๋ ์ฐ์ฐ์ ๊ทธ๋๋์ธํธ๊ฐ ๊ณ์ฐ๋์ง ์๋๋ค.
'AI\ML\DL > Pytorch' ์นดํ ๊ณ ๋ฆฌ์ ๋ค๋ฅธ ๊ธ
torch.utils.data.DataLoader (0) 2023.07.25 nn.Linear (0) 2023.07.14 4๏ธโฃ Pytorch ๊ธฐ์ด_์ฌ๋ฌ ํจ์๋ค (0) 2023.07.14 3๏ธโฃ Pytorch ๊ธฐ์ด_boolean ์ธ๋ฑ์ฑ (0) 2023.07.06