-
CNN ๋ชจ๋ธ์ classifier๋จ์์ FC layer์ ์ ๋ ฅ ๋ ธ๋ ๊ฐ์AI\ML\DL/Pytorch 2023. 8. 3. 21:35๋ฐ์ํ
๏นก
์ด๋ฏธ์ง ๋ค์ค๋ถ๋ฅ ๋ฌธ์ ์์ CNN์ ์ฌ์ฉํ๋ฉด convolution layer ์ ReLU, Batchnorm2d, MaxPool2d๋ฅผ ๋ฐ๋ณตํ๋ค๊ฐ
๋ง์ง๋ง์๋ fully connected layer๋ฅผ ํต๊ณผํด์ softmax๋ฅผ ์ป์ด ๋ถ๋ฅ๋ฅผ ํด์ฃผ์ด์ผ ํ๋ค.
์๋ฅผ ๋ค์ด CIFAR10 ๋ฐ์ดํฐ์ ์ผ๋ก 10๊ฐ์ง๋ฅผ ๋ถ๋ฅํ๋ ๋คํธ์ํฌ๋ฅผ ๋ง๋ ๋ค๊ณ ํ์.
ํต์์ ์ธ ๋คํธ์์์๋ CNN ๋ ์ด์ด๋ฅผ ํต๊ณผํด์ ๋ค์ํ ํน์ง๋งต์ ๋ฝ์๋ด๊ณ ํ๋ง์ ํด์ ์ฌ์ด์ฆ๋ฅผ ๋ฐ์ผ๋ก ์ค์ธ๋ค.
(Maxpooling์ ํ๋ค๊ณ ํด์ ์ฑ๋์๊ฐ ์ค์ง๋ ์๋๋ค. ์ฑ๋๋ณ๋ก ๊ฐ๊ฐ ํ๋งํด์ฃผ๋ ๊ฑฐ๋๊น)
์ถฉ๋ถํ ํน์ง๋งต์ ์ป์๋ค๋ฉด nn.Linear( )๋ฅผ ํ์ฉํด์
์ต์ข ์ ์ธ ์ถ๋ ฅ ๋ ธ๋๋ฅผ 10๊ฐ๋ก ์ค์ฌ์ค์ผ ํ ๊ฒ์ด๋ค.
์ด๋ nn.Linear(?, 10)์ ? ๋ถ๋ถ์ ์ด๋ค๊ฒ์ ์จ์ผํ ์ง ์๋ฌธ์ด ๋ค์๋ค.
conv layer๋ฅผ ํต๊ณผํ ์งํ์ ํน์ง๋งต์ ์ฑ๋์์ ํ, ์ด ํฌ๊ธฐ๋ฅผ ํ๋์ ๋ด์๋ ์ ์ ์๊ธฐ ๋๋ฌธ์ด๋ค.
์ด๊ฒ์ ํ์ธํ๊ธฐ ์ํ ๋ฐฉ๋ฒ์ nn.Linear() ๋ถ๋ถ์ ์ฃผ์์ฒ๋ฆฌ ํด๋๊ณ ๊ทธ ์ ๊น์ง ํต๊ณผํ ๊ฒฐ๊ณผ์ shape์ ํ์ธํด๋ณด๋ ๊ฒ์ด๋ค.
import torch from torch import nn class CNN(nn.Module): def __init__(self): super().__init__() self.conv1 = nn.Conv2d(1,8,6,stride=2) self.conv2 = nn.Conv2d(8,16,3,padding=1) self.Maxpool2 = nn.MaxPool2d(2) # self.fc = nn.Linear(?,10) def forward(self, x): x = self.conv1(x) x = self.conv2(x) x = self.Maxpool2(x) return x x = torch.randn(32,1,28,28) model = CNN() print(model(x).shape)
torch.Size([32, 16, 6, 6])
์์์ ํ์ธํ ์ ์๋ฏ์ด ๊ทธ ๊ฒฐ๊ณผ๋ torch.size([32, 16, 6, 6]) ์ด์๋ค.
์ฆ, fc layer๋ฅผ ํต๊ณผํ๊ธฐ ์ง์ ๊น์ง 16๊ฐ์ ์ฑ๋์ ๊ฐ์ง 6x6 ํฌ๊ธฐ์ feature map 32๊ฐ๊ฐ ์ถ๋ ฅ๋์๋ค๋ ์๋ฏธ์ด๋ค.
๋ฐ์ดํฐ ๋ฐฐ์น ์ฌ์ด์ฆ์ธ 32๋ฅผ ์ ์ธํ๊ณ 16*6*6 ์ flatten ์์ผ์ฃผ์ด ๋ฐ๋ก nn.Linear()์ ์ ๋ ฅ ๋ ธ๋ ๊ฐ์๋ก์ ์ ๋ ฅํด์ค์ผ ํ๋ค.
๋ฐ๋ผ์ ? ์ ๋ค์ด๊ฐ ์ซ์๋ 16*6*6 ์ด ๋จ์ ์ ์ ์๋ค.
์ฝ๋๋ฅผ ์์ ์์ผ์ฃผ์ด ํด๋์ค๋ฅผ ์์ฑ์์ผ์ฃผ๋ฉด ๋ค์๊ณผ ๊ฐ๋ค.
fc ๋ ์ด์ด์ nn.Linear() ๋ฅผ ์ถ๊ฐํด์ฃผ๊ณ x๋ฅผ flatten ์์ผ์ฃผ์๋ค.
import torch from torch import nn class CNN(nn.Module): def __init__(self): super().__init__() self.conv1 = nn.Conv2d(1,8,6,stride=2) self.conv2 = nn.Conv2d(8,16,3,padding=1) self.Maxpool2 = nn.MaxPool2d(2) self.fc = nn.Linear(16*6*6,10) def forward(self, x): x = self.conv1(x) x = self.conv2(x) x = self.Maxpool2(x) x = torch.flatten(x, start_dim=1) x = self.fc(x) return x
* ์ฐธ๊ณ ํ๊ธฐ
https://pytorch.org/docs/stable/generated/torch.flatten.html 'AI\ML\DL > Pytorch' ์นดํ ๊ณ ๋ฆฌ์ ๋ค๋ฅธ ๊ธ
torch.cumprod() w.r.t. diffusion noise scheduling (1) 2023.12.03 torch.cat vs torch.stack (0) 2023.09.29 PyTorch์ Numpy์์์ ์ด๋ฏธ์ง ๋ฐ์ดํฐ ํฌ๋งท ์ฐจ์ด (0) 2023.08.01 torch.utils.data.DataLoader (0) 2023.07.25