Skip to content

Commit

Permalink
Updated sd-scripts w/ custom loss for sdxl + moved config around to b…
Browse files Browse the repository at this point in the history
…etter location
  • Loading branch information
Jeff Ding committed May 14, 2024
1 parent 322e640 commit a22c9e1
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 57 deletions.
23 changes: 23 additions & 0 deletions sd_scripts/library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3335,6 +3335,29 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
default=1.0,
help="probability of masking loss (default is None) / masked lossの確率(デフォルトはNone)",
)
parser.add_argument(
"--background_weight",
type=float,
default=1.0,
help="Background weight for multi-mask training",
)
parser.add_argument(
"--character_weight",
type=float,
default=1.0,
help="Character weight for multi-mask training",
)
parser.add_argument(
"--detail_weight",
type=float,
default=1.0,
help="Detail weight for multi-mask training",
)
parser.add_argument(
"--weighted_loss",
action="store_true",
help="Use multi-mask weighted loss",
)

if support_dreambooth:
# DreamBooth training
Expand Down
23 changes: 0 additions & 23 deletions sd_scripts/stable_cascade_train_c_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -1209,29 +1209,6 @@ def setup_parser() -> argparse.ArgumentParser:
action="store_true",
help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う",
)
parser.add_argument(
"--background_weight",
type=float,
default=1.0,
help="Background weight for multi-mask training",
)
parser.add_argument(
"--character_weight",
type=float,
default=1.0,
help="Character weight for multi-mask training",
)
parser.add_argument(
"--detail_weight",
type=float,
default=1.0,
help="Detail weight for multi-mask training",
)
parser.add_argument(
"--weighted_loss",
action="store_true",
help="Use multi-mask weighted loss",
)
return parser

def is_decreasing(list, min_delta=0.001):
Expand Down
107 changes: 73 additions & 34 deletions sd_scripts/train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
init_ipex()

from torch.nn.parallel import DistributedDataParallel as DDP
import torch.nn.functional as F

from accelerate.utils import set_seed
from diffusers import DDPMScheduler
Expand Down Expand Up @@ -142,7 +143,7 @@ def all_reduce_network(self, accelerator, network):
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet):
train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet)

def process_batch(self, batch, is_train, train_unet, network, network_has_multiplier, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, step_logs, train_text_encoder=True):
def process_batch(self, batch, is_train, train_unet, alphas_cumprod, network, network_has_multiplier, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, step_logs, train_text_encoder=True):
with torch.no_grad():
if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device)
Expand Down Expand Up @@ -215,38 +216,68 @@ def process_batch(self, batch, is_train, train_unet, network, network_has_multip
else:
target = noise

loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
if args.masked_loss and np.random.rand() < args.masked_loss_prob:
loss, noise_mask = apply_multichannel_masked_loss(loss, batch, 1.0, 1.0, 1.0)
if args.use_sig_loss:
mae_loss = F.l1_loss(noise_pred, target, reduction="none")
mse_loss = F.mse_loss(noise_pred, target, reduction="none")
base_loss = 1/-mse_loss.exp() + 1

ac = alphas_cumprod[timesteps]
loss = base_loss.mean(dim=(2, 3), keepdims=True) * ac.sqrt()
loss = loss + base_loss.std(dim=(2,3), keepdims=True) * (1-ac).sqrt()

if args.masked_loss and np.random.rand() < args.masked_loss_prob:
loss = apply_masked_loss(loss, batch)

loss = loss.mean(dim=(1,2,3))

loss_weights = batch["loss_weights"] # 各sampleごとのweight
loss = loss * loss_weights

if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
if args.scale_v_pred_loss_like_noise_pred:
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
if args.v_pred_like_loss:
loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)
if args.debiased_estimation_loss:
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)


loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
else:
noise_mask = torch.ones_like(noise, device=noise.device)
loss = loss.mean([1, 2, 3])

loss_weights = batch["loss_weights"].to(accelerator.device) # 各sampleごとのweight
loss = loss * loss_weights

