ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • torch.cumprod() w.r.t. diffusion noise scheduling
    AI\ML\DL/Pytorch 2023. 12. 3. 19:25
    ๋ฐ˜์‘ํ˜•

     

     

    [cumprod: cumulative product ํ•จ์ˆ˜]

     

    torch.cumprod ํ•จ์ˆ˜๋Š” PyTorch ์—์„œ ์ œ๊ณตํ•˜๋Š” ํ•จ์ˆ˜๋กœ, ํ…์„œ์˜ ๋ˆ„์  ๊ณฑ์„ ์š”์†Œ๋ณ„๋กœ(element-wise) ๊ณ„์‚ฐํ•ด์ค€๋‹ค.

    ์‚ฌ์šฉ๋ฒ•์€ ๋‹ค์Œ๊ณผ ๊ฐ™๋‹ค. 

     

    import torch
    
    tensor = torch.tensor([1, 2, 3, 4])
    cumprod_result = torch.cumprod(tensor, dim=0)
    print(cumprod_result)
    tensor([ 1,  2,  6, 24])

     

    ์˜ˆ๋ฅผ ๋“ค์–ด, 'torch.tensor([1,2,3,4])' ํ…์„œ์— ๋Œ€ํ•ด torch.cumprod ํ•จ์ˆ˜๋ฅผ dim=0 (์ฒซ๋ฒˆ์งธ ์ฐจ์›)์œผ๋กœ ์ ์šฉํ•˜๋ฉด, ์ด ํ•จ์ˆ˜๋Š” ๊ฐ ์š”์†Œ์˜ ๋ˆ„์ ๊ณฑ์„ ๊ณ„์‚ฐํ•˜์—ฌ [1,2,6,24] ๋ฅผ ๋ฐ˜ํ™˜ํ•œ๋‹ค. 

    1

    1x2=2

    1x2x3=6

    1x2x3x4=24

     


     

    [Diffusion model์—์„œ ์“ฐ์ž„์ƒˆ]

    class Diffusion:
        def __init__(self, noise_steps=1000, beta_start=1e-4, beta_end=0.02, img_size=256, device="cuda"):
            self.noise_steps = noise_steps
            self.beta_start = beta_start
            self.beta_end = beta_end
            self.img_size = img_size
            self.device = device
    
            self.beta = self.prepare_noise_schedule().to(device)
            self.alpha = 1. - self.beta
            self.alpha_hat = torch.cumprod(self.alpha, dim=0)
    
        def prepare_noise_schedule(self):
            return torch.linspace(self.beta_start, self.beta_end, self.noise_steps) # ๊ฐ timestep ์— ํ•ด๋‹นํ•˜๋Š” beta๊ฐ’์„ ๋ฐ˜ํ™˜

     

     

    ๋””ํ“จ์ „ ๋ชจ๋ธ์—์„œ๋Š” forward ๋ฐ backward ๊ณผ์ •์—์„œ ๊ฐ ๋‹จ๊ณ„๋ณ„๋กœ ๋…ธ์ด์ฆˆ ์ˆ˜์ค€์ด ๋‹ค๋ฅธ๋ฐ, ์ „ ๋‹จ๊ณ„์˜ ์ด๋ฏธ์ง€์—์„œ ์ ์ง„์ ์œผ๋กœ ๋…ธ์ด์ฆˆ๋ฅผ ์ถ”๊ฐ€ํ•˜๊ฑฐ๋‚˜ ์ œ๊ฑฐํ•˜๊ธฐ ๋•Œ๋ฌธ์ด๋‹ค. Forward process๋กœ ์˜ˆ๋ฅผ ๋“ค์ž๋ฉด ์ด์ „ ์Šคํ…์˜ ์ด๋ฏธ์ง€์—์„œ ๋‹ค์Œ ์Šคํ…์˜ ์ด๋ฏธ์ง€๋กœ noise๋ฅผ ๊ฐ€ํ•  ๋•Œ ์ด๋ฏธ์ง€์™€ ๋…ธ์ด์ฆˆ์—๋Š” ๊ฐ๊ฐ $\sqrt{\alpha_t}$, $\sqrt{1-\alpha_{t}}$ ๋ผ๋Š” ๊ณ„์ˆ˜(coefficient)๊ฐ€ ๊ณฑํ•ด์ง„๋‹ค. 

    $$X_t=\sqrt{\alpha_t}X_{t-1}+\sqrt{1-\alpha_{t}}\epsilon_{t-1}, \  \ \alpha _{t}=1-\beta_{t}$$

    ๊ทธ๋Ÿฐ๋ฐ ์ด๋•Œ t-1์—์„œ t๋กœ ๊ฐ€์ง€ ์•Š๊ณ  coefficient๋“ค์˜ ๋ˆ„์  ๊ณฑ์„ ์‚ฌ์šฉํ•˜์—ฌ t=0์—์„œ t=T ๋กœ ๊ฐ€๋Š” ๋ฐฉ๋ฒ•๋„ ์žˆ๋‹ค. 

    ์œ„์˜ forward process ์‹์— t=0, 1, 2, ... ๋ฅผ ๋Œ€์ž…ํ•ด๋ณด๋ฉด ์•„๋ž˜์™€ ๊ฐ™์€ ์‹๋“ค์„ ์–ป์„ ์ˆ˜ ์žˆ๋‹ค. 

     

     

    ์—ฌ๊ธฐ์„œ ์‹ โ‘ก์— ์‹ โ‘ ์„ ๋Œ€์ž…ํ•ด๋ณด๋ฉด ์•„๋ž˜์™€ ๊ฐ™์ด ์“ธ ์ˆ˜ ์žˆ๋‹ค. 

     

     

    ์—ฌ๊ธฐ์„œ $\epsilon_{0}$๊ณผ $\epsilon_{1}$๋Š” ํ‰๊ท ์ด 0, ๋ถ„์‚ฐ์ด 1์ธ ๊ฐ€์šฐ์‹œ์•ˆ ๋ถ„ํฌ๋กœ๋ถ€ํ„ฐ ํ‘œ๋ณธ์˜ ๊ฐœ์ˆ˜(=์ž…๋ ฅ ๋ฐ์ดํ„ฐ์˜ shape) ๋งŒํผ ๊ฐ’์„ ๋ฌด์ž‘์œ„๋กœ ์ถ”์ถœํ•˜์—ฌ ์–ป์€ ํ…์„œ์ด๋‹ค. (์ด๋ก ์ƒ ๊ฐ epsilon์ด ํ‘œ์ค€๊ฐ€์šฐ์‹œ์•ˆ๋ถ„ํฌ๋ผ๊ณ  ๊ฐ€์ •) ์ด๋•Œ ๋‘ ๋ถ„ํฌ $\epsilon_{0}, \epsilon_{1}$ ๋ฅผ ํ•ฉ์น  ์ˆ˜ ์žˆ๋‹ค. ๋ถ„์‚ฐ์˜ ์„ฑ์งˆ์— ๋”ฐ๋ผ epsilon ์•ž์— ๋ถ™์€ ๊ณ„์ˆ˜์˜ ์ œ๊ณฑ์ด ๋ถ„์‚ฐ์ด ๋œ๋‹ค. 

    ์ฆ‰, $(\alpha_2)(1-\alpha _{1})$ ๊ฐ€ $\epsilon_0$์˜ ๋ถ„์‚ฐ์ด ๋˜๊ณ  $(1-\alpha_{2})$๊ฐ€ $\epsilon_1$์˜ ๋ถ„์‚ฐ์ด ๋œ๋‹ค. 

    ๋‘ ๊ฐœ์˜ ๋ถ„์‚ฐ์„ ๊ฐ€์ง„ ๊ฐ€์šฐ์‹œ์•ˆ ๋ถ„ํฌ๋ฅผ ํ•ฉํ•˜๋ฉด ๋‘ ๋ถ„ํฌ์˜ ๋ง์…ˆ์„ ๋ถ„์‚ฐ์œผ๋กœ ๊ฐ€์ง€๋Š” ์ƒˆ๋กœ์šด ๋ถ„ํฌ๋ฅผ ๋งŒ๋“ค ์ˆ˜์žˆ๋‹ค. 

     

    $$\bar{ \alpha _{t}}=\prod_{i=1}^{t}\alpha _{i}=\prod_{i=1}^{t}(1-\beta _{i})$$

    ๋”ฐ๋ผ์„œ X_t ๋Š” alpha ์˜ ๋ˆ„์ ๊ณฑ๋“ค์˜ ์ œ๊ณฑ๊ทผ์„ X_0 ์— ๊ณฑํ•˜๊ณ  1-(alpha ๋ˆ„์ ๊ณฑ) ์˜ ์ œ๊ณฑ๊ทผ์„ e_0์— ๊ณฑํ•˜์—ฌ ์–ป์„ ์ˆ˜ ์žˆ๋‹ค. 

    alpha๋“ค์˜ ๋ˆ„์ ๊ณฑ์„ ํ•˜๊ธฐ ์œ„ํ•ด์„œ torch.cumprod ํ•จ์ˆ˜๋ฅผ ์‚ฌ์šฉํ•œ๋‹ค.

     

    self.alpha_hat = torch.cumprod(self.alpha, dim=0)

     

    ์ด๋กœ์จ ๊ฐ ๋‹จ๊ณ„์—์„œ์˜ ๋…ธ์ด์ฆˆ ์ˆ˜์ค€์„ ๊ฒฐ์ •ํ•  ์ˆ˜ ์žˆ๋‹ค. 

     

    Diffusion์˜ ์ž์„ธํ•œ ์ˆ˜์‹ ๊ณผ์ •์€ ์•„๋ž˜ ๋งํฌ๋ฅผ ๋ณด๋ฉด ์ž˜ ์ •๋ฆฌ๋˜์–ด ์žˆ๋‹ค. 

    https://lilianweng.github.io/posts/2021-07-11-diffusion-models/

     

    What are Diffusion Models?

    [Updated on 2021-09-19: Highly recommend this blog post on score-based generative modeling by Yang Song (author of several key papers in the references)]. [Updated on 2022-08-27: Added classifier-free guidance, GLIDE, unCLIP and Imagen. [Updated on 2022-08

    lilianweng.github.io

     

Designed by Tistory.