-
Notifications
You must be signed in to change notification settings - Fork 0
/
ddpm.py
93 lines (71 loc) · 4.13 KB
/
ddpm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import torch
import numpy as np
class DDPMSampler:
def __init__(self, generator: torch.Generator, num_training_steps=1000, beta_start:float=0.00085, beta_end:float=0.0120):
self.betas = torch.linspace(beta_start**0.5,beta_end**0.5,num_training_steps,dtype=torch.float32)**2
self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas,dim=0)
self.one = torch.tensor(1.0)
self.generator = generator
self.num_train_timesteps = num_training_steps
self.timesteps = torch.from_numpy(np.arange(0,num_training_steps)[::-1].copy())
def set_inference_timesteps(self, num_inference_steps=50):
self.num_inference_steps = num_inference_steps
step_ratio = self.num_train_timesteps//self.num_inference_steps
timesteps = (np.arange(0,num_inference_steps)*step_ratio).round()[::-1].copy().astype(np.int64)
self.timesteps = torch.from_numpy(timesteps)
def _get_previous_timestep(self,timestep:int)->int:
prev_t = timestep - self.num_train_timesteps//self.num_inference_steps
return prev_t
def _get_variance(self, timestep: int)-> torch.Tensor:
prev_t = self._get_previous_timestep(timestep)
alpha_prod_t = self.alphas_cumprod[timestep]
alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one
current_beta_t = 1-alpha_prod_t/alpha_prod_t_prev
variance = (1-alpha_prod_t_prev)/(1-alpha_prod_t)*current_beta_t
variance=torch.clamp(variance,min=1e-20)
return variance
def set_strength(self, strength=1):
start_step = self.num_inference_steps - int(self.num_inference_steps*strength)
self.timesteps = self.timesteps[start_step:]
self.start_step = start_step
def step(self, timestep: int, latents: torch.Tensor, model_output: torch.Tensor):
t = timestep
prev_t = self._get_previous_timestep(t)
alpha_prod_t = self.alphas_cumprod[t]
alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one
beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1- alpha_prod_t_prev
current_alpha_t = alpha_prod_t/alpha_prod_t_prev
current_beta_t = 1-current_alpha_t
pred_original_sample = (latents-beta_prod_t**(0.5)*model_output)/alpha_prod_t**(0.5)
pred_original_sample_coeff = (alpha_prod_t_prev**(0.5)*current_beta_t)/beta_prod_t
current_sample_coeff = current_alpha_t**(0.5)*beta_prod_t_prev/beta_prod_t
pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * latents
variance = 0
if t>0:
device = model_output.device
noise = torch.randn(model_output.shape, generator=self.generator,device=device,dtype=model_output.dtype)
variance = (self._get_variance(t)**0.5)*noise
pred_prev_sample = pred_prev_sample + variance
return pred_prev_sample
def add_noise(
self,
original_samples: torch.FloatTensor,
timesteps: torch.IntTensor,
)->torch.FloatTensor:
alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device,dtype=original_samples.dtype)
timesteps = timesteps.to(original_samples.device)
sqrt_alpha_prod = alphas_cumprod[timesteps]**0.5
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
while len(sqrt_alpha_prod.shape)<len(original_samples.shape):
sqrt_alpha_prod=sqrt_alpha_prod.unsqueeze(-1)
print(sqrt_alpha_prod.shape)
sqrt_one_minus_alpha_prod = (1-alphas_cumprod[timesteps])**0.5
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
while len(sqrt_one_minus_alpha_prod.shape)<len(original_samples.shape):
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
print(sqrt_one_minus_alpha_prod.shape)
noise = torch.randn(original_samples.shape,generator=self.generator,device=original_samples.device,dtype=original_samples.dtype)
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples