ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • ๋ถ„์‚ฐ ๋ฐ ๋ณ‘๋ ฌ ํ•™์Šต
    AI\ML\DL/Pytorch 2023. 12. 13. 15:27
    ๋ฐ˜์‘ํ˜•

    ๋ถ„์‚ฐ ํ•™์Šต(Distributed training)์€ ํ•™์Šต ์›Œํฌ๋กœ๋“œ๋ฅผ ์—ฌ๋Ÿฌ ์ž‘์—…์ž ๋…ธ๋“œ์— ๋ถ„์‚ฐ์‹œ์ผœ ํ›ˆ๋ จ ์†๋„์™€ ๋ชจ๋ธ ์ •ํ™•๋„๋ฅผ ํฌ๊ฒŒ ํ–ฅ์ƒ์‹œํ‚ค๋Š” ๋ชจ๋ธ ํ•™์Šต ํŒจ๋Ÿฌ๋‹ค์ž„์ด๋‹ค. PyTorch์—์„œ ๋ถ„์‚ฐ ํ•™์Šต์„ ์ˆ˜ํ–‰ํ•  ์ˆ˜ ์žˆ๋Š” ๋ช‡ ๊ฐ€์ง€ ๋ฐฉ๋ฒ•์ด ์žˆ์œผ๋ฉฐ ๊ฐ ๋ฐฉ๋ฒ•์€ ์‚ฌ์šฉ ์šฉ๋„๋ณ„๋กœ ์žฅ์ ์„ ๊ฐ€์ง„๋‹ค. 

    • DistributedDataParallel (DDP)
    • Fully Shared Data Parallel (FSDP)
    • Remote Procedure Call (RPC) distributed traininng 
    • Custom Extensions

     

    ๋น„๋ถ„์‚ฐ ํ•™์Šต์€ ๋‹จ์ผ GPU์—์„œ ๋ชจ๋ธ์„ ํ•™์Šต์‹œํ‚จ๋‹ค. ํ•™์Šต ๊ณผ์ •์€ ๋‹ค์Œ๊ณผ ๊ฐ™๋‹ค. 

    1. DataLoader๋กœ๋ถ€ํ„ฐ ์ž…๋ ฅ ๋ฐฐ์น˜๋ฅผ ๋ฐ›๋Š”๋‹ค. 

    2. Forward pass๋ฅผ ํ†ตํ•ด loss ๊ณ„์‚ฐ

    3. Backward pass๋ฅผ ํ†ตํ•ด ๋งค๊ฐœ๋ณ€์ˆ˜ ๊ธฐ์šธ๊ธฐ ๊ณ„์‚ฐ

    4. Optimizer๊ฐ€ ๊ธฐ์šธ๊ธฐ๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๋ชจ๋ธ ์—…๋ฐ์ดํŠธ

     

    ๋ฐ˜๋ฉด ๋ถ„์‚ฐ ํ•™์Šต(DDP)๋Š” ํ•™์Šต ์ž‘์—…์„ ์—ฌ๋Ÿฌ๊ฐœ์˜ GPU๋กœ ๋ถ„์‚ฐํ•œ๋‹ค. ๊ฐ GPU ํ”„๋กœ์„ธ์Šค๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™์ด ๋ชจ๋ธ์˜ ๋กœ์ปฌ ๋ณต์‚ฌ๋ณธ์„ ๊ฐ€์ง€๊ณ  ์žˆ๋‹ค. 

     

    ์ด๋•Œ ๋ชจ๋“  ๋ชจ๋ธ ๋ณต์ œ๋ณธ๊ณผ ์˜ตํ‹ฐ๋งˆ์ด์ €๋Š” ์„œ๋กœ ๋™์ผํ•˜๋ฉฐ, ์ดˆ๊ธฐ ๋ชจ๋ธ ํŒŒ๋ผ๋ฏธํ„ฐ๋ฟ๋งŒ ์•„๋‹ˆ๋ผ ์˜ตํ‹ฐ๋งˆ์ด์ €๋„ ๋™์ผํ•œ ๋žœ๋ค ์‹œ๋“œ๋ฅผ ์‚ฌ์šฉํ•œ๋‹ค. DDP๋Š” ํ•™์Šต ๊ณผ์ • ์ „์ฒด์—์„œ ์ด ๋™๊ธฐํ™”๋ฅผ ์œ ์ง€ํ•œ๋‹ค. 

     

    ๋ฐ์ดํ„ฐ ๋ณ‘๋ ฌ

    ๊ฐ GPU ํ”„๋กœ์„ธ์Šค๋Š” ๋™์ผํ•œ ๋ชจ๋ธ์„ ๊ฐ€์ง€๊ณ  ์žˆ์ง€๋งŒ DistributedSampler์— ์˜ํ•ด ๋ฐ์ดํ„ฐ๋กœ๋”๋กœ๋ถ€ํ„ฐ ๋ฐ›์€ input batch ๋ฅผ ๋‚˜๋ˆ ๋ฐ›๊ธฐ ๋•Œ๋ฌธ์— ๊ฐ ๋ชจ๋ธ์ด ์ „๋‹ฌ๋ฐ›๋Š” ๋ฐ์ดํ„ฐ๋Š” ๋‹ค๋ฅด๋‹ค. (์ƒ˜ํ”Œ๋Ÿฌ๋Š” ๊ฐ ํ”„๋กœ์„ธ์Šค๊ฐ€ ๋‹ค๋ฅธ ์ž…๋ ฅ์„ ๋ฐ›๋„๋ก ํ•˜๋Š” ์—ญํ• ์„ ํ•œ๋‹ค.) ์ด ๊ณผ์ •์„ ๋ฐ์ดํ„ฐ ๋ณ‘๋ ฌ ์ฒ˜๋ฆฌ๋ผ๊ณ  ํ•œ๋‹ค. 

     

     

    ๊ฐ ํ”„๋กœ์„ธ์Šค์—์„œ ๋ชจ๋ธ(worker)์€ Localํ•˜๊ฒŒ forward ๋ฐ backward pass๋ฅผ ์‹คํ–‰ํ•œ๋‹ค. ์ด๋•Œ ์„œ๋กœ ๋‹ค๋ฅธ ์ž…๋ ฅ์„ ์ฒ˜๋ฆฌํ•˜๊ธฐ ๋•Œ๋ฌธ์— ๋ˆ„์ ๋œ gradient๊ฐ’๋„ ๋‹ค๋ฅด๋‹ค. ์ด ์‹œ์ ์—์„œ optimizer ๋‹จ๊ณ„๋ฅผ ์‹คํ–‰ํ•˜๋ฉด ํ”„๋กœ์„ธ์Šค๊ฐ„ ์„œ๋กœ ๋‹ค๋ฅธ ํŒŒ๋ผ๋ฏธํ„ฐ๊ฐ€ ์ƒ์„ฑ๋œ๋‹ค. 

    DDP ๋Š” ์ด ์‹œ์ ์—์„œ synchronization ์„ ์‹œ์ž‘ํ•œ๋‹ค. ๊ฐ ํ”„๋กœ์„ธ์Šค(๋ชจ๋ธ)๋กœ๋ถ€ํ„ฐ ํ•™์Šต๋œ gradient๋Š” bucket-allreduce ์•Œ๊ณ ๋ฆฌ์ฆ˜์„ ํ™œ์šฉํ•˜์—ฌ ์ง‘๊ณ„๋œ๋‹ค. ์ด ๋™๊ธฐํ™” ๊ณผ์ •์€ ๋‹ค๋ฅธ ํ”„๋กœ์„ธ์Šค์˜ gradient๊ฐ€ ๊ณ„์‚ฐ๋  ๋•Œ๊นŒ์ง€ ๊ธฐ๋‹ค๋ฆฌ์ง€ ์•Š๊ณ  ์•„์ง ์‹คํ–‰์ค‘์ธ ๋ง์„ ๋”ฐ๋ผ ํ†ต์‹ ์„ ์‹œ์ž‘ํ•˜๊ณ  ๋™๊ธฐํ™”๋ฅผ ํ•œ๋‹ค. ์ด๋ฅผ ํ†ตํ•ด GPU๊ฐ€ ํ•ญ์‹œ ์ž‘๋™ํ•˜๊ฒŒ ๋˜๊ณ  idle ์ƒํƒœ(์‚ฌ์šฉ๊ฐ€๋Šฅํ•œ ์ƒํƒœ์ด์ง€๋งŒ ๋…ธ๋Š” ์ƒํƒœ) ๊ฐ€ ๋˜์ง€ ์•Š๋Š”๋‹ค. Optimizer ๋‹จ๊ณ„์—์„œ๋Š” ๋™๊ธฐํ™”๋˜์–ด ์ง‘๊ณ„๋œ ๋ชจ๋ธ๋“ค์˜ ํŒŒ๋ผ๋ฏธํ„ฐ๊ฐ€ ํ•œ ๋ฒˆ์— ์—…๋ฐ์ดํŠธ๋œ๋‹ค. 

Designed by Tistory.