ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • CNN ๋ชจ๋ธ์˜ classifier๋‹จ์—์„œ FC layer์˜ ์ž…๋ ฅ ๋…ธ๋“œ ๊ฐœ์ˆ˜
    AI\ML\DL/Pytorch 2023. 8. 3. 21:35
    ๋ฐ˜์‘ํ˜•

    ๏นก

     

    ์ด๋ฏธ์ง€ ๋‹ค์ค‘๋ถ„๋ฅ˜ ๋ฌธ์ œ์—์„œ CNN์„ ์‚ฌ์šฉํ•˜๋ฉด convolution layer ์™€ ReLU, Batchnorm2d, MaxPool2d๋ฅผ ๋ฐ˜๋ณตํ•˜๋‹ค๊ฐ€

    ๋งˆ์ง€๋ง‰์—๋Š” fully connected layer๋ฅผ ํ†ต๊ณผํ•ด์„œ softmax๋ฅผ ์–ป์–ด ๋ถ„๋ฅ˜๋ฅผ ํ•ด์ฃผ์–ด์•ผ ํ•œ๋‹ค.

     

    ์˜ˆ๋ฅผ ๋“ค์–ด CIFAR10 ๋ฐ์ดํ„ฐ์…‹์œผ๋กœ 10๊ฐ€์ง€๋ฅผ ๋ถ„๋ฅ˜ํ•˜๋Š” ๋„คํŠธ์›Œํฌ๋ฅผ ๋งŒ๋“ ๋‹ค๊ณ  ํ•˜์ž.

    ํ†ต์ƒ์ ์ธ ๋„คํŠธ์›์—์„œ๋Š” CNN ๋ ˆ์ด์–ด๋ฅผ ํ†ต๊ณผํ•ด์„œ ๋‹ค์–‘ํ•œ ํŠน์ง•๋งต์„ ๋ฝ‘์•„๋‚ด๊ณ  ํ’€๋ง์„ ํ•ด์„œ ์‚ฌ์ด์ฆˆ๋ฅผ ๋ฐ˜์œผ๋กœ ์ค„์ธ๋‹ค.

    (Maxpooling์„ ํ•œ๋‹ค๊ณ  ํ•ด์„œ ์ฑ„๋„์ˆ˜๊ฐ€ ์ค„์ง€๋Š” ์•Š๋Š”๋‹ค. ์ฑ„๋„๋ณ„๋กœ ๊ฐ๊ฐ ํ’€๋งํ•ด์ฃผ๋Š” ๊ฑฐ๋‹ˆ๊นŒ)

     

    ์ถฉ๋ถ„ํžˆ ํŠน์ง•๋งต์„ ์–ป์—ˆ๋‹ค๋ฉด nn.Linear( )๋ฅผ ํ™œ์šฉํ•ด์„œ

    ์ตœ์ข…์ ์ธ ์ถœ๋ ฅ ๋…ธ๋“œ๋ฅผ 10๊ฐœ๋กœ ์ค„์—ฌ์ค˜์•ผ ํ•  ๊ฒƒ์ด๋‹ค. 

    ์ด๋•Œ nn.Linear(?, 10)์˜ ? ๋ถ€๋ถ„์— ์–ด๋–ค๊ฒƒ์„ ์จ์•ผํ•  ์ง€ ์˜๋ฌธ์ด ๋“ค์—ˆ๋‹ค. 

    conv layer๋ฅผ ํ†ต๊ณผํ•œ ์งํ›„์— ํŠน์ง•๋งต์˜ ์ฑ„๋„์ˆ˜์™€ ํ–‰, ์—ด ํฌ๊ธฐ๋ฅผ ํ•œ๋ˆˆ์— ๋ด์„œ๋Š” ์•Œ ์ˆ˜ ์—†๊ธฐ ๋•Œ๋ฌธ์ด๋‹ค. 

    ์ด๊ฒƒ์„ ํ™•์ธํ•˜๊ธฐ ์œ„ํ•œ ๋ฐฉ๋ฒ•์€ nn.Linear() ๋ถ€๋ถ„์„ ์ฃผ์„์ฒ˜๋ฆฌ ํ•ด๋†“๊ณ  ๊ทธ ์ „๊นŒ์ง€ ํ†ต๊ณผํ•œ ๊ฒฐ๊ณผ์˜ shape์„ ํ™•์ธํ•ด๋ณด๋Š” ๊ฒƒ์ด๋‹ค. 

     

    import torch
    from torch import nn
    class CNN(nn.Module):
        def __init__(self):
            super().__init__()
    
            self.conv1 = nn.Conv2d(1,8,6,stride=2)
            self.conv2 = nn.Conv2d(8,16,3,padding=1)
            self.Maxpool2 = nn.MaxPool2d(2)
            # self.fc = nn.Linear(?,10)
    
        def forward(self, x):
            x = self.conv1(x)
            x = self.conv2(x)
            x = self.Maxpool2(x)
            return x
    
    x = torch.randn(32,1,28,28)
    model = CNN()
    print(model(x).shape)
    torch.Size([32, 16, 6, 6])

    ์œ„์—์„œ ํ™•์ธํ•  ์ˆ˜ ์žˆ๋“ฏ์ด ๊ทธ ๊ฒฐ๊ณผ๋Š” torch.size([32, 16, 6, 6]) ์ด์—ˆ๋‹ค.

    ์ฆ‰, fc layer๋ฅผ ํ†ต๊ณผํ•˜๊ธฐ ์ง์ „๊นŒ์ง€ 16๊ฐœ์˜ ์ฑ„๋„์„ ๊ฐ€์ง„ 6x6 ํฌ๊ธฐ์˜ feature map 32๊ฐœ๊ฐ€ ์ถœ๋ ฅ๋˜์—ˆ๋‹ค๋Š” ์˜๋ฏธ์ด๋‹ค. 

    ๋ฐ์ดํ„ฐ ๋ฐฐ์น˜ ์‚ฌ์ด์ฆˆ์ธ 32๋ฅผ ์ œ์™ธํ•˜๊ณ  16*6*6 ์€ flatten ์‹œ์ผœ์ฃผ์–ด ๋ฐ”๋กœ nn.Linear()์˜ ์ž…๋ ฅ ๋…ธ๋“œ ๊ฐœ์ˆ˜๋กœ์„œ ์ž…๋ ฅํ•ด์ค˜์•ผ ํ•œ๋‹ค. 

    ๋”ฐ๋ผ์„œ ? ์— ๋“ค์–ด๊ฐˆ ์ˆซ์ž๋Š” 16*6*6 ์ด ๋จ์„ ์•Œ ์ˆ˜ ์žˆ๋‹ค. 

     

    ์ฝ”๋“œ๋ฅผ ์ˆ˜์ •์‹œ์ผœ์ฃผ์–ด ํด๋ž˜์Šค๋ฅผ ์™„์„ฑ์‹œ์ผœ์ฃผ๋ฉด ๋‹ค์Œ๊ณผ ๊ฐ™๋‹ค. 

    fc ๋ ˆ์ด์–ด์— nn.Linear() ๋ฅผ ์ถ”๊ฐ€ํ•ด์ฃผ๊ณ  x๋ฅผ flatten ์‹œ์ผœ์ฃผ์—ˆ๋‹ค.

    import torch
    from torch import nn
    class CNN(nn.Module):
        def __init__(self):
            super().__init__()
    
            self.conv1 = nn.Conv2d(1,8,6,stride=2)
            self.conv2 = nn.Conv2d(8,16,3,padding=1)
            self.Maxpool2 = nn.MaxPool2d(2)
            self.fc = nn.Linear(16*6*6,10)
    
        def forward(self, x):
            x = self.conv1(x)
            x = self.conv2(x)
            x = self.Maxpool2(x)
            x = torch.flatten(x, start_dim=1)
            x = self.fc(x)
            return x

     

    * ์ฐธ๊ณ ํ•˜๊ธฐ

     

    https://pytorch.org/docs/stable/generated/torch.flatten.html

Designed by Tistory.