Skip to content

Commit

Permalink
Fixed a few bugs, added a check for no details mask pixels, to avoid …
Browse files Browse the repository at this point in the history
…NaN error (also allows you to just add a 'full' red image for one where you dont have separate masks
  • Loading branch information
Jeff Ding committed May 2, 2024
1 parent f6483bc commit f94b9c2
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 5 deletions.
12 changes: 8 additions & 4 deletions sd_scripts/library/custom_train_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,17 +519,21 @@ def apply_multichannel_masked_loss(loss, batch, weight1, weight2, weight3):

background_loss = background_weight * (loss * mask1_inv)
character_loss = character_weight * (loss * mask1)
detail_loss = (detail_weight * (torch.log(torch.tensor(loss.numel(), device=loss.device).float()) - torch.log(bowtie_pixels.float())) * (loss * mask2))

detail_loss = torch.maximum(torch.tensor(0.0, dtype=loss.dtype, device=loss.device), -1 + detail_loss)
#
if bowtie_pixels == 0:
detail_loss = torch.tensor(0.0, dtype=loss.dtype, device=loss.device)
else:

detail_loss = (detail_weight * (torch.log(torch.tensor(loss.numel(), device=loss.device).float()) - torch.log(bowtie_pixels.float())) * (loss * mask2))

detail_loss = torch.maximum(torch.tensor(0.0, dtype=loss.dtype, device=loss.device), -1 + detail_loss)

# logger.info(f"Previous loss: {loss.sum()}")
# logger.info(f"Total background loss: {background_loss.sum()}")
# logger.info(f"Total character loss: {character_loss.sum()}")
# logger.info(f"Total detail loss: {detail_loss.sum()}")

final_loss = background_loss + character_loss + detail_loss

return final_loss, mask1_inv


Expand Down
2 changes: 1 addition & 1 deletion sd_scripts/stable_cascade_train_c_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -921,7 +921,7 @@ def remove_model(old_ckpt_name):

if args.masked_loss and np.random.rand() < args.masked_loss_prob:
# loss, noise_mask = apply_masked_loss(loss, batch)
loss, noise_mask = apply_multichannel_masked_loss(loss, batch, 1.0, 1.0, 1.0)
loss, noise_mask = apply_multichannel_masked_loss(loss, batch, 1.0, 1.5, 2.0)
noise_mask = torch.ones_like(noise, device=noise.device)
else:
noise_mask = torch.ones_like(noise, device=noise.device)
Expand Down

0 comments on commit f94b9c2

Please sign in to comment.