-
nn.LinearAI\ML\DL/Pytorch 2023. 7. 14. 21:04๋ฐ์ํ
Fully connected layer
nn.Linear(์ ๋ ฅ๋ ธ๋ ๊ฐ์, ์ถ๋ ฅ๋ ธ๋ ๊ฐ์)
์์ ๋ฌธ์๋ฅผ ๋ณด๋ฉด,
nn.Linear() ์์ ์ ๋ ฅ ๋ฒกํฐ $x$ ์ ๊ฐ์ค์น ๋ฒกํฐ $w$ ๋ฅผ ๊ณฑํด์ค ๋, ๊ฐ์ค์น $w$์ transpose๋ฅผ ์ทจํด์ค์ ๊ณฑํด์ค๋ค.
๊ทธ ์ด์ ๋ ๋ฌด์์ผ๊น?
mlp ์์ ๋ ธ๋๋ ์ฑ๋์ ์๋ฏธํ๋ค.
nn.Linear ๋ ์ ๋ ฅ์ผ๋ก 1D data๊ฐ ๋ค์ด์ค๊ธธ ๊ธฐ๋ํ๋๋ฐ,
๋ฐ์ดํฐ๋ฅผ ์ฌ๋ฌ ๊ฐ ํต๊ณผ์ํฌ ๋ ํน์ ์ฑ๋ ์๋ฅผ ๊ฐ์ง ๋ฐ์ดํฐ n๊ฐ๋ฅผ (๊ฐ์n $\times$ ์ฑ๋) ํํ๋ก ํต๊ณผ์์ผ์ผ ํ๋ค.
weight ๋ํ (๊ฐ์x์ฑ๋) ํํ๋ก ํํํด์ฃผ์.
์ผ๋จ weight์ ์ฑ๋์ ์์ ์๋ ๋ ธ๋์ ์ฑ๋์ด๋ ๋ง์ถฐ์ผ ํด์ ๊ฐ์ค์น์ ์ฑ๋์ ์ ๋ ฅ ๋ ธ๋์ ์ฑ๋์๋ก ์ ํด์ง ์ํ์ด๋ค.
์๋ฅผ ๋ค์ด nn.Linear(3,2) ์ผ๋ก ์ ๋ ฅ ๋ ธ๋ 3๊ฐ์์ ์ถ๋ ฅ ๋ ธ๋ 2๊ฐ๋ฅผ ๋ง๋ ๋ค๊ณ ํ ๋,
๊ฐ์ค์น์ ์ฑ๋ ํฌ๊ธฐ๋ 3์ด๋ค. (=์ ๋ ฅ๋ ธ๋์ ์ฑ๋์, ๊ฐ์ค์น์ ์ฑ๋์๋ ์ ๋ ฅ๋ ธ๋์ ์ฑ๋์์ ์ผ์น)
๊ทธ๋ฆฌ๊ณ 3์ฑ๋์ ๊ฐ์ง๊ณ 2๊ฐ์ ๋ ธ๋๋ฅผ ๋ง๋๋ ๊ฑฐ๋ผ์ ๊ฐ์๋ 2๊ฐ๊ฐ ๋๋ค.
์ฆ, weight์ shape์ 2x3(๊ฐx์ฒด) ๊ฐ ๋๋ค.
์ ๋ ฅ ๋ ธ๋์ shape ์ด 1x3 ์ด๋ผ๊ณ ํ๋ค๋ฉด ํ๋ ฌ์ ๊ณฑ์ ์ ์ํ ํฌ๊ธฐ๋ฅผ ๋ง์ถฐ์ฃผ๊ธฐ ์ํด ๊ฐ์ค์น์ ๊ผญ transpose๋ฅผ ํด์ค์
shape์ 3x2๋ก ๋ง์ถฐ์ค์ผ ํ๋๊ฒ์ด๋ค.
- nn.Linear(3,2): ์ ๋ ฅ ๋ ธ๋ 3๊ฐ, ์ถ๋ ฅ ๋ ธ๋ 2๊ฐ์ธ ๋ ์ด์ด๋ฅผ ๊ฐ์ง ๋ชจ๋ธ
- ์ ๋ ฅ ํ ์: 3๊ฐ์ ์ฑ๋์ ๊ฐ์ง ๋ฐ์ดํฐ 5๊ฐ
- ์ถ๋ ฅ ํ ์: 2๊ฐ์ ์ฑ๋์ ๊ฐ์ง ๋ฐ์ดํฐ 5๊ฐ
- $\textbf{w}$ ํ ์: 3๊ฐ์ ์ฑ๋์ ๊ฐ์ง ๋ฐ์ดํฐ (๊ฐ์ค์น ์ธํธ) 2๊ฐ
์ด๋ฅผ ๊ทธ๋ฆผ์ผ๋ก ๋ํ๋ด๋ฉด ๋ค์๊ณผ ๊ฐ๋ค.
nn.Linear(3,2) ์์ ๋ฐ์ดํฐ ๊ฐ์๋ ์ผ๋ง๋ ์ง ๋ฐ๋ ์ ์์ง๋ง (์ํ๋ ๋ฐ์ดํฐ ๊ฐ์ ๋งํผ)
์ฑ๋์๋ ๋ง์ถฐ์ค์ผ ํ๋ค.
'AI\ML\DL > Pytorch' ์นดํ ๊ณ ๋ฆฌ์ ๋ค๋ฅธ ๊ธ
PyTorch์ Numpy์์์ ์ด๋ฏธ์ง ๋ฐ์ดํฐ ํฌ๋งท ์ฐจ์ด (0) 2023.08.01 torch.utils.data.DataLoader (0) 2023.07.25 5๏ธโฃ Pytorch ๊ธฐ์ด_autograd (0) 2023.07.14 4๏ธโฃ Pytorch ๊ธฐ์ด_์ฌ๋ฌ ํจ์๋ค (0) 2023.07.14