Skip to content

Commit

Permalink
Temp fix to NaNs from log(0), requires testing
Browse files Browse the repository at this point in the history
  • Loading branch information
Jeff Ding committed May 2, 2024
1 parent a5655d4 commit ea366ea
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions sd_scripts/library/custom_train_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,18 +516,19 @@ def apply_multichannel_masked_loss(loss, batch, *weights: float) -> Tuple[Tensor

# compute weights
background_weight, character_weight, detail_weight = weights
detail_weight_with_size_correction = (torch.log(torch.tensor(reduce(mul, loss_height_width))).float() - torch.log(size_of_detail_tensor.float())) * detail_weight
detail_weight_with_size_correction = (torch.log(torch.tensor(reduce(mul, loss_height_width))).float() - torch.log(size_of_detail_tensor.float() + 1)) * detail_weight

background_loss = background_weight * (loss * background_mask_latent)
character_loss = character_weight * (loss * character_mask_latents)

nonzero_detail_loss = detail_weight_with_size_correction * (loss * detail_mask_latents)
detail_loss = torch.where(size_of_detail_tensor > 0, nonzero_detail_loss, torch.tensor(0.0))

# logger.info(f"Previous loss: {loss.sum()}")
# logger.info(f"Total background loss: {background_loss.sum()}")
# logger.info(f"Total character loss: {character_loss.sum()}")
# logger.info(f"Total detail loss: {detail_loss.sum()}")
# with torch.no_grad():
# logger.info(f"Previous loss: {loss.sum()}")
# logger.info(f"Total background loss: {background_loss.sum()}")
# logger.info(f"Total character loss: {character_loss.sum()}")
# logger.info(f"Total detail loss: {detail_loss.sum()}")

final_loss = background_loss + character_loss + detail_loss
return final_loss, background_mask_latent
Expand Down

0 comments on commit ea366ea

Please sign in to comment.