-
torch.utils.data.DataLoaderAI\ML\DL/Pytorch 2023. 7. 25. 18:41๋ฐ์ํ
DataLoader
DataLoader ๋ ๋ฐ์ดํฐ์ ์ ๋ฏธ๋ ๋ฐฐ์น ๋จ์๋ก ๋ถํ ํ์ฌ ํ์ตํ๋ก์ธ์ค์ ์ ๊ณตํด์ฃผ๋ ์ญํ ์ ํ๋ค.
๋ฐ์ดํฐ๋ก๋ (DataLoader) ๊ฐ์ฒด๋ ํ์ต์ ์ฌ์ฉ๋ ๋ฐ์ดํฐ ์ ์ฒด๋ฅผ ๋ณด๊ดํ๋ค๊ฐ ๋ชจ๋ธ ํ์ต์ ํ ๋ ๋ฐฐ์น ํฌ๊ธฐ๋งํผ ๋ฐ์ดํฐ๋ฅผ ๊บผ๋ด์ ์ฌ์ฉํ๋ค.
์ด๋ ์ฃผ์ํ ๊ฒ์ ๋ฐ์ดํฐ๋ฅผ ๋ฏธ๋ฆฌ ์๋ผ ๋๋ ๊ฒ์ด ์๋๋ผ ๋ด๋ถ์ ์ผ๋ก ๋ฐ๋ณต์ (iterator)์ ํฌํจ๋ ์ธ๋ฑ์ค (index)๋ฅผ ์ด์ฉํ์ฌ ๋ฐฐ์น ํฌ๊ธฐ๋งํผ ๋ฐ์ดํฐ๋ฅผ ๋ฐํํ๋ค๋ ๊ฒ์ด๋ค.
- DataLoader ์ฌ์ฉ ์์
์๋์ ์ฝ๋๋ ๊ฐ๊ฐ 5๊ฐ์ ์ํ์ ๊ฐ์ง ๋ฐ์ดํฐ (data; example data) ์ ๋ ์ด๋ธ (labels) ํ ์๋ก dataset์ ๋ง๋ ํ์
dataloader๋ฅผ ๋ง๋ค์ด์ ๋ฐฐ์น ํฌ๊ธฐ๋งํผ ๋ฐ์ดํฐ๋ฅผ ๋น๋ณต์์ถ์ถํ์ฌ ํ๋ฆฐํธํ๋ ๊ฐ๋จํ ์ฝ๋์ด๋ค.
์ด๋ฅผ ์ํด ํ ์๋ csvํ์ผ์ด ๋ ์ ์๋ dataset์ torch.utils.data.DataLoader์ ํ๋ผ๋ฏธํฐ๋ก ์ ๋ฌํ์๋ค.
import torch from torch.utils.data import DataLoader, TensorDataset # Example data and labels data = torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]]) labels = torch.tensor([0, 1, 0, 1, 0]) # Creating a TensorDataset dataset = TensorDataset(data, labels) # Creating a DataLoader batch_size = 2 dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) # Iterating over the DataLoader for minibatch in dataloader: inputs, targets = minibatch # Perform operations with the current mini-batch, e.g., model training print("Input batch:", inputs) print("Target batch:", targets)
DataLoader๋ iterable (๋ฐ๋ณต ๊ฐ๋ฅ) ํ ๊ฐ์ฒด์ด๋ค. ๋ฐ๋ผ์ for๋ฌธ์ ์ฌ์ฉํ์ฌ 2๊ฐ ๋ฐฐ์น๋งํผ ์ถ์ถ๋ data์ label์ ํ๋ฆฐํธํ ์ ์๋ค.
์ถ๋ ฅ์ ๋ค์๊ณผ ๊ฐ๋ค.
Input batch: tensor([[1, 2], [7, 8]]) Target batch: tensor([0, 1]) Input batch: tensor([[ 3, 4], [ 9, 10]]) Target batch: tensor([1, 0]) Input batch: tensor([[5, 6]]) Target batch: tensor([0])
์ฐธ๊ณ ๋ก ์ด ์ํ ์๊ฐ ๋ฐฐ์น ํฌ๊ธฐ๋ก ๋๋์ด ๋จ์ด์ง์ง ์๋ ๊ฒฝ์ฐ ๋๋จธ์ง ๋งํผ์ ๋ ์ ์ ์์ ์ํ์ ๊ฐ์ง ์ ์๋ค.
์ด ์ฝ๋์์๋ ์ด ์ํ ๊ฐ์์ธ 5๊ฐ๊ฐ ๋ฐฐ์น์ฌ์ด์ฆ์ธ 2๋ก ๋๋์ด ๋จ์ด์ง์ง ์์์ last batch ์ ๊ฐ์๋ 1๊ฐ ์์ ์ ์ ์๋ค.
๋ํ DataLoader์์ shuffle=True ๋ก ์ค์ ํ๊ธฐ ๋๋ฌธ์ ์ํ์ ์์๋ฅผ ๋๋ค์ผ๋ก ์์ด์ ์ถ์ถํ์๋ค.
- ๋ฐฐ์น 1: [[1, 2], [7, 8]]
- ๋ฐฐ์น 2: [[3, 4], [9, 10]]
- ๋ฐฐ์น 3: [[5, 6]] (์ฐธ๊ณ : ์ด ์ํ ์๊ฐ ๋ฐฐ์น ํฌ๊ธฐ๋ก ๋๋ ์ ์๋ ๊ฒฝ์ฐ ๋ง์ง๋ง ๋ฐฐ์น๋ ๋ ์ ์ ์์ ์ํ์ ๊ฐ์ง ์ ์์)
next(iteration(DataLoader))
- for๋ฌธ์ ์ฌ์ฉํ์ง ์๊ณ ๋ ์ ์ฝ๋๋ฅผ ๋ฐ๋ณตํ๋ฉด ๊ณ์์ ์ผ๋ก ๋ค์ ๋ฐฐ์น๋ฅผ ๋ณผ ์ ์๋ค.
- ๋ฐ์ดํฐ ์ถ๋ ฅ์ด ์ ๋์ค๋ ์ง ํ์ ํ ๋ ์ธ ์ ์๋ค.
# next(iter(DataLoader)) input, target = next(iter(dataloader)) print(input) print(target)
tensor([[5, 6], [7, 8]]) tensor([0, 1])
'AI\ML\DL > Pytorch' ์นดํ ๊ณ ๋ฆฌ์ ๋ค๋ฅธ ๊ธ
CNN ๋ชจ๋ธ์ classifier๋จ์์ FC layer์ ์ ๋ ฅ ๋ ธ๋ ๊ฐ์ (0) 2023.08.03 PyTorch์ Numpy์์์ ์ด๋ฏธ์ง ๋ฐ์ดํฐ ํฌ๋งท ์ฐจ์ด (0) 2023.08.01 nn.Linear (0) 2023.07.14 5๏ธโฃ Pytorch ๊ธฐ์ด_autograd (0) 2023.07.14