Skip to content

Commit

Permalink
WIP custom train function for bowtie masks
Browse files Browse the repository at this point in the history
  • Loading branch information
Jeff Ding committed Apr 30, 2024
1 parent 3f10c45 commit c5c615d
Showing 1 changed file with 33 additions and 0 deletions.
33 changes: 33 additions & 0 deletions sd_scripts/library/custom_train_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,39 @@ def apply_masked_loss(loss, batch):
loss = loss * mask_image
return loss, mask_image

def apply_bowtie_masked_loss(loss, batch):
# Merge the MrM and Bowtie masks to have MrM mask in channel 1 and bowtie mask in channel 2
mask1 = batch["conditioning_images"].to(dtype=loss.dtype)[:, 0].unsqueeze(1)
mask2 = batch["conditioning_images"].to(dtype=loss.dtype)[:, 1].unsqueeze(1)

# resize to the same size as the loss
mask1 = torch.nn.functional.interpolate(mask1, size=loss.shape[2:], mode="area")
mask2 = torch.nn.functional.interpolate(mask2, size=loss.shape[2:], mode="area")

# Normalize masks to range 0 to 1
mask1 = mask1 / 2 + 0.5
mask2 = mask2 / 2 + 0.5

# invert mask 1, so that we can separate MrM and the background and calculate separately
mask1_inv = 1 - mask1

# Assuming mask channel is either 0 or 1, calculate total # of pixels "in" the mask
bowtie_pixels = mask2.sum()

background_weight = 1.0
mrm_weight = 1.5
bowtie_weight = 2.0

background_loss = background_weight * (loss * mask1)
mrm_loss = mrm_weight * (loss * mask1_inv)
bowtie_loss = torch.maximum(torch.tensor(0.0, dtype=loss.dtype, device=loss.device),
-1 + bowtie_weight * (torch.log(loss.numel()) - torch.log(bowtie_pixels)) * (loss * mask2))

final_loss = background_loss + mrm_loss + bowtie_loss

return final_loss, mask1, mask2


"""
##########################################
# Perlin Noise
Expand Down

0 comments on commit c5c615d

Please sign in to comment.