Skip to content

Commit

Permalink
WIP multichannel masked loss
Browse files Browse the repository at this point in the history
  • Loading branch information
Jeff Ding committed May 2, 2024
1 parent 52e4c8b commit 489f84a
Show file tree
Hide file tree
Showing 5 changed files with 108 additions and 54 deletions.
21 changes: 12 additions & 9 deletions sd_scripts/library/custom_train_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,15 +491,13 @@ def apply_multichannel_masked_loss(loss, batch, weight1, weight2, weight3):
mask2 = batch["conditioning_images"].to(dtype=loss.dtype)[:, 1].unsqueeze(1)

# Logging for debugging if needed
# threshold = 0.5
# red_pixels = (mask1 > threshold).sum().item()
# green_pixels = (mask2 > threshold).sum().item()
#
# total_red_pixels = mask1.numel()
# total_green_pixels = mask2.numel()
#
# logger.info(f"Number of 'red' pixels in mask1 above threshold: {red_pixels}, total pixels: {total_red_pixels}")
# logger.info(f"Number of 'green' pixels in mask2 above threshold: {green_pixels}, total pixels: {total_green_pixels}")
threshold = 0.5
red_pixels = (mask1 > threshold).sum().item()
green_pixels = (mask2 > threshold).sum().item()

total_pixels = mask1.numel()

logger.info(f"Red pixels: {red_pixels}, green pixels: {green_pixels}, total pixels: {total_pixels}")

# resize to the same size as the loss
mask1 = torch.nn.functional.interpolate(mask1, size=loss.shape[2:], mode="area")
Expand All @@ -525,6 +523,11 @@ def apply_multichannel_masked_loss(loss, batch, weight1, weight2, weight3):

detail_loss = torch.maximum(torch.tensor(0.0, dtype=loss.dtype, device=loss.device), -1 + detail_loss)

logger.info(f"Previous loss: {loss.sum()}")
logger.info(f"Total background loss: {background_loss.sum()}")
logger.info(f"Total character loss: {character_loss.sum()}")
logger.info(f"Total detail loss: {detail_loss.sum()}")

final_loss = background_loss + character_loss + detail_loss

return final_loss, mask1_inv
Expand Down
75 changes: 49 additions & 26 deletions sd_scripts/library/stable_cascade.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import torch.utils.checkpoint
import torchvision


MODEL_VERSION_STABLE_CASCADE = "stable_cascade"

