diff --git a/sd_scripts/library/train_util.py b/sd_scripts/library/train_util.py index cb7456f2c..6e92976bf 100644 --- a/sd_scripts/library/train_util.py +++ b/sd_scripts/library/train_util.py @@ -3335,6 +3335,29 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: default=1.0, help="probability of masking loss (default is None) / masked lossの確率(デフォルトはNone)", ) + parser.add_argument( + "--background_weight", + type=float, + default=1.0, + help="Background weight for multi-mask training", + ) + parser.add_argument( + "--character_weight", + type=float, + default=1.0, + help="Character weight for multi-mask training", + ) + parser.add_argument( + "--detail_weight", + type=float, + default=1.0, + help="Detail weight for multi-mask training", + ) + parser.add_argument( + "--weighted_loss", + action="store_true", + help="Use multi-mask weighted loss", + ) if support_dreambooth: # DreamBooth training diff --git a/sd_scripts/stable_cascade_train_c_network.py b/sd_scripts/stable_cascade_train_c_network.py index 4e0a620d3..19edcbbb0 100644 --- a/sd_scripts/stable_cascade_train_c_network.py +++ b/sd_scripts/stable_cascade_train_c_network.py @@ -1209,29 +1209,6 @@ def setup_parser() -> argparse.ArgumentParser: action="store_true", help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う", ) - parser.add_argument( - "--background_weight", - type=float, - default=1.0, - help="Background weight for multi-mask training", - ) - parser.add_argument( - "--character_weight", - type=float, - default=1.0, - help="Character weight for multi-mask training", - ) - parser.add_argument( - "--detail_weight", - type=float, - default=1.0, - help="Detail weight for multi-mask training", - ) - parser.add_argument( - "--weighted_loss", - action="store_true", - help="Use multi-mask weighted loss", - ) return parser def is_decreasing(list, min_delta=0.001): diff --git a/sd_scripts/train_network.py b/sd_scripts/train_network.py index 4433e3a19..2bd83a9fa 100644 --- a/sd_scripts/train_network.py +++ b/sd_scripts/train_network.py @@ -20,6 +20,7 @@ init_ipex() from torch.nn.parallel import DistributedDataParallel as DDP +import torch.nn.functional as F from accelerate.utils import set_seed from diffusers import DDPMScheduler @@ -142,7 +143,7 @@ def all_reduce_network(self, accelerator, network): def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet): train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet) - def process_batch(self, batch, is_train, train_unet, network, network_has_multiplier, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, step_logs, train_text_encoder=True): + def process_batch(self, batch, is_train, train_unet, alphas_cumprod, network, network_has_multiplier, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, step_logs, train_text_encoder=True): with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: latents = batch["latents"].to(accelerator.device) @@ -215,38 +216,68 @@ def process_batch(self, batch, is_train, train_unet, network, network_has_multip else: target = noise - loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") - if args.masked_loss and np.random.rand() < args.masked_loss_prob: - loss, noise_mask = apply_multichannel_masked_loss(loss, batch, 1.0, 1.0, 1.0) + if args.use_sig_loss: + mae_loss = F.l1_loss(noise_pred, target, reduction="none") + mse_loss = F.mse_loss(noise_pred, target, reduction="none") + base_loss = 1/-mse_loss.exp() + 1 + + ac = alphas_cumprod[timesteps] + loss = base_loss.mean(dim=(2, 3), keepdims=True) * ac.sqrt() + loss = loss + base_loss.std(dim=(2,3), keepdims=True) * (1-ac).sqrt() + + if args.masked_loss and np.random.rand() < args.masked_loss_prob: + loss = apply_masked_loss(loss, batch) + + loss = loss.mean(dim=(1,2,3)) + + loss_weights = batch["loss_weights"] # 各sampleごとのweight + loss = loss * loss_weights + + if args.min_snr_gamma: + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) + if args.scale_v_pred_loss_like_noise_pred: + loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) + if args.v_pred_like_loss: + loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) + if args.debiased_estimation_loss: + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) + + + loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし else: - noise_mask = torch.ones_like(noise, device=noise.device) - loss = loss.mean([1, 2, 3]) - - loss_weights = batch["loss_weights"].to(accelerator.device) # 各sampleごとのweight - loss = loss * loss_weights - - if args.min_snr_gamma: - loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) - if args.scale_v_pred_loss_like_noise_pred: - loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) - if args.v_pred_like_loss: - loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) - if args.debiased_estimation_loss: - loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) - - pred_std, pred_skews, pred_kurtoses = train_util.noise_stats(noise_pred * noise_mask) - true_std, true_skews, true_kurtoses = train_util.noise_stats(noise * noise_mask) - - if args.std_loss_weight is not None: - std_loss = torch.nn.functional.mse_loss(pred_std, true_std, reduction="none") - loss = loss + std_loss * args.std_loss_weight - - step_logs["metrics/noise_pred_std"] = pred_std.mean().item() - step_logs["metrics/noise_pred_mean"] = noise_pred.mean() - step_logs["metrics/std_divergence"] = true_std.mean().item() - pred_std.mean().item() - step_logs["metrics/skew_divergence"] = true_skews.mean().item() - pred_skews.mean().item() - step_logs["metrics/kurtosis_divergence"] = true_kurtoses.mean().item() - pred_kurtoses.mean().item() - loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし + loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + if args.masked_loss and np.random.rand() < args.masked_loss_prob: + loss, noise_mask = apply_multichannel_masked_loss(loss, batch, 1.0, 1.0, 1.0) + noise_mask = torch.ones_like(noise, device=noise.device) + else: + noise_mask = torch.ones_like(noise, device=noise.device) + loss = loss.mean([1, 2, 3]) + + loss_weights = batch["loss_weights"].to(accelerator.device) # 各sampleごとのweight + loss = loss * loss_weights + + if args.min_snr_gamma: + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) + if args.scale_v_pred_loss_like_noise_pred: + loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) + if args.v_pred_like_loss: + loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) + if args.debiased_estimation_loss: + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) + + pred_std, pred_skews, pred_kurtoses = train_util.noise_stats(noise_pred * noise_mask) + true_std, true_skews, true_kurtoses = train_util.noise_stats(noise * noise_mask) + + if args.std_loss_weight is not None: + std_loss = torch.nn.functional.mse_loss(pred_std, true_std, reduction="none") + loss = loss + std_loss * args.std_loss_weight + + step_logs["metrics/noise_pred_std"] = pred_std.mean().item() + step_logs["metrics/noise_pred_mean"] = noise_pred.mean() + step_logs["metrics/std_divergence"] = true_std.mean().item() - pred_std.mean().item() + step_logs["metrics/skew_divergence"] = true_skews.mean().item() - pred_skews.mean().item() + step_logs["metrics/kurtosis_divergence"] = true_kurtoses.mean().item() - pred_kurtoses.mean().item() + loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし return loss @@ -419,10 +450,18 @@ def train(self, args): # workaround for LyCORIS (;^ω^) net_kwargs["dropout"] = args.network_dropout + if args.text_encoder_dim is None: + args.text_encoder_dim = args.network_dim + + if args.text_encoder_alpha is None: + args.text_encoder_alpha = args.network_alpha + network = network_module.create_network( 1.0, args.network_dim, args.network_alpha, + args.text_encoder_dim, + args.text_encoder_alpha, vae, text_encoder, unet, @@ -878,14 +917,14 @@ def remove_model(old_ckpt_name): accelerator.unwrap_model(network).on_epoch_start(text_encoder, unet) # TRAINING - + alphas_cumprod = noise_scheduler.alphas_cumprod.to(accelerator.device) for step, batch in enumerate(train_dataloader): step_logs = {} current_step.value = global_step with accelerator.accumulate(network): on_step_start(text_encoder, unet) is_train = True - loss = self.process_batch(batch, is_train, train_unet, network, network_has_multiplier, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, step_logs, train_text_encoder=train_text_encoder) + loss = self.process_batch(batch, is_train, train_unet, alphas_cumprod, network, network_has_multiplier, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, step_logs, train_text_encoder=train_text_encoder) accelerator.backward(loss) if accelerator.sync_gradients and args.max_grad_norm != 0.0: