@@ -5829,8 +5829,8 @@ def save_sd_model_on_train_end_common(
5829
5829
5830
5830
5831
5831
def get_timesteps (min_timestep , max_timestep , b_size , device ):
5832
- timesteps = torch .randint (min_timestep , max_timestep , (b_size ,), device = device )
5833
- timesteps = timesteps .long ()
5832
+ timesteps = torch .randint (min_timestep , max_timestep , (b_size ,), device = "cpu" )
5833
+ timesteps = timesteps .long (). to ( device )
5834
5834
return timesteps
5835
5835
5836
5836
@@ -5875,8 +5875,8 @@ def get_huber_threshold(args, timesteps: torch.Tensor, noise_scheduler) -> torch
5875
5875
alpha = - math .log (args .huber_c ) / noise_scheduler .config .num_train_timesteps
5876
5876
result = torch .exp (- alpha * timesteps ) * args .huber_scale
5877
5877
elif args .huber_schedule == "snr" :
5878
- if not hasattr (noise_scheduler , ' alphas_cumprod' ):
5879
- raise NotImplementedError (f "Huber schedule 'snr' is not supported with the current model." )
5878
+ if not hasattr (noise_scheduler , " alphas_cumprod" ):
5879
+ raise NotImplementedError ("Huber schedule 'snr' is not supported with the current model." )
5880
5880
alphas_cumprod = torch .index_select (noise_scheduler .alphas_cumprod , 0 , timesteps .cpu ())
5881
5881
sigmas = ((1.0 - alphas_cumprod ) / alphas_cumprod ) ** 0.5
5882
5882
result = (1 - args .huber_c ) / (1 + sigmas ) ** 2 + args .huber_c
0 commit comments