Skip to content

Commit

Permalink
Masked loss patched for stable cascade training
Browse files Browse the repository at this point in the history
  • Loading branch information
Jeff Ding committed Apr 4, 2024
1 parent 1c1173b commit c27c282
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion sd_scripts/stable_cascade_train_c_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -896,9 +896,10 @@ def remove_model(old_ckpt_name):
pred = stage_c(
noised, noise_cond, clip_text=encoder_hidden_states, clip_text_pooled=pool, clip_img=zero_img_emb
)
loss = torch.nn.functional.mse_loss(pred, target, reduction="none").mean(dim=[1, 2, 3])
loss = torch.nn.functional.mse_loss(pred.float(), target.float(), reduction="none")
if args.masked_loss:
loss = apply_masked_loss(loss, batch)
loss = loss.mean(dim=[1, 2, 3])
loss_adjusted = (loss * loss_weight).mean()

if args.adaptive_loss_weight:
Expand Down

0 comments on commit c27c282

Please sign in to comment.