Skip to content

Commit 420a180

Browse files
committed
Implement pseudo Huber loss for Flux and SD3
1 parent 2a61fc0 commit 420a180

15 files changed

+76
-61
lines changed

fine_tune.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -380,7 +380,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
380380

381381
# Sample noise, sample a random timestep for each image, and add noise to the latents,
382382
# with noise offset and/or multires noise if specified
383-
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(
383+
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(
384384
args, noise_scheduler, latents
385385
)
386386

@@ -397,7 +397,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
397397
if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred or args.debiased_estimation_loss:
398398
# do not mean over batch dimension for snr weight or scale v-pred loss
399399
loss = train_util.conditional_loss(
400-
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
400+
args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler
401401
)
402402
loss = loss.mean([1, 2, 3])
403403

@@ -411,7 +411,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
411411
loss = loss.mean() # mean over batch dimension
412412
else:
413413
loss = train_util.conditional_loss(
414-
noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c
414+
args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler
415415
)
416416

417417
accelerator.backward(loss)

flux_train.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -667,7 +667,7 @@ def grad_hook(parameter: torch.Tensor):
667667

668668
# calculate loss
669669
loss = train_util.conditional_loss(
670-
model_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=None
670+
args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler
671671
)
672672
if weighting is not None:
673673
loss = loss * weighting

flux_train_network.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -468,7 +468,7 @@ def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t
468468
)
469469
target[diff_output_pr_indices] = model_pred_prior.to(target.dtype)
470470

471-
return model_pred, target, timesteps, None, weighting
471+
return model_pred, target, timesteps, weighting
472472

473473
def post_process_loss(self, loss, args, timesteps, noise_scheduler):
474474
return loss

library/train_util.py

+42-32
Original file line numberDiff line numberDiff line change
@@ -3905,7 +3905,14 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
39053905
"--huber_c",
39063906
type=float,
39073907
default=0.1,
3908-
help="The huber loss parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type. default is 0.1 / Huber損失のパラメータ。loss_typeがhuberまたはsmooth l1の場合に有効。デフォルトは0.1",
3908+
help="The Huber loss decay parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type. default is 0.1 / Huber損失のパラメータ。loss_typeがhuberまたはsmooth l1の場合に有効。デフォルトは0.1",
3909+
)
3910+
3911+
parser.add_argument(
3912+
"--huber_scale",
3913+
type=float,
3914+
default=1.0,
3915+
help="The Huber loss scale parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type. default is 1.0 / Huber損失のパラメータ。loss_typeがhuberまたはsmooth l1の場合に有効。デフォルトは0.1",
39093916
)
39103917

39113918
parser.add_argument(
@@ -5821,29 +5828,10 @@ def save_sd_model_on_train_end_common(
58215828
huggingface_util.upload(args, out_dir, "/" + model_name, force_sync_upload=True)
58225829

58235830

5824-
def get_timesteps_and_huber_c(args, min_timestep, max_timestep, noise_scheduler, b_size, device):
5825-
timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device="cpu")
5826-
5827-
if args.loss_type == "huber" or args.loss_type == "smooth_l1":
5828-
if args.huber_schedule == "exponential":
5829-
alpha = -math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps
5830-
huber_c = torch.exp(-alpha * timesteps)
5831-
elif args.huber_schedule == "snr":
5832-
alphas_cumprod = torch.index_select(noise_scheduler.alphas_cumprod, 0, timesteps)
5833-
sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5
5834-
huber_c = (1 - args.huber_c) / (1 + sigmas) ** 2 + args.huber_c
5835-
elif args.huber_schedule == "constant":
5836-
huber_c = torch.full((b_size,), args.huber_c)
5837-
else:
5838-
raise NotImplementedError(f"Unknown Huber loss schedule {args.huber_schedule}!")
5839-
huber_c = huber_c.to(device)
5840-
elif args.loss_type == "l2":
5841-
huber_c = None # may be anything, as it's not used
5842-
else:
5843-
raise NotImplementedError(f"Unknown loss type {args.loss_type}")
5844-
5845-
timesteps = timesteps.long().to(device)
5846-
return timesteps, huber_c
5831+
def get_timesteps(min_timestep, max_timestep, b_size, device):
5832+
timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device=device)
5833+
timesteps = timesteps.long()
5834+
return timesteps
58475835

58485836

58495837
def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents):
@@ -5865,7 +5853,7 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents):
58655853
min_timestep = 0 if args.min_timestep is None else args.min_timestep
58665854
max_timestep = noise_scheduler.config.num_train_timesteps if args.max_timestep is None else args.max_timestep
58675855

5868-
timesteps, huber_c = get_timesteps_and_huber_c(args, min_timestep, max_timestep, noise_scheduler, b_size, latents.device)
5856+
timesteps = get_timesteps(min_timestep, max_timestep, b_size, latents.device)
58695857

58705858
# Add noise to the latents according to the noise magnitude at each timestep
58715859
# (this is the forward diffusion process)
@@ -5878,32 +5866,54 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents):
58785866
else:
58795867
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
58805868

5881-
return noise, noisy_latents, timesteps, huber_c
5869+
return noise, noisy_latents, timesteps
5870+
5871+
5872+
def get_huber_threshold(args, timesteps: torch.Tensor, noise_scheduler) -> torch.Tensor:
5873+
b_size = timesteps.shape[0]
5874+
if args.huber_schedule == "exponential":
5875+
alpha = -math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps
5876+
result = torch.exp(-alpha * timesteps) * args.huber_scale
5877+
elif args.huber_schedule == "snr":
5878+
if not hasattr(noise_scheduler, 'alphas_cumprod'):
5879+
raise NotImplementedError(f"Huber schedule 'snr' is not supported with the current model.")
5880+
alphas_cumprod = torch.index_select(noise_scheduler.alphas_cumprod, 0, timesteps.cpu())
5881+
sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5
5882+
result = (1 - args.huber_c) / (1 + sigmas) ** 2 + args.huber_c
5883+
result = result.to(timesteps.device)
5884+
elif args.huber_schedule == "constant":
5885+
result = torch.full((b_size,), args.huber_c * args.huber_scale, device=timesteps.device)
5886+
else:
5887+
raise NotImplementedError(f"Unknown Huber loss schedule {args.huber_schedule}!")
5888+
5889+
return result
58825890

58835891

58845892
def conditional_loss(
5885-
model_pred: torch.Tensor, target: torch.Tensor, reduction: str, loss_type: str, huber_c: Optional[torch.Tensor]
5893+
args, model_pred: torch.Tensor, target: torch.Tensor, timesteps: torch.Tensor, reduction: str, noise_scheduler
58865894
):
5887-
if loss_type == "l2":
5895+
if args.loss_type == "l2":
58885896
loss = torch.nn.functional.mse_loss(model_pred, target, reduction=reduction)
5889-
elif loss_type == "l1":
5897+
elif args.loss_type == "l1":
58905898
loss = torch.nn.functional.l1_loss(model_pred, target, reduction=reduction)
5891-
elif loss_type == "huber":
5899+
elif args.loss_type == "huber":
5900+
huber_c = get_huber_threshold(args, timesteps, noise_scheduler)
58925901
huber_c = huber_c.view(-1, 1, 1, 1)
58935902
loss = 2 * huber_c * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
58945903
if reduction == "mean":
58955904
loss = torch.mean(loss)
58965905
elif reduction == "sum":
58975906
loss = torch.sum(loss)
5898-
elif loss_type == "smooth_l1":
5907+
elif args.loss_type == "smooth_l1":
5908+
huber_c = get_huber_threshold(args, timesteps, noise_scheduler)
58995909
huber_c = huber_c.view(-1, 1, 1, 1)
59005910
loss = 2 * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
59015911
if reduction == "mean":
59025912
loss = torch.mean(loss)
59035913
elif reduction == "sum":
59045914
loss = torch.sum(loss)
59055915
else:
5906-
raise NotImplementedError(f"Unsupported Loss Type {loss_type}")
5916+
raise NotImplementedError(f"Unsupported Loss Type: {args.loss_type}")
59075917
return loss
59085918

59095919

sd3_train.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -845,7 +845,7 @@ def grad_hook(parameter: torch.Tensor):
845845
# )
846846
# calculate loss
847847
loss = train_util.conditional_loss(
848-
model_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=None
848+
args, model_pred.float(), target.float(), timesteps, "none", noise_scheduler
849849
)
850850
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
851851
loss = apply_masked_loss(loss, batch)

sd3_train_network.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -378,7 +378,7 @@ def get_noise_pred_and_target(
378378

379379
target[diff_output_pr_indices] = model_pred_prior.to(target.dtype)
380380

381-
return model_pred, target, timesteps, None, weighting
381+
return model_pred, target, timesteps, weighting
382382

383383
def post_process_loss(self, loss, args, timesteps, noise_scheduler):
384384
return loss

sdxl_train.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -695,7 +695,7 @@ def optimizer_hook(parameter: torch.Tensor):
695695

696696
# Sample noise, sample a random timestep for each image, and add noise to the latents,
697697
# with noise offset and/or multires noise if specified
698-
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(
698+
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(
699699
args, noise_scheduler, latents
700700
)
701701

@@ -720,7 +720,7 @@ def optimizer_hook(parameter: torch.Tensor):
720720
):
721721
# do not mean over batch dimension for snr weight or scale v-pred loss
722722
loss = train_util.conditional_loss(
723-
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
723+
args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler
724724
)
725725
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
726726
loss = apply_masked_loss(loss, batch)
@@ -738,7 +738,7 @@ def optimizer_hook(parameter: torch.Tensor):
738738
loss = loss.mean() # mean over batch dimension
739739
else:
740740
loss = train_util.conditional_loss(
741-
noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c
741+
args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler
742742
)
743743

744744
accelerator.backward(loss)

sdxl_train_control_net.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -512,7 +512,7 @@ def remove_model(old_ckpt_name):
512512

513513
# Sample noise, sample a random timestep for each image, and add noise to the latents,
514514
# with noise offset and/or multires noise if specified
515-
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(
515+
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(
516516
args, noise_scheduler, latents
517517
)
518518

@@ -534,7 +534,7 @@ def remove_model(old_ckpt_name):
534534
target = noise
535535

536536
loss = train_util.conditional_loss(
537-
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
537+
args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler
538538
)
539539
loss = loss.mean([1, 2, 3])
540540

sdxl_train_control_net_lllite.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -463,7 +463,7 @@ def remove_model(old_ckpt_name):
463463

464464
# Sample noise, sample a random timestep for each image, and add noise to the latents,
465465
# with noise offset and/or multires noise if specified
466-
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(
466+
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(
467467
args, noise_scheduler, latents
468468
)
469469

@@ -485,7 +485,7 @@ def remove_model(old_ckpt_name):
485485
target = noise
486486

487487
loss = train_util.conditional_loss(
488-
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
488+
args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler
489489
)
490490
loss = loss.mean([1, 2, 3])
491491

sdxl_train_control_net_lllite_old.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -406,7 +406,7 @@ def remove_model(old_ckpt_name):
406406

407407
# Sample noise, sample a random timestep for each image, and add noise to the latents,
408408
# with noise offset and/or multires noise if specified
409-
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
409+
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
410410

411411
noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
412412

@@ -426,7 +426,9 @@ def remove_model(old_ckpt_name):
426426
else:
427427
target = noise
428428

429-
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
429+
loss = train_util.conditional_loss(
430+
args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler
431+
)
430432
loss = loss.mean([1, 2, 3])
431433

432434
loss_weights = batch["loss_weights"] # 各sampleごとのweight

train_controlnet.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -464,8 +464,8 @@ def remove_model(old_ckpt_name):
464464
)
465465

466466
# Sample a random timestep for each image
467-
timesteps, huber_c = train_util.get_timesteps_and_huber_c(
468-
args, 0, noise_scheduler.config.num_train_timesteps, noise_scheduler, b_size, latents.device
467+
timesteps = train_util.get_timesteps(
468+
0, noise_scheduler.config.num_train_timesteps, b_size, latents.device
469469
)
470470

471471
# Add noise to the latents according to the noise magnitude at each timestep
@@ -499,7 +499,7 @@ def remove_model(old_ckpt_name):
499499
target = noise
500500

501501
loss = train_util.conditional_loss(
502-
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
502+
args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler
503503
)
504504
loss = loss.mean([1, 2, 3])
505505

train_db.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -370,7 +370,7 @@ def train(args):
370370

371371
# Sample noise, sample a random timestep for each image, and add noise to the latents,
372372
# with noise offset and/or multires noise if specified
373-
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(
373+
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(
374374
args, noise_scheduler, latents
375375
)
376376

@@ -385,7 +385,7 @@ def train(args):
385385
target = noise
386386

387387
loss = train_util.conditional_loss(
388-
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
388+
args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler
389389
)
390390
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
391391
loss = apply_masked_loss(loss, batch)

train_network.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ def get_noise_pred_and_target(
192192
):
193193
# Sample noise, sample a random timestep for each image, and add noise to the latents,
194194
# with noise offset and/or multires noise if specified
195-
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
195+
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
196196

197197
# ensure the hidden state will require grad
198198
if args.gradient_checkpointing:
@@ -244,7 +244,7 @@ def get_noise_pred_and_target(
244244
network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step
245245
target[diff_output_pr_indices] = noise_pred_prior.to(target.dtype)
246246

247-
return noise_pred, target, timesteps, huber_c, None
247+
return noise_pred, target, timesteps, None
248248

249249
def post_process_loss(self, loss, args, timesteps, noise_scheduler):
250250
if args.min_snr_gamma:
@@ -806,6 +806,7 @@ def load_model_hook(models, input_dir):
806806
"ss_ip_noise_gamma_random_strength": args.ip_noise_gamma_random_strength,
807807
"ss_loss_type": args.loss_type,
808808
"ss_huber_schedule": args.huber_schedule,
809+
"ss_huber_scale": args.huber_scale,
809810
"ss_huber_c": args.huber_c,
810811
"ss_fp8_base": bool(args.fp8_base),
811812
"ss_fp8_base_unet": bool(args.fp8_base_unet),
@@ -1193,7 +1194,7 @@ def remove_model(old_ckpt_name):
11931194
text_encoder_conds[i] = encoded_text_encoder_conds[i]
11941195

11951196
# sample noise, call unet, get target
1196-
noise_pred, target, timesteps, huber_c, weighting = self.get_noise_pred_and_target(
1197+
noise_pred, target, timesteps, weighting = self.get_noise_pred_and_target(
11971198
args,
11981199
accelerator,
11991200
noise_scheduler,
@@ -1207,7 +1208,7 @@ def remove_model(old_ckpt_name):
12071208
)
12081209

12091210
loss = train_util.conditional_loss(
1210-
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
1211+
args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler
12111212
)
12121213
if weighting is not None:
12131214
loss = loss * weighting

train_textual_inversion.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -585,7 +585,7 @@ def remove_model(old_ckpt_name):
585585

586586
# Sample noise, sample a random timestep for each image, and add noise to the latents,
587587
# with noise offset and/or multires noise if specified
588-
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(
588+
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(
589589
args, noise_scheduler, latents
590590
)
591591

@@ -602,7 +602,7 @@ def remove_model(old_ckpt_name):
602602
target = noise
603603

604604
loss = train_util.conditional_loss(
605-
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
605+
args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler
606606
)
607607
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
608608
loss = apply_masked_loss(loss, batch)

0 commit comments

Comments
 (0)