Skip to content

Commit

Permalink
Commented out debugging lines
Browse files Browse the repository at this point in the history
  • Loading branch information
Jeff Ding committed May 2, 2024
1 parent 489f84a commit f6483bc
Showing 1 changed file with 13 additions and 13 deletions.
26 changes: 13 additions & 13 deletions sd_scripts/library/custom_train_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,14 +490,14 @@ def apply_multichannel_masked_loss(loss, batch, weight1, weight2, weight3):
mask1 = batch["conditioning_images"].to(dtype=loss.dtype)[:, 0].unsqueeze(1)
mask2 = batch["conditioning_images"].to(dtype=loss.dtype)[:, 1].unsqueeze(1)

# Logging for debugging if needed
threshold = 0.5
red_pixels = (mask1 > threshold).sum().item()
green_pixels = (mask2 > threshold).sum().item()

total_pixels = mask1.numel()

logger.info(f"Red pixels: {red_pixels}, green pixels: {green_pixels}, total pixels: {total_pixels}")
# # Logging for debugging if needed
# threshold = 0.5
# red_pixels = (mask1 > threshold).sum().item()
# green_pixels = (mask2 > threshold).sum().item()
#
# total_pixels = mask1.numel()
#
# logger.info(f"Red pixels: {red_pixels}, green pixels: {green_pixels}, total pixels: {total_pixels}")

# resize to the same size as the loss
mask1 = torch.nn.functional.interpolate(mask1, size=loss.shape[2:], mode="area")
Expand All @@ -522,11 +522,11 @@ def apply_multichannel_masked_loss(loss, batch, weight1, weight2, weight3):
detail_loss = (detail_weight * (torch.log(torch.tensor(loss.numel(), device=loss.device).float()) - torch.log(bowtie_pixels.float())) * (loss * mask2))

detail_loss = torch.maximum(torch.tensor(0.0, dtype=loss.dtype, device=loss.device), -1 + detail_loss)

logger.info(f"Previous loss: {loss.sum()}")
logger.info(f"Total background loss: {background_loss.sum()}")
logger.info(f"Total character loss: {character_loss.sum()}")
logger.info(f"Total detail loss: {detail_loss.sum()}")
#
# logger.info(f"Previous loss: {loss.sum()}")
# logger.info(f"Total background loss: {background_loss.sum()}")
# logger.info(f"Total character loss: {character_loss.sum()}")
# logger.info(f"Total detail loss: {detail_loss.sum()}")

final_loss = background_loss + character_loss + detail_loss

Expand Down

0 comments on commit f6483bc

Please sign in to comment.