if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
if args.scale_v_pred_loss_like_noise_pred:
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
if args.v_pred_like_loss:
loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)
if args.debiased_estimation_loss:
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)

pred_std, pred_skews, pred_kurtoses = train_util.noise_stats(noise_pred * noise_mask)
true_std, true_skews, true_kurtoses = train_util.noise_stats(noise * noise_mask)

if args.std_loss_weight is not None:
std_loss = torch.nn.functional.mse_loss(pred_std, true_std, reduction="none")
loss = loss + std_loss * args.std_loss_weight

step_logs["metrics/noise_pred_std"] = pred_std.mean().item()
step_logs["metrics/noise_pred_mean"] = noise_pred.mean()
step_logs["metrics/std_divergence"] = true_std.mean().item() - pred_std.mean().item()
step_logs["metrics/skew_divergence"] = true_skews.mean().item() - pred_skews.mean().item()
step_logs["metrics/kurtosis_divergence"] = true_kurtoses.mean().item() - pred_kurtoses.mean().item()
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
if args.masked_loss and np.random.rand() < args.masked_loss_prob:
loss, noise_mask = apply_multichannel_masked_loss(loss, batch, 1.0, 1.0, 1.0)
noise_mask = torch.ones_like(noise, device=noise.device)
else:
noise_mask = torch.ones_like(noise, device=noise.device)
loss = loss.mean([1, 2, 3])

loss_weights = batch["loss_weights"].to(accelerator.device) # 各sampleごとのweight
loss = loss * loss_weights

if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
if args.scale_v_pred_loss_like_noise_pred:
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
if args.v_pred_like_loss:
loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)
if args.debiased_estimation_loss:
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)

pred_std, pred_skews, pred_kurtoses = train_util.noise_stats(noise_pred * noise_mask)
true_std, true_skews, true_kurtoses = train_util.noise_stats(noise * noise_mask)

if args.std_loss_weight is not None:
std_loss = torch.nn.functional.mse_loss(pred_std, true_std, reduction="none")
loss = loss + std_loss * args.std_loss_weight

step_logs["metrics/noise_pred_std"] = pred_std.mean().item()
step_logs["metrics/noise_pred_mean"] = noise_pred.mean()
step_logs["metrics/std_divergence"] = true_std.mean().item() - pred_std.mean().item()
step_logs["metrics/skew_divergence"] = true_skews.mean().item() - pred_skews.mean().item()
step_logs["metrics/kurtosis_divergence"] = true_kurtoses.mean().item() - pred_kurtoses.mean().item()
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
return loss


Expand Down Expand Up @@ -419,10 +450,18 @@ def train(self, args):
# workaround for LyCORIS (;^ω^)
net_kwargs["dropout"] = args.network_dropout

if args.text_encoder_dim is None:
args.text_encoder_dim = args.network_dim

if args.text_encoder_alpha is None:
args.text_encoder_alpha = args.network_alpha

network = network_module.create_network(
1.0,
args.network_dim,
args.network_alpha,
args.text_encoder_dim,
args.text_encoder_alpha,
vae,
text_encoder,
unet,
Expand Down Expand Up @@ -878,14 +917,14 @@ def remove_model(old_ckpt_name):
accelerator.unwrap_model(network).on_epoch_start(text_encoder, unet)

# TRAINING

alphas_cumprod = noise_scheduler.alphas_cumprod.to(accelerator.device)
for step, batch in enumerate(train_dataloader):
step_logs = {}
current_step.value = global_step
with accelerator.accumulate(network):
on_step_start(text_encoder, unet)
is_train = True
loss = self.process_batch(batch, is_train, train_unet, network, network_has_multiplier, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, step_logs, train_text_encoder=train_text_encoder)
loss = self.process_batch(batch, is_train, train_unet, alphas_cumprod, network, network_has_multiplier, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, step_logs, train_text_encoder=train_text_encoder)

accelerator.backward(loss)
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
Expand Down

0 comments on commit a22c9e1

Please sign in to comment.