@@ -3905,7 +3905,14 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
3905
3905
"--huber_c" ,
3906
3906
type = float ,
3907
3907
default = 0.1 ,
3908
- help = "The huber loss parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type. default is 0.1 / Huber損失のパラメータ。loss_typeがhuberまたはsmooth l1の場合に有効。デフォルトは0.1" ,
3908
+ help = "The Huber loss decay parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type. default is 0.1 / Huber損失のパラメータ。loss_typeがhuberまたはsmooth l1の場合に有効。デフォルトは0.1" ,
3909
+ )
3910
+
3911
+ parser .add_argument (
3912
+ "--huber_scale" ,
3913
+ type = float ,
3914
+ default = 1.0 ,
3915
+ help = "The Huber loss scale parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type. default is 1.0 / Huber損失のパラメータ。loss_typeがhuberまたはsmooth l1の場合に有効。デフォルトは0.1" ,
3909
3916
)
3910
3917
3911
3918
parser .add_argument (
@@ -5821,29 +5828,10 @@ def save_sd_model_on_train_end_common(
5821
5828
huggingface_util .upload (args , out_dir , "/" + model_name , force_sync_upload = True )
5822
5829
5823
5830
5824
- def get_timesteps_and_huber_c (args , min_timestep , max_timestep , noise_scheduler , b_size , device ):
5825
- timesteps = torch .randint (min_timestep , max_timestep , (b_size ,), device = "cpu" )
5826
-
5827
- if args .loss_type == "huber" or args .loss_type == "smooth_l1" :
5828
- if args .huber_schedule == "exponential" :
5829
- alpha = - math .log (args .huber_c ) / noise_scheduler .config .num_train_timesteps
5830
- huber_c = torch .exp (- alpha * timesteps )
5831
- elif args .huber_schedule == "snr" :
5832
- alphas_cumprod = torch .index_select (noise_scheduler .alphas_cumprod , 0 , timesteps )
5833
- sigmas = ((1.0 - alphas_cumprod ) / alphas_cumprod ) ** 0.5
5834
- huber_c = (1 - args .huber_c ) / (1 + sigmas ) ** 2 + args .huber_c
5835
- elif args .huber_schedule == "constant" :
5836
- huber_c = torch .full ((b_size ,), args .huber_c )
5837
- else :
5838
- raise NotImplementedError (f"Unknown Huber loss schedule { args .huber_schedule } !" )
5839
- huber_c = huber_c .to (device )
5840
- elif args .loss_type == "l2" :
5841
- huber_c = None # may be anything, as it's not used
5842
- else :
5843
- raise NotImplementedError (f"Unknown loss type { args .loss_type } " )
5844
-
5845
- timesteps = timesteps .long ().to (device )
5846
- return timesteps , huber_c
5831
+ def get_timesteps (min_timestep , max_timestep , b_size , device ):
5832
+ timesteps = torch .randint (min_timestep , max_timestep , (b_size ,), device = device )
5833
+ timesteps = timesteps .long ()
5834
+ return timesteps
5847
5835
5848
5836
5849
5837
def get_noise_noisy_latents_and_timesteps (args , noise_scheduler , latents ):
@@ -5865,7 +5853,7 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents):
5865
5853
min_timestep = 0 if args .min_timestep is None else args .min_timestep
5866
5854
max_timestep = noise_scheduler .config .num_train_timesteps if args .max_timestep is None else args .max_timestep
5867
5855
5868
- timesteps , huber_c = get_timesteps_and_huber_c ( args , min_timestep , max_timestep , noise_scheduler , b_size , latents .device )
5856
+ timesteps = get_timesteps ( min_timestep , max_timestep , b_size , latents .device )
5869
5857
5870
5858
# Add noise to the latents according to the noise magnitude at each timestep
5871
5859
# (this is the forward diffusion process)
@@ -5878,32 +5866,54 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents):
5878
5866
else :
5879
5867
noisy_latents = noise_scheduler .add_noise (latents , noise , timesteps )
5880
5868
5881
- return noise , noisy_latents , timesteps , huber_c
5869
+ return noise , noisy_latents , timesteps
5870
+
5871
+
5872
+ def get_huber_threshold (args , timesteps : torch .Tensor , noise_scheduler ) -> torch .Tensor :
5873
+ b_size = timesteps .shape [0 ]
5874
+ if args .huber_schedule == "exponential" :
5875
+ alpha = - math .log (args .huber_c ) / noise_scheduler .config .num_train_timesteps
5876
+ result = torch .exp (- alpha * timesteps ) * args .huber_scale
5877
+ elif args .huber_schedule == "snr" :
5878
+ if not hasattr (noise_scheduler , 'alphas_cumprod' ):
5879
+ raise NotImplementedError (f"Huber schedule 'snr' is not supported with the current model." )
5880
+ alphas_cumprod = torch .index_select (noise_scheduler .alphas_cumprod , 0 , timesteps .cpu ())
5881
+ sigmas = ((1.0 - alphas_cumprod ) / alphas_cumprod ) ** 0.5
5882
+ result = (1 - args .huber_c ) / (1 + sigmas ) ** 2 + args .huber_c
5883
+ result = result .to (timesteps .device )
5884
+ elif args .huber_schedule == "constant" :
5885
+ result = torch .full ((b_size ,), args .huber_c * args .huber_scale , device = timesteps .device )
5886
+ else :
5887
+ raise NotImplementedError (f"Unknown Huber loss schedule { args .huber_schedule } !" )
5888
+
5889
+ return result
5882
5890
5883
5891
5884
5892
def conditional_loss (
5885
- model_pred : torch .Tensor , target : torch .Tensor , reduction : str , loss_type : str , huber_c : Optional [ torch . Tensor ]
5893
+ args , model_pred : torch .Tensor , target : torch .Tensor , timesteps : torch . Tensor , reduction : str , noise_scheduler
5886
5894
):
5887
- if loss_type == "l2" :
5895
+ if args . loss_type == "l2" :
5888
5896
loss = torch .nn .functional .mse_loss (model_pred , target , reduction = reduction )
5889
- elif loss_type == "l1" :
5897
+ elif args . loss_type == "l1" :
5890
5898
loss = torch .nn .functional .l1_loss (model_pred , target , reduction = reduction )
5891
- elif loss_type == "huber" :
5899
+ elif args .loss_type == "huber" :
5900
+ huber_c = get_huber_threshold (args , timesteps , noise_scheduler )
5892
5901
huber_c = huber_c .view (- 1 , 1 , 1 , 1 )
5893
5902
loss = 2 * huber_c * (torch .sqrt ((model_pred - target ) ** 2 + huber_c ** 2 ) - huber_c )
5894
5903
if reduction == "mean" :
5895
5904
loss = torch .mean (loss )
5896
5905
elif reduction == "sum" :
5897
5906
loss = torch .sum (loss )
5898
- elif loss_type == "smooth_l1" :
5907
+ elif args .loss_type == "smooth_l1" :
5908
+ huber_c = get_huber_threshold (args , timesteps , noise_scheduler )
5899
5909
huber_c = huber_c .view (- 1 , 1 , 1 , 1 )
5900
5910
loss = 2 * (torch .sqrt ((model_pred - target ) ** 2 + huber_c ** 2 ) - huber_c )
5901
5911
if reduction == "mean" :
5902
5912
loss = torch .mean (loss )
5903
5913
elif reduction == "sum" :
5904
5914
loss = torch .sum (loss )
5905
5915
else :
5906
- raise NotImplementedError (f"Unsupported Loss Type { loss_type } " )
5916
+ raise NotImplementedError (f"Unsupported Loss Type: { args . loss_type } " )
5907
5917
return loss
5908
5918
5909
5919
0 commit comments