EFFNET_PREPROCESS = torchvision.transforms.Compose(
Expand All @@ -30,8 +29,8 @@ class vector_quantize(torch.autograd.Function):
@staticmethod
def forward(ctx, x, codebook):
with torch.no_grad():
codebook_sqr = torch.sum(codebook**2, dim=1)
x_sqr = torch.sum(x**2, dim=1, keepdim=True)
codebook_sqr = torch.sum(codebook ** 2, dim=1)
x_sqr = torch.sum(x ** 2, dim=1, keepdim=True)

dist = torch.addmm(codebook_sqr + x_sqr, x, codebook.t(), alpha=-2.0, beta=1.0)
_, indices = dist.min(dim=1)
Expand Down Expand Up @@ -172,7 +171,6 @@ def encode(self, x):
from torch.nn import Conv2d
from torch.nn import Linear


r"""
class Attention2D(nn.Module):
def __init__(self, c, nhead, dropout=0.0):
Expand Down Expand Up @@ -458,7 +456,8 @@ def __init__(self, c_in, c_out, mode, enabled=True):
super().__init__()
assert mode in ["up", "down"]
interpolation = (
nn.Upsample(scale_factor=2 if mode == "up" else 0.5, mode="bilinear", align_corners=True) if enabled else nn.Identity()
nn.Upsample(scale_factor=2 if mode == "up" else 0.5, mode="bilinear",
align_corners=True) if enabled else nn.Identity()
)
mapping = nn.Conv2d(c_in, c_out, kernel_size=1)
self.blocks = nn.ModuleList([interpolation, mapping] if mode == "up" else [mapping, interpolation])
Expand Down Expand Up @@ -540,11 +539,12 @@ def forward(self, x):


class StageA(nn.Module):
def __init__(self, levels=2, bottleneck_blocks=12, c_hidden=384, c_latent=4, codebook_size=8192, scale_factor=0.43): # 0.3764
def __init__(self, levels=2, bottleneck_blocks=12, c_hidden=384, c_latent=4, codebook_size=8192,
scale_factor=0.43): # 0.3764
super().__init__()
self.c_latent = c_latent
self.scale_factor = scale_factor
c_levels = [c_hidden // (2**i) for i in reversed(range(levels))]
c_levels = [c_hidden // (2 ** i) for i in reversed(range(levels))]

# Encoder blocks
self.in_block = nn.Sequential(nn.PixelUnshuffle(2), nn.Conv2d(3 * 4, c_levels[0], kernel_size=1))
Expand Down Expand Up @@ -574,7 +574,8 @@ def __init__(self, levels=2, bottleneck_blocks=12, c_hidden=384, c_latent=4, cod
up_blocks.append(block)
if i < levels - 1:
up_blocks.append(
nn.ConvTranspose2d(c_levels[levels - 1 - i], c_levels[levels - 2 - i], kernel_size=4, stride=2, padding=1)
nn.ConvTranspose2d(c_levels[levels - 1 - i], c_levels[levels - 2 - i], kernel_size=4, stride=2,
padding=1)
)
self.up_blocks = nn.Sequential(*up_blocks)
self.out_block = nn.Sequential(
Expand Down Expand Up @@ -672,7 +673,7 @@ def __init__(

self.embedding = nn.Sequential(
nn.PixelUnshuffle(patch_size),
nn.Conv2d(c_in * (patch_size**2), c_hidden[0], kernel_size=1),
nn.Conv2d(c_in * (patch_size ** 2), c_hidden[0], kernel_size=1),
LayerNorm2d(c_hidden[0], elementwise_affine=False, eps=1e-6),
)

Expand Down Expand Up @@ -733,7 +734,8 @@ def get_block(block_type, c_hidden, nhead, c_skip=0, dropout=0, self_attn=True):
for j in range(blocks[1][::-1][i]):
for k, block_type in enumerate(level_config[i]):
c_skip = c_hidden[i] if i < len(c_hidden) - 1 and j == k == 0 else 0
block = get_block(block_type, c_hidden[i], nhead[i], c_skip=c_skip, dropout=dropout[i], self_attn=self_attn[i])
block = get_block(block_type, c_hidden[i], nhead[i], c_skip=c_skip, dropout=dropout[i],
self_attn=self_attn[i])
up_block.append(block)
self.up_blocks.append(up_block)
if block_repeat is not None:
Expand All @@ -745,7 +747,7 @@ def get_block(block_type, c_hidden, nhead, c_skip=0, dropout=0, self_attn=True):
# OUTPUT
self.clf = nn.Sequential(
LayerNorm2d(c_hidden[0], elementwise_affine=False, eps=1e-6),
nn.Conv2d(c_hidden[0], c_out * (patch_size**2), kernel_size=1),
nn.Conv2d(c_hidden[0], c_out * (patch_size ** 2), kernel_size=1),
nn.PixelShuffle(patch_size),
)

Expand Down Expand Up @@ -836,7 +838,8 @@ def _up_decode(self, level_outputs, r_embed, clip):
):
skip = level_outputs[i] if k == 0 and i > 0 else None
if skip is not None and (x.size(-1) != skip.size(-1) or x.size(-2) != skip.size(-2)):
x = torch.nn.functional.interpolate(x.float(), skip.shape[-2:], mode="bilinear", align_corners=True)
x = torch.nn.functional.interpolate(x.float(), skip.shape[-2:], mode="bilinear",
align_corners=True)
x = block(x, skip)
elif isinstance(block, AttnBlock) or (
hasattr(block, "_fsdp_wrapped_module") and isinstance(block._fsdp_wrapped_module, AttnBlock)
Expand Down Expand Up @@ -937,7 +940,7 @@ def __init__(

self.embedding = nn.Sequential(
nn.PixelUnshuffle(patch_size),
nn.Conv2d(c_in * (patch_size**2), c_hidden[0], kernel_size=1),
nn.Conv2d(c_in * (patch_size ** 2), c_hidden[0], kernel_size=1),
LayerNorm2d(c_hidden[0], elementwise_affine=False, eps=1e-6),
)

Expand Down Expand Up @@ -998,7 +1001,8 @@ def get_block(block_type, c_hidden, nhead, c_skip=0, dropout=0, self_attn=True):
for j in range(blocks[1][::-1][i]):
for k, block_type in enumerate(level_config[i]):
c_skip = c_hidden[i] if i < len(c_hidden) - 1 and j == k == 0 else 0
block = get_block(block_type, c_hidden[i], nhead[i], c_skip=c_skip, dropout=dropout[i], self_attn=self_attn[i])
block = get_block(block_type, c_hidden[i], nhead[i], c_skip=c_skip, dropout=dropout[i],
self_attn=self_attn[i])
up_block.append(block)
self.up_blocks.append(up_block)
if block_repeat is not None:
Expand All @@ -1010,7 +1014,7 @@ def get_block(block_type, c_hidden, nhead, c_skip=0, dropout=0, self_attn=True):
# OUTPUT
self.clf = nn.Sequential(
LayerNorm2d(c_hidden[0], elementwise_affine=False, eps=1e-6),
nn.Conv2d(c_hidden[0], c_out * (patch_size**2), kernel_size=1),
nn.Conv2d(c_hidden[0], c_out * (patch_size ** 2), kernel_size=1),
nn.PixelShuffle(patch_size),
)

Expand Down Expand Up @@ -1088,7 +1092,8 @@ def _down_encode(self, x, r_embed, clip, cnet=None):
if cnet is not None:
next_cnet = cnet()
if next_cnet is not None:
x = x + nn.functional.interpolate(next_cnet, size=x.shape[-2:], mode="bilinear", align_corners=True)
x = x + nn.functional.interpolate(next_cnet, size=x.shape[-2:], mode="bilinear",
align_corners=True)
x = block(x)
elif isinstance(block, AttnBlock) or (
hasattr(block, "_fsdp_wrapped_module") and isinstance(block._fsdp_wrapped_module, AttnBlock)
Expand Down Expand Up @@ -1116,11 +1121,13 @@ def _up_decode(self, level_outputs, r_embed, clip, cnet=None):
):
skip = level_outputs[i] if k == 0 and i > 0 else None
if skip is not None and (x.size(-1) != skip.size(-1) or x.size(-2) != skip.size(-2)):
x = torch.nn.functional.interpolate(x.float(), skip.shape[-2:], mode="bilinear", align_corners=True)
x = torch.nn.functional.interpolate(x.float(), skip.shape[-2:], mode="bilinear",
align_corners=True)
if cnet is not None:
next_cnet = cnet()
if next_cnet is not None:
x = x + nn.functional.interpolate(next_cnet, size=x.shape[-2:], mode="bilinear", align_corners=True)
x = x + nn.functional.interpolate(next_cnet, size=x.shape[-2:], mode="bilinear",
align_corners=True)
x = block(x, skip)
elif isinstance(block, AttnBlock) or (
hasattr(block, "_fsdp_wrapped_module") and isinstance(block._fsdp_wrapped_module, AttnBlock)
Expand Down Expand Up @@ -1264,9 +1271,9 @@ def step(self, x, x0, epsilon, logSNR, logSNR_prev, eta=0):
if len(a_prev.shape) == 1:
a_prev, b_prev = a_prev.view(-1, *[1] * (len(x0.shape) - 1)), b_prev.view(-1, *[1] * (len(x0.shape) - 1))

sigma_tau = eta * (b_prev**2 / b**2).sqrt() * (1 - a**2 / a_prev**2).sqrt() if eta > 0 else 0
sigma_tau = eta * (b_prev ** 2 / b ** 2).sqrt() * (1 - a ** 2 / a_prev ** 2).sqrt() if eta > 0 else 0
# x = a_prev * x0 + (1 - a_prev**2 - sigma_tau ** 2).sqrt() * epsilon + sigma_tau * torch.randn_like(x0)
x = a_prev * x0 + (b_prev**2 - sigma_tau**2).sqrt() * epsilon + sigma_tau * torch.randn_like(x0)
x = a_prev * x0 + (b_prev ** 2 - sigma_tau ** 2).sqrt() * epsilon + sigma_tau * torch.randn_like(x0)
return x


Expand All @@ -1284,16 +1291,30 @@ def step(self, x, x0, epsilon, logSNR, logSNR_prev):


class GDF:
def __init__(self, schedule, input_scaler, target, noise_cond, loss_weight, offset_noise=0):
def __init__(self, schedule, input_scaler, target, noise_cond, loss_weight, num_timesteps=1000, offset_noise=0):
self.schedule = schedule
self.input_scaler = input_scaler
self.target = target
self.noise_cond = noise_cond
self.loss_weight = loss_weight
self.num_timesteps = num_timesteps
self.offset_noise = offset_noise

# Simulate betas and compute alphas and alphas_cumprod
timesteps = torch.linspace(0, 1, num_timesteps)
logSNR = self.schedule(timesteps)
betas = 1 - torch.exp(-logSNR) # This assumes logSNR is log(1/alpha)
self.alphas = 1 - betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)

def sample_timesteps(self, batch_size, min_timestep=0, max_timestep=None):
if max_timestep is None:
max_timestep = self.num_timesteps - 1
return torch.randint(min_timestep, max_timestep + 1, (batch_size,)).long()

def setup_limits(self, stretch_max=True, stretch_min=True, shift=1):
stretched_limits = self.input_scaler.setup_limits(self.schedule, self.input_scaler, stretch_max, stretch_min, shift)
stretched_limits = self.input_scaler.setup_limits(self.schedule, self.input_scaler, stretch_max, stretch_min,
shift)
return stretched_limits

def diffuse(self, x0, epsilon=None, t=None, shift=1, loss_shift=1, offset=None):
Expand All @@ -1310,7 +1331,8 @@ def diffuse(self, x0, epsilon=None, t=None, shift=1, loss_shift=1, offset=None):
target = self.target(x0, epsilon, logSNR, a, b)

# noised, noise, logSNR, t_cond
return x0 * a + epsilon * b, epsilon, target, logSNR, self.noise_cond(logSNR), self.loss_weight(logSNR, shift=loss_shift)
return x0 * a + epsilon * b, epsilon, target, logSNR, self.noise_cond(logSNR), self.loss_weight(logSNR,
shift=loss_shift)

def undiffuse(self, x, logSNR, pred):
a, b = self.input_scaler(logSNR)
Expand Down Expand Up @@ -1343,7 +1365,8 @@ def sample(
sampler = DDPMSampler(self)
r_range = torch.linspace(t_start, t_end, timesteps + 1)
schedule = self.schedule if schedule is None else schedule
logSNR_range = schedule(r_range, shift=shift)[:, None].expand(-1, shape[0] if x_init is None else x_init.size(0)).to(device)
logSNR_range = schedule(r_range, shift=shift)[:, None].expand(-1, shape[0] if x_init is None else x_init.size(
0)).to(device)

x = sampler.init_x(shape).to(device) if x_init is None else x_init.clone()
if cfg is not None:
Expand Down Expand Up @@ -1383,7 +1406,8 @@ def sample(
if isinstance(cfg_val, (list, tuple)):
assert len(cfg_val) == 2, "cfg must be a float or a list/tuple of length 2"
cfg_val = cfg_val[0] * r_range[i].item() + cfg_val[1] * (1 - r_range[i].item())
pred, pred_unconditional = model(torch.cat([x, x], dim=0), noise_cond.repeat(2), **model_inputs).chunk(2)
pred, pred_unconditional = model(torch.cat([x, x], dim=0), noise_cond.repeat(2), **model_inputs).chunk(
2)
pred_cfg = torch.lerp(pred_unconditional, pred, cfg_val)
if cfg_rho > 0:
std_pos, std_cfg = pred.std(), pred_cfg.std()
Expand Down Expand Up @@ -1650,5 +1674,4 @@ def update_buckets(self, logSNR, loss, beta=0.99):
indices = torch.searchsorted(self.bucket_ranges.to(logSNR.device), logSNR).cpu()
self.bucket_losses[indices] = self.bucket_losses[indices] * beta + loss.detach().cpu() * (1 - beta)


# endregion gdf
5 changes: 5 additions & 0 deletions sd_scripts/library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3324,6 +3324,11 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
default=None,
help="Weight for standard deviation loss. Encourages the model to learn noise with a stddev like the true noise. May prevent 'deep fry'. 1.0 is a good starting place.",
)
parser.add_argument(
"--use_sig_loss",
action="store_true",
help="apply mask for calculating loss. conditioning_data_dir is required for dataset. / 損失計算時にマスクを適用する。datasetにはconditioning_data_dirが必要",
)
parser.add_argument(
"--masked_loss_prob",
type=float,
Expand Down
Loading

0 comments on commit 489f84a

Please sign in to comment.