Skip to content

Commit

Permalink
Updates masked loss. Needs investigation for why last_hidden_states i…
Browse files Browse the repository at this point in the history
…n get_hidden_states_stable_cascade is NaN sometimes
  • Loading branch information
doctorpangloss committed May 2, 2024
1 parent f94b9c2 commit a5655d4
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 47 deletions.
2 changes: 1 addition & 1 deletion .editorconfig
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ charset = utf-8
trim_trailing_whitespace = true
insert_final_newline = true
visual_wrap = true
max_line_length = 120
max_line_length = 100000

[*]
indent_style = space
Expand Down
88 changes: 42 additions & 46 deletions sd_scripts/library/custom_train_functions.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,22 @@
import torch
import numpy as np
import argparse
import random
import re
from typing import List, Optional, Union
from functools import reduce
from operator import mul
from typing import List, Optional, Union, Tuple

import torch
import torch.nn.functional as F
from torch import Tensor

from .utils import setup_logging

setup_logging()
import logging

logger = logging.getLogger(__name__)


def prepare_scheduler_for_custom_training(noise_scheduler, device):
if hasattr(noise_scheduler, "all_snr"):
return
Expand Down Expand Up @@ -42,7 +50,7 @@ def enforce_zero_terminal_snr(betas):
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)

# Convert alphas_bar_sqrt to betas
alphas_bar = alphas_bar_sqrt**2
alphas_bar = alphas_bar_sqrt ** 2
alphas = alphas_bar[1:] / alphas_bar[:-1]
alphas = torch.cat([alphas_bar[0:1], alphas])
betas = 1 - alphas
Expand All @@ -65,7 +73,7 @@ def apply_snr_weight(loss, timesteps, noise_scheduler, gamma, v_prediction=False
snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps])
min_snr_gamma = torch.minimum(snr, torch.full_like(snr, gamma))
if v_prediction:
snr_weight = torch.div(min_snr_gamma, snr+1).float().to(loss.device)
snr_weight = torch.div(min_snr_gamma, snr + 1).float().to(loss.device)
else:
snr_weight = torch.div(min_snr_gamma, snr).float().to(loss.device)
loss = loss * snr_weight
Expand Down Expand Up @@ -93,13 +101,15 @@ def add_v_prediction_like_loss(loss, timesteps, noise_scheduler, v_pred_like_los
loss = loss + loss / scale * v_pred_like_loss
return loss


def apply_debiased_estimation(loss, timesteps, noise_scheduler):
snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size
snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000
weight = 1/torch.sqrt(snr_t)
weight = 1 / torch.sqrt(snr_t)
loss = weight * loss
return loss


# TODO train_utilと分散しているのでどちらかに寄せる


Expand Down Expand Up @@ -293,7 +303,7 @@ def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_midd
else:
for j in range(max_embeddings_multiples):
w.append(1.0) # weight for starting token in this chunk
w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))]
w += weights[i][j * (chunk_length - 2): min(len(weights[i]), (j + 1) * (chunk_length - 2))]
w.append(1.0) # weight for ending token in this chunk
w += [1.0] * (weights_length - len(w))
weights[i] = w[:]
Expand All @@ -320,7 +330,7 @@ def get_unweighted_text_embeddings(
text_embeddings = []
for i in range(max_embeddings_multiples):
# extract the i-th chunk
text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone()
text_input_chunk = text_input[:, i * (chunk_length - 2): (i + 1) * (chunk_length - 2) + 2].clone()

# cover the head and the tail by the starting and the ending tokens
text_input_chunk[:, 0] = text_input[0, 0]
Expand Down Expand Up @@ -451,8 +461,8 @@ def pyramid_noise_like(noise, device, iterations=6, discount=0.4):
u = torch.nn.Upsample(size=(w, h), mode="bilinear").to(device)
for i in range(iterations):
r = random.random() * 2 + 2 # Rather than always going 2x,
wn, hn = max(1, int(w / (r**i))), max(1, int(h / (r**i)))
noise += u(torch.randn(b, c, wn, hn).to(device)) * discount**i
wn, hn = max(1, int(w / (r ** i))), max(1, int(h / (r ** i)))
noise += u(torch.randn(b, c, wn, hn).to(device)) * discount ** i
if wn == 1 or hn == 1:
break # Lowest resolution is 1x1
return noise / noise.std() # Scaled back to roughly unit variance
Expand All @@ -474,6 +484,7 @@ def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale):
noise = noise + noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
return noise


