-
๋ถ์ฐ ๋ฐ ๋ณ๋ ฌ ํ์ตAI\ML\DL/Pytorch 2023. 12. 13. 15:27๋ฐ์ํ
๋ถ์ฐ ํ์ต(Distributed training)์ ํ์ต ์ํฌ๋ก๋๋ฅผ ์ฌ๋ฌ ์์ ์ ๋ ธ๋์ ๋ถ์ฐ์์ผ ํ๋ จ ์๋์ ๋ชจ๋ธ ์ ํ๋๋ฅผ ํฌ๊ฒ ํฅ์์ํค๋ ๋ชจ๋ธ ํ์ต ํจ๋ฌ๋ค์์ด๋ค. PyTorch์์ ๋ถ์ฐ ํ์ต์ ์ํํ ์ ์๋ ๋ช ๊ฐ์ง ๋ฐฉ๋ฒ์ด ์์ผ๋ฉฐ ๊ฐ ๋ฐฉ๋ฒ์ ์ฌ์ฉ ์ฉ๋๋ณ๋ก ์ฅ์ ์ ๊ฐ์ง๋ค.
- DistributedDataParallel (DDP)
- Fully Shared Data Parallel (FSDP)
- Remote Procedure Call (RPC) distributed traininng
- Custom Extensions
๋น๋ถ์ฐ ํ์ต์ ๋จ์ผ GPU์์ ๋ชจ๋ธ์ ํ์ต์ํจ๋ค. ํ์ต ๊ณผ์ ์ ๋ค์๊ณผ ๊ฐ๋ค.
1. DataLoader๋ก๋ถํฐ ์ ๋ ฅ ๋ฐฐ์น๋ฅผ ๋ฐ๋๋ค.
2. Forward pass๋ฅผ ํตํด loss ๊ณ์ฐ
3. Backward pass๋ฅผ ํตํด ๋งค๊ฐ๋ณ์ ๊ธฐ์ธ๊ธฐ ๊ณ์ฐ
4. Optimizer๊ฐ ๊ธฐ์ธ๊ธฐ๋ฅผ ์ฌ์ฉํ์ฌ ๋ชจ๋ธ ์ ๋ฐ์ดํธ
๋ฐ๋ฉด ๋ถ์ฐ ํ์ต(DDP)๋ ํ์ต ์์ ์ ์ฌ๋ฌ๊ฐ์ GPU๋ก ๋ถ์ฐํ๋ค. ๊ฐ GPU ํ๋ก์ธ์ค๋ ๋ค์๊ณผ ๊ฐ์ด ๋ชจ๋ธ์ ๋ก์ปฌ ๋ณต์ฌ๋ณธ์ ๊ฐ์ง๊ณ ์๋ค.
์ด๋ ๋ชจ๋ ๋ชจ๋ธ ๋ณต์ ๋ณธ๊ณผ ์ตํฐ๋ง์ด์ ๋ ์๋ก ๋์ผํ๋ฉฐ, ์ด๊ธฐ ๋ชจ๋ธ ํ๋ผ๋ฏธํฐ๋ฟ๋ง ์๋๋ผ ์ตํฐ๋ง์ด์ ๋ ๋์ผํ ๋๋ค ์๋๋ฅผ ์ฌ์ฉํ๋ค. DDP๋ ํ์ต ๊ณผ์ ์ ์ฒด์์ ์ด ๋๊ธฐํ๋ฅผ ์ ์งํ๋ค.
๋ฐ์ดํฐ ๋ณ๋ ฌ
๊ฐ GPU ํ๋ก์ธ์ค๋ ๋์ผํ ๋ชจ๋ธ์ ๊ฐ์ง๊ณ ์์ง๋ง DistributedSampler์ ์ํด ๋ฐ์ดํฐ๋ก๋๋ก๋ถํฐ ๋ฐ์ input batch ๋ฅผ ๋๋ ๋ฐ๊ธฐ ๋๋ฌธ์ ๊ฐ ๋ชจ๋ธ์ด ์ ๋ฌ๋ฐ๋ ๋ฐ์ดํฐ๋ ๋ค๋ฅด๋ค. (์ํ๋ฌ๋ ๊ฐ ํ๋ก์ธ์ค๊ฐ ๋ค๋ฅธ ์ ๋ ฅ์ ๋ฐ๋๋ก ํ๋ ์ญํ ์ ํ๋ค.) ์ด ๊ณผ์ ์ ๋ฐ์ดํฐ ๋ณ๋ ฌ ์ฒ๋ฆฌ๋ผ๊ณ ํ๋ค.
๊ฐ ํ๋ก์ธ์ค์์ ๋ชจ๋ธ(worker)์ Localํ๊ฒ forward ๋ฐ backward pass๋ฅผ ์คํํ๋ค. ์ด๋ ์๋ก ๋ค๋ฅธ ์ ๋ ฅ์ ์ฒ๋ฆฌํ๊ธฐ ๋๋ฌธ์ ๋์ ๋ gradient๊ฐ๋ ๋ค๋ฅด๋ค. ์ด ์์ ์์ optimizer ๋จ๊ณ๋ฅผ ์คํํ๋ฉด ํ๋ก์ธ์ค๊ฐ ์๋ก ๋ค๋ฅธ ํ๋ผ๋ฏธํฐ๊ฐ ์์ฑ๋๋ค.
DDP ๋ ์ด ์์ ์์ synchronization ์ ์์ํ๋ค. ๊ฐ ํ๋ก์ธ์ค(๋ชจ๋ธ)๋ก๋ถํฐ ํ์ต๋ gradient๋ bucket-allreduce ์๊ณ ๋ฆฌ์ฆ์ ํ์ฉํ์ฌ ์ง๊ณ๋๋ค. ์ด ๋๊ธฐํ ๊ณผ์ ์ ๋ค๋ฅธ ํ๋ก์ธ์ค์ gradient๊ฐ ๊ณ์ฐ๋ ๋๊น์ง ๊ธฐ๋ค๋ฆฌ์ง ์๊ณ ์์ง ์คํ์ค์ธ ๋ง์ ๋ฐ๋ผ ํต์ ์ ์์ํ๊ณ ๋๊ธฐํ๋ฅผ ํ๋ค. ์ด๋ฅผ ํตํด GPU๊ฐ ํญ์ ์๋ํ๊ฒ ๋๊ณ idle ์ํ(์ฌ์ฉ๊ฐ๋ฅํ ์ํ์ด์ง๋ง ๋ ธ๋ ์ํ) ๊ฐ ๋์ง ์๋๋ค. Optimizer ๋จ๊ณ์์๋ ๋๊ธฐํ๋์ด ์ง๊ณ๋ ๋ชจ๋ธ๋ค์ ํ๋ผ๋ฏธํฐ๊ฐ ํ ๋ฒ์ ์ ๋ฐ์ดํธ๋๋ค.
'AI\ML\DL > Pytorch' ์นดํ ๊ณ ๋ฆฌ์ ๋ค๋ฅธ ๊ธ
torch.log_softmax (0) 2024.03.11 torch.cumprod() w.r.t. diffusion noise scheduling (1) 2023.12.03 torch.cat vs torch.stack (0) 2023.09.29 CNN ๋ชจ๋ธ์ classifier๋จ์์ FC layer์ ์ ๋ ฅ ๋ ธ๋ ๊ฐ์ (0) 2023.08.03