diff --git a/sd_scripts/library/custom_train_functions.py b/sd_scripts/library/custom_train_functions.py index 1f35f8798..dc7fdfd22 100644 --- a/sd_scripts/library/custom_train_functions.py +++ b/sd_scripts/library/custom_train_functions.py @@ -490,14 +490,14 @@ def apply_multichannel_masked_loss(loss, batch, weight1, weight2, weight3): mask1 = batch["conditioning_images"].to(dtype=loss.dtype)[:, 0].unsqueeze(1) mask2 = batch["conditioning_images"].to(dtype=loss.dtype)[:, 1].unsqueeze(1) - # Logging for debugging if needed - threshold = 0.5 - red_pixels = (mask1 > threshold).sum().item() - green_pixels = (mask2 > threshold).sum().item() - - total_pixels = mask1.numel() - - logger.info(f"Red pixels: {red_pixels}, green pixels: {green_pixels}, total pixels: {total_pixels}") + # # Logging for debugging if needed + # threshold = 0.5 + # red_pixels = (mask1 > threshold).sum().item() + # green_pixels = (mask2 > threshold).sum().item() + # + # total_pixels = mask1.numel() + # + # logger.info(f"Red pixels: {red_pixels}, green pixels: {green_pixels}, total pixels: {total_pixels}") # resize to the same size as the loss mask1 = torch.nn.functional.interpolate(mask1, size=loss.shape[2:], mode="area") @@ -522,11 +522,11 @@ def apply_multichannel_masked_loss(loss, batch, weight1, weight2, weight3): detail_loss = (detail_weight * (torch.log(torch.tensor(loss.numel(), device=loss.device).float()) - torch.log(bowtie_pixels.float())) * (loss * mask2)) detail_loss = torch.maximum(torch.tensor(0.0, dtype=loss.dtype, device=loss.device), -1 + detail_loss) - - logger.info(f"Previous loss: {loss.sum()}") - logger.info(f"Total background loss: {background_loss.sum()}") - logger.info(f"Total character loss: {character_loss.sum()}") - logger.info(f"Total detail loss: {detail_loss.sum()}") + # + # logger.info(f"Previous loss: {loss.sum()}") + # logger.info(f"Total background loss: {background_loss.sum()}") + # logger.info(f"Total character loss: {character_loss.sum()}") + # logger.info(f"Total detail loss: {detail_loss.sum()}") final_loss = background_loss + character_loss + detail_loss