Skip to content

Commit

Permalink
Added std loss and updates to masked loss
Browse files Browse the repository at this point in the history
  • Loading branch information
Jeff Ding committed Apr 25, 2024
1 parent bf6268b commit 3f10c45
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 43 deletions.
2 changes: 1 addition & 1 deletion sd_scripts/library/custom_train_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,7 +481,7 @@ def apply_masked_loss(loss, batch):
mask_image = torch.nn.functional.interpolate(mask_image, size=loss.shape[2:], mode="area")
mask_image = mask_image / 2 + 0.5
loss = loss * mask_image
return loss
return loss, mask_image

"""
##########################################
Expand Down
56 changes: 51 additions & 5 deletions sd_scripts/library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3314,6 +3314,19 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
help="tags for model metadata, separated by comma / メタデータに書き込まれるモデルタグ、カンマ区切り",
)

parser.add_argument(
"--std_loss_weight",
type=float,
default=None,
help="Weight for standard deviation loss. Encourages the model to learn noise with a stddev like the true noise. May prevent 'deep fry'. 1.0 is a good starting place.",
)
parser.add_argument(
"--masked_loss_prob",
type=float,
default=1.0,
help="probability of masking loss (default is None) / masked lossの確率(デフォルトはNone)",
)

if support_dreambooth:
# DreamBooth training
parser.add_argument(
Expand Down Expand Up @@ -4877,13 +4890,13 @@ def get_my_scheduler(
elif sample_sampler == "dpmsolver" or sample_sampler == "dpmsolver++":
scheduler_cls = DPMSolverMultistepScheduler
sched_init_args["algorithm_type"] = sample_sampler
if args.sample_sampler == 'k-sde-dpmsolver++' or args.sample_sampler == 'k-dpmsolver++':
sched_init_args["algorithm_type"] = args.sample_sampler[2:]
if sample_sampler == 'k-sde-dpmsolver++' or sample_sampler == 'k-dpmsolver++':
sched_init_args["algorithm_type"] = sample_sampler[2:]
sched_init_args["use_karras_sigmas"] = True
if args.sample_sampler == 'lu-sde-dpmsolver++':
sched_init_args["algorithm_type"] = args.sample_sampler[3:]
if sample_sampler == 'lu-sde-dpmsolver++':
sched_init_args["algorithm_type"] = sample_sampler[3:]
sched_init_args["use_lu_lambdas"] = True
if args.sample_sampler == 'k-sde-dpmsolver++' or args.sample_sampler == 'sde-dpmsolver++' or args.sample_sampler == 'lu-sde-dpmsolver++':
if sample_sampler == 'k-sde-dpmsolver++' or sample_sampler == 'sde-dpmsolver++' or sample_sampler == 'lu-sde-dpmsolver++':
sched_init_args["euler_at_final"] = True
scheduler_module = diffusers.schedulers.scheduling_dpmsolver_multistep
elif sample_sampler == "dpmsingle":
Expand Down Expand Up @@ -5279,3 +5292,36 @@ def add(self, *, epoch: int, step: int, loss: float) -> None:
@property
def moving_average(self) -> float:
return self.loss_total / len(self.loss_list)

def noise_stats(noise):
diff = noise - noise.mean(dim=(1,2,3), keepdim=True)
std = noise.std(dim=(1,2,3))
zscores = diff / std[:, None, None, None]
skews = (zscores**3).mean(dim=(1,2,3))
kurtoses = (zscores**4).mean(dim=(1,2,3)) - 3.0
return std, skews, kurtoses

def stat_losses(noise, noise_pred, std_loss_weight=0.5, kl_loss_weight=3e-3, skew_loss_weight=0, kurtosis_loss_weight=0):
std_loss = torch.nn.functional.mse_loss(
noise_pred.std(dim=(1,2,3)),
noise.std(dim=(1,2,3)),
reduction="none") * std_loss_weight

skew_pred, kurt_pred = noise_stats(noise_pred)
skew_true, kurt_true = noise_stats(noise)

skew_loss = torch.nn.functional.mse_loss(skew_pred, skew_true, reduction="none") * skew_loss_weight
kurt_loss = torch.nn.functional.mse_loss(kurt_pred, kurt_true, reduction="none") * kurtosis_loss_weight

p1s = []
p2s = []
for i, v in enumerate(noise_pred):
n = noise[i]
p1s.append(torch.histc(v.float(), bins=500, min=n.min(), max=n.max()) + 1e-6)
p2s.append(torch.histc(n.float(), bins=500) + 1e-6)
p1 = torch.stack(p1s)
p2 = torch.stack(p2s)

kl_loss = torch.nn.functional.kl_div(p1.log(), p2, reduction="none").mean(dim=1) * kl_loss_weight

return std_loss, skew_loss, kurt_loss, kl_loss
74 changes: 47 additions & 27 deletions sd_scripts/stable_cascade_train_c_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from .library.device_utils import init_ipex, clean_memory_on_device
init_ipex()


import numpy as np
from accelerate.utils import set_seed
from diffusers import DDPMScheduler

Expand Down Expand Up @@ -52,9 +52,9 @@ def __init__(self):

# TODO 他のスクリプトと共通化する
def generate_step_logs(
self, args: argparse.Namespace, current_loss, avr_loss, lr_scheduler, keys_scaled=None, mean_norm=None, maximum_norm=None
self, args: argparse.Namespace, current_loss, avr_loss, lr_scheduler, keys_scaled=None, mean_norm=None, maximum_norm=None, extra={}
):
logs = {"loss/current": current_loss, "loss/average": avr_loss}
logs = {"loss/current": current_loss, "loss/average": avr_loss, **extra}

if keys_scaled is not None:
logs["max_norm/keys_scaled"] = keys_scaled
Expand Down Expand Up @@ -844,6 +844,7 @@ def remove_model(old_ckpt_name):
accelerator.unwrap_model(network).on_epoch_start(text_encoder, stage_c)

for step, batch in enumerate(train_dataloader):
step_logs = {}
current_step.value = global_step
with accelerator.accumulate(network):
on_step_start(text_encoder, stage_c)
Expand Down Expand Up @@ -897,10 +898,28 @@ def remove_model(old_ckpt_name):
noised, noise_cond, clip_text=encoder_hidden_states, clip_text_pooled=pool, clip_img=zero_img_emb
)
loss = torch.nn.functional.mse_loss(pred.float(), target.float(), reduction="none")
if args.masked_loss:
loss = apply_masked_loss(loss, batch)

if args.masked_loss and np.random.rand() < args.masked_loss_prob:
loss, noise_mask = apply_masked_loss(loss, batch)
else:
noise_mask = torch.ones_like(noise, device=noise.device)

loss = loss.mean(dim=[1, 2, 3])
loss_adjusted = (loss * loss_weight).mean()
loss_adjusted = (loss * loss_weight)

pred_std, pred_skews, pred_kurtoses = train_util.noise_stats(pred * noise_mask)
true_std, true_skews, true_kurtoses = train_util.noise_stats(noise * noise_mask)

if args.std_loss_weight is not None:
std_loss = torch.nn.functional.mse_loss(pred_std, true_std, reduction="none")
loss = loss + std_loss * args.std_loss_weight

step_logs["metrics/noise_pred_std"] = pred_std.mean().item()
step_logs["metrics/noise_pred_mean"] = pred.mean()
step_logs["metrics/std_divergence"] = true_std.mean().item() - pred_std.mean().item()
step_logs["metrics/skew_divergence"] = true_skews.mean().item() - pred_skews.mean().item()
step_logs["metrics/kurtosis_divergence"] = true_kurtoses.mean().item() - pred_kurtoses.mean().item()
loss_adjusted = loss_adjusted.mean()

if args.adaptive_loss_weight:
gdf.loss_weight.update_buckets(logSNR, loss) # use loss instead of loss_adjusted
Expand Down Expand Up @@ -956,33 +975,34 @@ def remove_model(old_ckpt_name):
progress_bar.set_postfix(**{**max_mean_logs, **logs})

if args.logging_dir is not None:
logs = self.generate_step_logs(args, current_loss, avr_loss, lr_scheduler, keys_scaled, mean_norm, maximum_norm)
logs = self.generate_step_logs(args, current_loss, avr_loss, lr_scheduler, keys_scaled, mean_norm, maximum_norm, step_logs)
accelerator.log(logs, step=global_step)

if global_step >= args.max_train_steps:
break

if len(val_dataloader) > 0:
print("Validating バリデーション処理...")

with torch.no_grad():
for val_step, batch in enumerate(val_dataloader):
is_train = False
loss = self.process_batch(batch, is_train, train_unet, network, network_has_multiplier, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args)

current_loss = loss.detach().item()
val_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss)

if args.logging_dir is not None:
avr_loss: float = val_loss_recorder.moving_average
logs = {"loss/validation_current": current_loss}
accelerator.log(logs, step=(len(val_dataloader) * epoch) + 1 + val_step)

if len(val_dataloader) > 0:
if args.logging_dir is not None:
avr_loss: float = val_loss_recorder.moving_average
logs = {"loss/validation_average": avr_loss}
accelerator.log(logs, step=epoch + 1)
print("Currently does not support validation set")
sys.exit("Error: Please set validation set to 0")

# with torch.no_grad():
# for val_step, batch in enumerate(val_dataloader):
# is_train = False
# loss = self.process_batch(batch, is_train, train_unet, network, network_has_multiplier, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args)
#
# current_loss = loss.detach().item()
# val_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss)
#
# if args.logging_dir is not None:
# avr_loss: float = val_loss_recorder.moving_average
# logs = {"loss/validation_current": current_loss}
# accelerator.log(logs, step=(len(val_dataloader) * epoch) + 1 + val_step)
#
# if len(val_dataloader) > 0:
# if args.logging_dir is not None:
# avr_loss: float = val_loss_recorder.moving_average
# logs = {"loss/validation_average": avr_loss}
# accelerator.log(logs, step=epoch + 1)

if args.logging_dir is not None:
logs = {"loss/epoch": loss_recorder.moving_average}
Expand Down
35 changes: 25 additions & 10 deletions sd_scripts/train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from accelerate.utils import set_seed
from diffusers import DDPMScheduler
from .library import model_util

import numpy as np
from .library import train_util
from .library.train_util import (
DreamBoothDataset,
Expand Down Expand Up @@ -60,9 +60,9 @@ def __init__(self):

# TODO 他のスクリプトと共通化する
def generate_step_logs(
self, args: argparse.Namespace, current_loss, avr_loss, lr_scheduler, keys_scaled=None, mean_norm=None, maximum_norm=None
self, args: argparse.Namespace, current_loss, avr_loss, lr_scheduler, keys_scaled=None, mean_norm=None, maximum_norm=None, extra={}
):
logs = {"loss/current": current_loss, "loss/average": avr_loss}
logs = {"loss/current": current_loss, "loss/average": avr_loss, **extra}

if keys_scaled is not None:
logs["max_norm/keys_scaled"] = keys_scaled
Expand Down Expand Up @@ -141,7 +141,7 @@ def all_reduce_network(self, accelerator, network):
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet):
train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet)

def process_batch(self, batch, is_train, train_unet, network, network_has_multiplier, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, train_text_encoder=True):
def process_batch(self, batch, is_train, train_unet, network, network_has_multiplier, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, step_logs, train_text_encoder=True):
with torch.no_grad():
if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device)
Expand Down Expand Up @@ -215,8 +215,10 @@ def process_batch(self, batch, is_train, train_unet, network, network_has_multip
target = noise

loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
if args.masked_loss:
loss = apply_masked_loss(loss, batch)
if args.masked_loss and np.random.rand() < args.masked_loss_prob:
loss, noise_mask = apply_masked_loss(loss, batch)
else:
noise_mask = torch.ones_like(noise, device=noise.device)
loss = loss.mean([1, 2, 3])

loss_weights = batch["loss_weights"].to(accelerator.device) # 各sampleごとのweight
Expand All @@ -231,8 +233,19 @@ def process_batch(self, batch, is_train, train_unet, network, network_has_multip
if args.debiased_estimation_loss:
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)

loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
pred_std, pred_skews, pred_kurtoses = train_util.noise_stats(noise_pred * noise_mask)
true_std, true_skews, true_kurtoses = train_util.noise_stats(noise * noise_mask)

if args.std_loss_weight is not None:
std_loss = torch.nn.functional.mse_loss(pred_std, true_std, reduction="none")
loss = loss + std_loss * args.std_loss_weight

step_logs["metrics/noise_pred_std"] = pred_std.mean().item()
step_logs["metrics/noise_pred_mean"] = noise_pred.mean()
step_logs["metrics/std_divergence"] = true_std.mean().item() - pred_std.mean().item()
step_logs["metrics/skew_divergence"] = true_skews.mean().item() - pred_skews.mean().item()
step_logs["metrics/kurtosis_divergence"] = true_kurtoses.mean().item() - pred_kurtoses.mean().item()
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
return loss


Expand Down Expand Up @@ -866,11 +879,12 @@ def remove_model(old_ckpt_name):
# TRAINING

for step, batch in enumerate(train_dataloader):
step_logs = {}
current_step.value = global_step
with accelerator.accumulate(network):
on_step_start(text_encoder, unet)
is_train = True
loss = self.process_batch(batch, is_train, train_unet, network, network_has_multiplier, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, train_text_encoder=train_text_encoder)
loss = self.process_batch(batch, is_train, train_unet, network, network_has_multiplier, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, step_logs, train_text_encoder=train_text_encoder)

accelerator.backward(loss)
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
Expand Down Expand Up @@ -927,7 +941,7 @@ def remove_model(old_ckpt_name):
progress_bar.set_postfix(**{**max_mean_logs, **logs})

if args.logging_dir is not None:
logs = self.generate_step_logs(args, current_loss, avr_loss, lr_scheduler, keys_scaled, mean_norm, maximum_norm)
logs = self.generate_step_logs(args, current_loss, avr_loss, lr_scheduler, keys_scaled, mean_norm, maximum_norm, step_logs)
accelerator.log(logs, step=global_step)

if global_step >= args.max_train_steps:
Expand All @@ -940,8 +954,9 @@ def remove_model(old_ckpt_name):

with torch.no_grad():
for val_step, batch in enumerate(val_dataloader):
step_logs = {}
is_train = False
loss = self.process_batch(batch, is_train, train_unet, network, network_has_multiplier, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args)
loss = self.process_batch(batch, is_train, train_unet, network, network_has_multiplier, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, step_logs)

current_loss = loss.detach().item()
val_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss)
Expand Down

0 comments on commit 3f10c45

Please sign in to comment.