Skip to content

Commit

Permalink
Updated custom loss function for weighing different parts of the imag…
Browse files Browse the repository at this point in the history
…e, requires testing
  • Loading branch information
Jeff Ding committed May 1, 2024
1 parent c5c615d commit 52e4c8b
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 15 deletions.
36 changes: 25 additions & 11 deletions sd_scripts/library/custom_train_functions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
import numpy as np
import argparse
import random
import re
Expand Down Expand Up @@ -483,11 +484,23 @@ def apply_masked_loss(loss, batch):
loss = loss * mask_image
return loss, mask_image

def apply_bowtie_masked_loss(loss, batch):
# Merge the MrM and Bowtie masks to have MrM mask in channel 1 and bowtie mask in channel 2
## Custom loss function for weighing a character, and specifical details of the character, differently from the background
def apply_multichannel_masked_loss(loss, batch, weight1, weight2, weight3):
# Merge the character and detail masks to have character mask in channel 1 and detail mask in channel 2
mask1 = batch["conditioning_images"].to(dtype=loss.dtype)[:, 0].unsqueeze(1)
mask2 = batch["conditioning_images"].to(dtype=loss.dtype)[:, 1].unsqueeze(1)

# Logging for debugging if needed
# threshold = 0.5
# red_pixels = (mask1 > threshold).sum().item()
# green_pixels = (mask2 > threshold).sum().item()
#
# total_red_pixels = mask1.numel()
# total_green_pixels = mask2.numel()
#
# logger.info(f"Number of 'red' pixels in mask1 above threshold: {red_pixels}, total pixels: {total_red_pixels}")
# logger.info(f"Number of 'green' pixels in mask2 above threshold: {green_pixels}, total pixels: {total_green_pixels}")

# resize to the same size as the loss
mask1 = torch.nn.functional.interpolate(mask1, size=loss.shape[2:], mode="area")
mask2 = torch.nn.functional.interpolate(mask2, size=loss.shape[2:], mode="area")
Expand All @@ -502,18 +515,19 @@ def apply_bowtie_masked_loss(loss, batch):
# Assuming mask channel is either 0 or 1, calculate total # of pixels "in" the mask
bowtie_pixels = mask2.sum()

background_weight = 1.0
mrm_weight = 1.5
bowtie_weight = 2.0
background_weight = weight1
character_weight = weight2
detail_weight = weight3

background_loss = background_weight * (loss * mask1_inv)
character_loss = character_weight * (loss * mask1)
detail_loss = (detail_weight * (torch.log(torch.tensor(loss.numel(), device=loss.device).float()) - torch.log(bowtie_pixels.float())) * (loss * mask2))

background_loss = background_weight * (loss * mask1)
mrm_loss = mrm_weight * (loss * mask1_inv)
bowtie_loss = torch.maximum(torch.tensor(0.0, dtype=loss.dtype, device=loss.device),
-1 + bowtie_weight * (torch.log(loss.numel()) - torch.log(bowtie_pixels)) * (loss * mask2))
detail_loss = torch.maximum(torch.tensor(0.0, dtype=loss.dtype, device=loss.device), -1 + detail_loss)

final_loss = background_loss + mrm_loss + bowtie_loss
final_loss = background_loss + character_loss + detail_loss

return final_loss, mask1, mask2
return final_loss, mask1_inv


"""
Expand Down
8 changes: 6 additions & 2 deletions sd_scripts/library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -1224,8 +1224,12 @@ def __getitem__(self, index):
images.append(image)
latents_list.append(latents)

target_size = (image.shape[2], image.shape[1]) if image is not None else (
latents.shape[2] * 8, latents.shape[1] * 8)
if image is not None:
target_size = (image.shape[2], image.shape[1])
elif image_info.latents_npz is not None:
target_size = (latents.shape[2] * 32, latents.shape[1] * 32)
else:
target_size = (latents.shape[2] * 8, latents.shape[1] * 8)

if not flipped:
crop_left_top = (crop_ltrb[0], crop_ltrb[1])
Expand Down
5 changes: 3 additions & 2 deletions sd_scripts/stable_cascade_train_c_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from .library import stable_cascade as sc
from .library.sdxl_train_util import add_sdxl_training_arguments
from .library.custom_train_functions import (
apply_masked_loss,
apply_masked_loss, apply_multichannel_masked_loss,
)
from .library.utils import setup_logging, add_logging_arguments

Expand Down Expand Up @@ -900,7 +900,8 @@ def remove_model(old_ckpt_name):
loss = torch.nn.functional.mse_loss(pred.float(), target.float(), reduction="none")

if args.masked_loss and np.random.rand() < args.masked_loss_prob:
loss, noise_mask = apply_masked_loss(loss, batch)
# loss, noise_mask = apply_masked_loss(loss, batch)
loss, noise_mask = apply_multichannel_masked_loss(loss, batch, 1.0, 1.5, 3.0)
else:
noise_mask = torch.ones_like(noise, device=noise.device)

Expand Down

0 comments on commit 52e4c8b

Please sign in to comment.