ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • torch.utils.data.DataLoader
    AI\ML\DL/Pytorch 2023. 7. 25. 18:41
    ๋ฐ˜์‘ํ˜•

    DataLoader

    DataLoader ๋Š” ๋ฐ์ดํ„ฐ์…‹์„ ๋ฏธ๋‹ˆ ๋ฐฐ์น˜ ๋‹จ์œ„๋กœ ๋ถ„ํ• ํ•˜์—ฌ ํ•™์Šตํ”„๋กœ์„ธ์Šค์— ์ œ๊ณตํ•ด์ฃผ๋Š” ์—ญํ• ์„ ํ•œ๋‹ค. 

    ๋ฐ์ดํ„ฐ๋กœ๋” (DataLoader) ๊ฐ์ฒด๋Š” ํ•™์Šต์— ์‚ฌ์šฉ๋  ๋ฐ์ดํ„ฐ ์ „์ฒด๋ฅผ ๋ณด๊ด€ํ–ˆ๋‹ค๊ฐ€ ๋ชจ๋ธ ํ•™์Šต์„ ํ•  ๋•Œ ๋ฐฐ์น˜ ํฌ๊ธฐ๋งŒํผ ๋ฐ์ดํ„ฐ๋ฅผ ๊บผ๋‚ด์„œ ์‚ฌ์šฉํ•œ๋‹ค. 

    ์ด๋•Œ ์ฃผ์˜ํ•  ๊ฒƒ์€ ๋ฐ์ดํ„ฐ๋ฅผ ๋ฏธ๋ฆฌ ์ž˜๋ผ ๋†“๋Š” ๊ฒƒ์ด ์•„๋‹ˆ๋ผ ๋‚ด๋ถ€์ ์œผ๋กœ ๋ฐ˜๋ณต์ž (iterator)์— ํฌํ•จ๋œ ์ธ๋ฑ์Šค (index)๋ฅผ ์ด์šฉํ•˜์—ฌ ๋ฐฐ์น˜ ํฌ๊ธฐ๋งŒํผ ๋ฐ์ดํ„ฐ๋ฅผ ๋ฐ˜ํ™˜ํ•œ๋‹ค๋Š” ๊ฒƒ์ด๋‹ค. 

     

    • DataLoader ์‚ฌ์šฉ ์˜ˆ์‹œ

    ์•„๋ž˜์˜ ์ฝ”๋“œ๋Š” ๊ฐ๊ฐ 5๊ฐœ์˜ ์ƒ˜ํ”Œ์„ ๊ฐ€์ง„ ๋ฐ์ดํ„ฐ (data; example data) ์™€ ๋ ˆ์ด๋ธ” (labels) ํ…์„œ๋กœ dataset์„ ๋งŒ๋“  ํ›„์—

    dataloader๋ฅผ ๋งŒ๋“ค์–ด์„œ ๋ฐฐ์น˜ ํฌ๊ธฐ๋งŒํผ ๋ฐ์ดํ„ฐ๋ฅผ ๋น„๋ณต์›์ถ”์ถœํ•˜์—ฌ ํ”„๋ฆฐํŠธํ•˜๋Š” ๊ฐ„๋‹จํ•œ ์ฝ”๋“œ์ด๋‹ค. 

    ์ด๋ฅผ ์œ„ํ•ด ํ…์„œ๋‚˜ csvํŒŒ์ผ์ด ๋  ์ˆ˜ ์žˆ๋Š” dataset์„ torch.utils.data.DataLoader์— ํŒŒ๋ผ๋ฏธํ„ฐ๋กœ ์ „๋‹ฌํ•˜์˜€๋‹ค.

     

    import torch
    from torch.utils.data import DataLoader, TensorDataset
    
    # Example data and labels
    data = torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]])
    labels = torch.tensor([0, 1, 0, 1, 0])
    
    # Creating a TensorDataset
    dataset = TensorDataset(data, labels)
    
    # Creating a DataLoader
    batch_size = 2
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    
    # Iterating over the DataLoader
    for minibatch in dataloader:
        inputs, targets = minibatch
        # Perform operations with the current mini-batch, e.g., model training
        print("Input batch:", inputs)
        print("Target batch:", targets)

    DataLoader๋Š” iterable (๋ฐ˜๋ณต ๊ฐ€๋Šฅ) ํ•œ ๊ฐ์ฒด์ด๋‹ค. ๋”ฐ๋ผ์„œ for๋ฌธ์„ ์‚ฌ์šฉํ•˜์—ฌ 2๊ฐœ ๋ฐฐ์น˜๋งŒํผ ์ถ”์ถœ๋œ data์™€ label์„ ํ”„๋ฆฐํŠธํ•  ์ˆ˜ ์žˆ๋‹ค. 

    ์ถœ๋ ฅ์€ ๋‹ค์Œ๊ณผ ๊ฐ™๋‹ค. 

    Input batch: tensor([[1, 2],
            [7, 8]])
    Target batch: tensor([0, 1])
    Input batch: tensor([[ 3,  4],
            [ 9, 10]])
    Target batch: tensor([1, 0])
    Input batch: tensor([[5, 6]])
    Target batch: tensor([0])

     

    ์ฐธ๊ณ ๋กœ ์ด ์ƒ˜ํ”Œ ์ˆ˜๊ฐ€ ๋ฐฐ์น˜ ํฌ๊ธฐ๋กœ ๋‚˜๋ˆ„์–ด ๋–จ์–ด์ง€์ง€ ์•Š๋Š” ๊ฒฝ์šฐ ๋‚˜๋จธ์ง€ ๋งŒํผ์˜ ๋” ์ ์€ ์ˆ˜์˜ ์ƒ˜ํ”Œ์„ ๊ฐ€์งˆ ์ˆ˜ ์žˆ๋‹ค.

    ์ด ์ฝ”๋“œ์—์„œ๋„ ์ด ์ƒ˜ํ”Œ ๊ฐœ์ˆ˜์ธ 5๊ฐœ๊ฐ€ ๋ฐฐ์น˜์‚ฌ์ด์ฆˆ์ธ 2๋กœ ๋‚˜๋ˆ„์–ด ๋–จ์–ด์ง€์ง€ ์•Š์•„์„œ last batch ์˜ ๊ฐœ์ˆ˜๋Š” 1๊ฐœ ์ž„์„ ์•Œ ์ˆ˜ ์žˆ๋‹ค. 

    ๋˜ํ•œ DataLoader์—์„œ shuffle=True ๋กœ ์„ค์ •ํ–ˆ๊ธฐ ๋•Œ๋ฌธ์— ์ƒ˜ํ”Œ์˜ ์ˆœ์„œ๋ฅผ ๋žœ๋ค์œผ๋กœ ์„ž์–ด์„œ ์ถ”์ถœํ•˜์˜€๋‹ค. 

     

    1. ๋ฐฐ์น˜ 1: [[1, 2], [7, 8]]
    2. ๋ฐฐ์น˜ 2: [[3, 4], [9, 10]]
    3. ๋ฐฐ์น˜ 3: [[5, 6]] (์ฐธ๊ณ : ์ด ์ƒ˜ํ”Œ ์ˆ˜๊ฐ€ ๋ฐฐ์น˜ ํฌ๊ธฐ๋กœ ๋‚˜๋ˆŒ ์ˆ˜ ์—†๋Š” ๊ฒฝ์šฐ ๋งˆ์ง€๋ง‰ ๋ฐฐ์น˜๋Š” ๋” ์ ์€ ์ˆ˜์˜ ์ƒ˜ํ”Œ์„ ๊ฐ€์งˆ ์ˆ˜ ์žˆ์Œ)

    next(iteration(DataLoader))

    • for๋ฌธ์„ ์‚ฌ์šฉํ•˜์ง€ ์•Š๊ณ ๋„ ์œ„ ์ฝ”๋“œ๋ฅผ ๋ฐ˜๋ณตํ•˜๋ฉด ๊ณ„์†์ ์œผ๋กœ ๋‹ค์Œ ๋ฐฐ์น˜๋ฅผ ๋ณผ ์ˆ˜ ์žˆ๋‹ค. 
    • ๋ฐ์ดํ„ฐ ์ถœ๋ ฅ์ด ์ž˜ ๋‚˜์˜ค๋Š” ์ง€ ํŒŒ์•…ํ•  ๋•Œ ์“ธ ์ˆ˜ ์žˆ๋‹ค. 
    # next(iter(DataLoader))
    input, target = next(iter(dataloader))
    print(input)
    print(target)
    tensor([[5, 6],
            [7, 8]])
    tensor([0, 1])
Designed by Tistory.