-
Stochastic Gradient DescentAI\ML\DL/Deep learning theory 2023. 5. 6. 22:17๋ฐ์ํ
Vanilla GD vs. SGD
Gradient descent๋ ๋ชจ๋ ๋ฐ์ดํฐ๋ฅผ ๋ค ๊ณ ๋ คํ๊ธฐ ๋๋ฌธ์ ์ต์๋ฅผ ํฅํ๋ ์๋ฒฝํ ๋ฐฉํฅ์ผ๋ก ๋์๊ฐ๋ค. ํ์ง๋ง, ๋ชจ๋ ๋ฐ์ดํฐ๋ฅผ ๋ค ๊ณ ๋ คํ๊ธฐ ๋๋ฌธ์ ์๊ฐ์ด ๋๋ฌด ๋๋ฆฌ๋ค.
Stochastic gradient descent (SGD)๋ GD์ ๋ฌ๋ฆฌ ์ ์ฒด ๋ฐ์ดํฐ๋ฅผ ๋ชจ๋ ์ ๋ ฅํ์ง ์๊ณ ๋๋คํ๊ฒ ์ถ์ถํ ๋ฐ์ดํฐ๋ฅผ ํ๋์ฉ ์ ๋ ฅํด์ loss function๋ฅผ ๋ง๋๋ ๋ฐฉ๋ฒ์ด๋ค.
SGD๋ฅผ ์ฌ์ฉํ๋ฉด GD์์ ๋ฐ์ํ๋ ๊ณ์ฐ ์๋ ๋ฌธ์ ์ local minimum์ ๋น ์ง๋ ํ๊ณ๋ฅผ ๊ทน๋ณตํ ์ ์๋ค.
SGD์ ํน์ง
1. ๋๋คํ๊ฒ ๋ฐ์ดํฐ๋ฅผ ๋น๋ณต์์ถ์ถ๋ก ํ๋์ฉ ๋ฝ์์ loss๋ฅผ ๋ง๋ค๊ณ gradient๋ฅผ ๊ณ์ฐํ๋ค. (๋๋ค์ด๋ผ stochastic์ด๋ผ๋ ์ด๋ฆ์ด ๋ถ์๋ค.)
์ฆ, ๋ฐ์ดํฐ ํ๋๋ง ๋ณด๊ณ ๊ทธ๋๋์ธํธ๋ฅผ ๊ฒฐ์ ํ๊ธฐ ๋๋ฌธ์ ๋น ๋ฅด๊ฒ ์์น๋ฅผ ์ ๋ฐ์ดํธํ๋ค. ๋ฐ๋ผ์ gradient ๋ฐฉํฅ์ด ํญ์ ์ผ์ ํ์ง๋ ์๋ค.)
2. GD๋ ๋ฌด์กฐ๊ฑด ์ ์ฒด ๋ฐ์ดํฐ๋ฅผ ๋ค ๊ณ ๋ คํ ์์คํจ์๋ฅผ ๋ณด๊ณ ๊ทธ๋๋์ธํธ๋ฅผ ๊ณ์ฐํ๊ธฐ ๋๋ฌธ์ ํญ์ ๊ฐ์ ๋ฐฉํฅ์ผ๋ก ํฅํด์ local minimum์ ๋น ์ง ์ํ์ด ์๋ค. ํ์ง๋ง SGD๋ ํ๋์ ๋ฐ์ดํฐ๋ง ๋ฝ์์ ์์คํจ์๋ฅผ ๊ณ์ฐํ๊ธฐ ๋๋ฌธ์ iteration๋ง๋ค ์์คํจ์๊ฐ ๋ค๋ฅผ ์ ์๊ณ gradient ๋ํ ๋ฌ๋ผ์ง๋ค. ์ด๋ ๊ฒ ๋๋ฌธ์ local minimum ์์ ํ์ถํ ๊ธฐํ๊ฐ ์๋ค.
ํ์ง๋ง SGD๋ ๋ช๋ฐฑ๋ง๊ฐ์ ๋ฐ์ดํฐ ์ค์์ ํ๋์ฉ๋ง ๋ณด๊ณ ๊ฑ๋ง ๋ง์ถ๋ ค๊ณ ํ๊ธฐ ๋๋ฌธ์ ๋๋ฌด ์ฑ๊ธํ๊ฒ ๋ฐฉํฅ(gradient)๋ฅผ ๊ฒฐ์ ํด์ ์ ์คํ์ง ์๊ณ , ์ฆ์ ์์ง์์ด ์๋ค..
์ด๊ฒ ๋๋ฌธ์ ๋ฑ์ฅํ ๊ฒ์ด mini-batch SGD์ด๋ค.
mini-batch SGD๋ 2๊ฐ ์ด์์ฉ ๋ฐ์ดํฐ๋ฅผ ์ถ์ถํด์ loss๋ฅผ ๋ง๋๋ ๋ฐฉ๋ฒ์ด๋ค. ์๋ฅผ๋ค์ด mini-batch size = 2 ๋ผ๋ฉด, ๋ฐ์ดํฐ๋ฅผ 2๊ฐ์ฉ ๋น๋ณต์์ถ์ถ๋ก ๋๋คํ๊ฒ ๋ฝ๋๋ค๋ ์๋ฏธ์ด๋ค. (iteration์ด ๋ ๋จ์๋๋ฐ ๋จ์ ๋ฐ์ดํฐ๊ฐ ํ๋๋ผ๋ฉด mini batch ์ฌ์ด์ฆ๊ฐ 2๋ผ๋ ๊ทธ๋ฅ ํ๋๋ง ๋ฝ๋๋ค. ์์ธ๋๋ ๋ฐ์ดํฐ๊ฐ ์๋๋ก ํ๊ธฐ ์ํจ์ด๋ค.)
Mini-batch SGD์์ ๋ฌด๋ฆฌํ๊ฒ batch size๋ฅผ ํค์ฐ๋ฉด ์ฑ๋ฅ์ด ์์ข์์ง๋ค๋ ์ฐ๊ตฌ ๊ฒฐ๊ณผ๊ฐ ์๋ค.
batch size๊ฐ ์ปค์ง๋ฉด ๊ทธ๋งํผ GD์ ๋น์ทํด์ง๋ ๊ฑฐ๋๊น ์์ข์ local minimum์ผ๋ก ๋น ์ง ๊ฐ๋ฅ์ฑ์ด ์ปค์ง๊ธฐ ๋๋ฌธ์ด๋ค.
https://arxiv.org/abs/1706.02677
Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour
Deep learning thrives with large neural networks and large datasets. However, larger networks and larger datasets result in longer training times that impede research and development progress. Distributed synchronous SGD offers a potential solution to this
arxiv.org
๋ฐ๋ผ์ ์ด ๋ ผ๋ฌธ์์๋ batch size๋ฅผ ํค์ฐ๋ ค๋ฉด learning rate๋ ๊ฐ์ด ํค์ฐ๊ณ , warmup๋ ํด์ผ ๊ทธ๋๋ง ์์ batch size์ผ๋์ ์ฑ๋ฅ์ ์ป์ ์ ์๋ค๊ณ ํ๋ค.
'AI\ML\DL > Deep learning theory' ์นดํ ๊ณ ๋ฆฌ์ ๋ค๋ฅธ ๊ธ
Logistic Regression (0) 2023.05.08 Backpropagation (0) 2023.05.07 Momentum, RMSProp Optimizer (0) 2023.05.06 Gradient descent (0) 2023.05.06