def apply_masked_loss(loss, batch):
# mask image is -1 to 1. we need to convert it to 0 to 1
mask_image = batch["conditioning_images"].to(dtype=loss.dtype)[:, 0].unsqueeze(1) # use R channel
Expand All @@ -484,57 +495,42 @@ def apply_masked_loss(loss, batch):
loss = loss * mask_image
return loss, mask_image

## Custom loss function for weighing a character, and specifical details of the character, differently from the background
def apply_multichannel_masked_loss(loss, batch, weight1, weight2, weight3):
# Merge the character and detail masks to have character mask in channel 1 and detail mask in channel 2
mask1 = batch["conditioning_images"].to(dtype=loss.dtype)[:, 0].unsqueeze(1)
mask2 = batch["conditioning_images"].to(dtype=loss.dtype)[:, 1].unsqueeze(1)

# # Logging for debugging if needed
# threshold = 0.5
# red_pixels = (mask1 > threshold).sum().item()
# green_pixels = (mask2 > threshold).sum().item()
#
# total_pixels = mask1.numel()
#
# logger.info(f"Red pixels: {red_pixels}, green pixels: {green_pixels}, total pixels: {total_pixels}")

# resize to the same size as the loss
mask1 = torch.nn.functional.interpolate(mask1, size=loss.shape[2:], mode="area")
mask2 = torch.nn.functional.interpolate(mask2, size=loss.shape[2:], mode="area")
## Custom loss function for weighing a character, and specifical details of the character, differently from the background
def apply_multichannel_masked_loss(loss, batch, *weights: float) -> Tuple[Tensor, Tensor]:
# masks are in channels 0 and 1 in the RGB conditioning image
# todo: pass this some other way, potentially using EXR format for multichannel image data
character_mask = batch["conditioning_images"].to(dtype=loss.dtype)[:, 0].unsqueeze(1) / 2.0 + 0.5
detail_mask = batch["conditioning_images"].to(dtype=loss.dtype)[:, 1].unsqueeze(1) / 2.0 + 0.5

# Normalize masks to range 0 to 1
mask1 = mask1 / 2 + 0.5
mask2 = mask2 / 2 + 0.5
# resize to the same size in the height, width dimension as the loss
loss_height_width = loss.shape[2:]
character_mask_latents = (F.interpolate(character_mask, size=loss_height_width, mode="area") != 0.0).float()
detail_mask_latents: Tensor = (F.interpolate(detail_mask, size=loss_height_width, mode="area") != 0.0).float()

# invert mask 1, so that we can separate MrM and the background and calculate separately
mask1_inv = 1 - mask1

# Assuming mask channel is either 0 or 1, calculate total # of pixels "in" the mask
bowtie_pixels = mask2.sum()
background_mask_latent = 1 - character_mask_latents

background_weight = weight1
character_weight = weight2
detail_weight = weight3
# assuming mask channel is either 0 or 1, calculate total # of pixels "in" the mask
size_of_detail_tensor = detail_mask_latents.count_nonzero()

background_loss = background_weight * (loss * mask1_inv)
character_loss = character_weight * (loss * mask1)

if bowtie_pixels == 0:
detail_loss = torch.tensor(0.0, dtype=loss.dtype, device=loss.device)
else:
# compute weights
background_weight, character_weight, detail_weight = weights
detail_weight_with_size_correction = (torch.log(torch.tensor(reduce(mul, loss_height_width))).float() - torch.log(size_of_detail_tensor.float())) * detail_weight

detail_loss = (detail_weight * (torch.log(torch.tensor(loss.numel(), device=loss.device).float()) - torch.log(bowtie_pixels.float())) * (loss * mask2))
background_loss = background_weight * (loss * background_mask_latent)
character_loss = character_weight * (loss * character_mask_latents)

detail_loss = torch.maximum(torch.tensor(0.0, dtype=loss.dtype, device=loss.device), -1 + detail_loss)
nonzero_detail_loss = detail_weight_with_size_correction * (loss * detail_mask_latents)
detail_loss = torch.where(size_of_detail_tensor > 0, nonzero_detail_loss, torch.tensor(0.0))

# logger.info(f"Previous loss: {loss.sum()}")
# logger.info(f"Total background loss: {background_loss.sum()}")
# logger.info(f"Total character loss: {character_loss.sum()}")
# logger.info(f"Total detail loss: {detail_loss.sum()}")

final_loss = background_loss + character_loss + detail_loss
return final_loss, mask1_inv
return final_loss, background_mask_latent


"""
Expand Down

0 comments on commit a5655d4

Please sign in to comment.