diff --git a/sd_scripts/library/autostats.py b/sd_scripts/library/autostats.py new file mode 100644 index 000000000..bc81f5816 --- /dev/null +++ b/sd_scripts/library/autostats.py @@ -0,0 +1,110 @@ +import torch +import torch.nn.functional as F +import os +from safetensors import safe_open +import numpy as np + +from .utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) + +standard_normal_distribution = torch.distributions.normal.Normal(torch.tensor([0.0]), torch.tensor([1.0])) +def smooth(probs, step_size=0.5): + kernel = standard_normal_distribution.log_prob(torch.arange(-torch.pi, torch.pi, step_size) ).exp().to(probs.device) + smoothed = F.conv1d(probs[None, None, :].float(), kernel[None, None, :].float(), padding="same").reshape(-1) + return smoothed / smoothed.sum() + +def kerras_timesteps(n, sigma_min=0.001, sigma_max=10.0): + alpha_min = torch.arctan(torch.tensor(sigma_min)) + alpha_max = torch.arctan(torch.tensor(sigma_max)) + step_indices = torch.arange(n) + sigmas = torch.tan(step_indices / n * alpha_min + (1.0 - step_indices / n) * alpha_max) + return sigmas + +# cribbed from A111 +def read_metadata_from_safetensors(filename): + import json + + with open(filename, mode="rb") as file: + metadata_len = file.read(8) + metadata_len = int.from_bytes(metadata_len, "little") + json_start = file.read(2) + + assert metadata_len > 2 and json_start in (b'{"', b"{'"), f"{filename} is not a safetensors file" + + res = {} + try: + json_data = json_start + file.read(metadata_len-2) + json_obj = json.loads(json_data) + for k, v in json_obj.get("__metadata__", {}).items(): + res[k] = v + if isinstance(v, str) and v[0:1] == '{': + try: + res[k] = json.loads(v) + except Exception: + pass + except Exception: + logger.error(f"Error reading metadata from file: {filename}", exc_info=True) + + return res + +def interp_forward(t, timesteps): + p = t.permute(1, 0).float().cpu().numpy() # Switch to channel-first and flip the order from first-denoised to first-noised + rev_ts = torch.tensor(timesteps).tolist() # Reverse the timesteps from denoising order to noising order + xs = np.arange(0, 1000) + t = torch.stack([torch.tensor(list(np.interp(xs, rev_ts, p[i]))) for i in range(0, 4)]) + return t.permute(1, 0).to(t.device) + +def load_model_noise_stats(args): + if args.autostats is None or not os.path.exists(args.autostats): + return None, None + with safe_open(args.autostats, framework="pt") as f: + observations = f.get_tensor("observations") + timesteps = f.get_tensor("timesteps") + return transform_observations(observations, timesteps) + +def transform_observations(observations, timesteps): + # shape is [timestep, sample, channels, h, w] + # we average on sample, h, w so that we get stats for [timestep, channel] + + means = observations.mean(dim=(1, 3, 4)) + stds = observations.std(dim=(1, 3, 4)) + return interp_forward(means, timesteps), interp_forward(stds, timesteps) + +def autostats(args, generator): + timestep_probs = torch.ones(1000) + std_target_by_ts = mean_target_by_ts = scaled_std_target_by_ts = scaled_mean_target_by_ts = None + + mean_target_by_ts, std_target_by_ts = load_model_noise_stats(args) + if mean_target_by_ts is None: + generator() + mean_target_by_ts, std_target_by_ts = load_model_noise_stats(args) + + if mean_target_by_ts is None: + raise ValueError("Could not load noise stats from model") + + std_target_by_ts = std_target_by_ts.view(-1, 4, 1, 1) + mean_target_by_ts = mean_target_by_ts.view(-1, 4, 1, 1) + + std_weighting = (std_target_by_ts - 1).abs() + std_weighting = std_weighting / std_weighting.max(dim=0).values + + mean_weighting = mean_target_by_ts.abs() + mean_weighting = mean_weighting / mean_weighting.max(dim=0).values + + effect_scale = args.autostats_true_noise_weight + scaled_std_target_by_ts = (std_target_by_ts - 1.0) * effect_scale[0] + 1.0 + scaled_mean_target_by_ts = (mean_target_by_ts * effect_scale[1]) + + if args.autostats_timestep_weighting: + timestep_probs = (std_target_by_ts - 1).abs().mean(dim=1).reshape(-1) + timestep_probs[:15] = timestep_probs[15] + timestep_probs = smooth(timestep_probs) + + timestep_probs = timestep_probs / timestep_probs.sum() + + print("std", scaled_std_target_by_ts.view(-1, 4)) + print("mean", scaled_mean_target_by_ts.view(-1, 4)) + + return std_target_by_ts, mean_target_by_ts, scaled_std_target_by_ts, scaled_mean_target_by_ts, timestep_probs diff --git a/sd_scripts/library/custom_train_functions.py b/sd_scripts/library/custom_train_functions.py index 6d04a6b06..ac8adbea7 100644 --- a/sd_scripts/library/custom_train_functions.py +++ b/sd_scripts/library/custom_train_functions.py @@ -485,15 +485,20 @@ def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale): return noise -def apply_masked_loss(loss, batch): +def get_mask(batch, latents): # mask image is -1 to 1. we need to convert it to 0 to 1 - mask_image = batch["conditioning_images"].to(dtype=loss.dtype)[:, 0].unsqueeze(1) # use R channel + mask_image = batch["conditioning_images"].to(dtype=latents.dtype)[:, 0].unsqueeze(1) # use R channel # resize to the same size as the loss - mask_image = torch.nn.functional.interpolate(mask_image, size=loss.shape[2:], mode="area") + mask_image = torch.nn.functional.interpolate(mask_image, size=latents.shape[2:], mode="area") mask_image = mask_image / 2 + 0.5 + return mask_image + + +def apply_masked_loss(loss, batch): + mask_image = get_mask(batch, loss) loss = loss * mask_image - return loss, mask_image + return loss ## Custom loss function for weighing a character, and specifical details of the character, differently from the background diff --git a/sd_scripts/library/train_util.py b/sd_scripts/library/train_util.py index 41f02d9b6..f7a393506 100644 --- a/sd_scripts/library/train_util.py +++ b/sd_scripts/library/train_util.py @@ -724,7 +724,7 @@ def process_caption(self, subset: BaseSubset, caption): tokens_len = ( math.floor( (self.current_step) * ( - (len(flex_tokens) - subset.token_warmup_min) / (subset.token_warmup_step)) + (len(flex_tokens) - subset.token_warmup_min) / (subset.token_warmup_step)) ) + subset.token_warmup_min ) @@ -3076,7 +3076,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed") parser.add_argument( "--gradient_checkpointing", action="store_true", - help="enable gradient checkpointing / grandient checkpointingを有効にする" + help="enable gradient checkpointing / gradient checkpointingを有効にする" ) parser.add_argument( "--gradient_accumulation_steps", @@ -3219,7 +3219,13 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: default=None, help="set maximum time step for U-Net training (1~1000, default is 1000) / U-Net学習時のtime stepの最大値を設定する(1~1000で指定、省略時はデフォルト値(1000))", ) - + parser.add_argument( + "--loss_type", + type=str, + default="l2", + choices=["l2", "huber", "smooth_l1"], + help="The type of loss function to use (L2, Huber, or smooth L1), default is L2 / 使用する損失関数の種類(L2、Huber、またはsmooth L1)、デフォルトはL2", + ) parser.add_argument( "--lowram", action="store_true", @@ -3330,6 +3336,66 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: default=None, help="Weight for standard deviation loss. Encourages the model to learn noise with a stddev like the true noise. May prevent 'deep fry'. 1.0 is a good starting place.", ) + parser.add_argument( + "--autostats", + type=str, + default=None, + help="If set, generate and use autostats" + ) + parser.add_argument( + "--use_sts_loss", + action="store_true", + help="Use sts loss", + ) + parser.add_argument( + "--autostats_prompts", + type=str, + default=None, + nargs="*", + help='Prompts to use for autostats collection', + ) + parser.add_argument( + "--autostats_batch_size", + type=int, + default=1, + help='Number of prompts to process at a time', + ) + parser.add_argument( + "--autostats_true_noise_weight", + type=float, + nargs=2, + default=[1.0, 1.0], + help='Effect size; larger means more detail. Arg 0 is standard deviation, arg 1 is mean.', + ) + parser.add_argument( + "--autostats_loss_weights", + type=float, + nargs=2, + default=[1.0, 1.0], + help='Loss weights for std/mean', + ) + parser.add_argument( + "--autostats_decay_rate", + type=float, + default=0.0, + help='Decay rate for autostats loss weight. 0 = no decay.', + ) + parser.add_argument( + "--autostats_effect_min", + type=float, + default=0.0, + help='Minimum effect size for autostats', + ) + parser.add_argument( + "--autostats_timestep_weighting", + action="store_true", + help='Weight timestep selection probability by autostats', + ) + parser.add_argument( + "--autostats_dynamic_timestep_weighting", + action="store_true", + help='When using timestep weighting, dynamically adjust the weighting based on observed losses', + ) parser.add_argument( "--use_sig_loss", action="store_true", @@ -3371,6 +3437,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: "--prior_loss_weight", type=float, default=1.0, help="loss weight for regularization images / 正則化画像のlossの重み" ) + def add_masked_loss_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--conditioning_data_dir", @@ -3384,6 +3451,7 @@ def add_masked_loss_arguments(parser: argparse.ArgumentParser): help="apply mask for calculating loss. conditioning_data_dir is required for dataset. / 損失計算時にマスクを適用する。datasetにはconditioning_data_dirが必要", ) + def verify_training_args(args: argparse.Namespace): r""" Verify training arguments. Also reflect highvram option to global variable @@ -4186,7 +4254,7 @@ def load_tokenizer(args: argparse.Namespace): return tokenizer -def prepare_accelerator(args: argparse.Namespace): +def prepare_accelerator(args: argparse.Namespace) -> Accelerator: if args.logging_dir is None: logging_dir = None else: @@ -4847,33 +4915,108 @@ def save_sd_model_on_train_end_common( huggingface_util.upload(args, out_dir, "/" + model_name, force_sync_upload=True) -def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents): +def get_timesteps(args, noise_scheduler, probs, b_size, device): + min_timestep = 0 if args.min_timestep is None else args.min_timestep + max_timestep = noise_scheduler.config.num_train_timesteps if args.max_timestep is None else args.max_timestep + + if args.autostats_timestep_weighting: + probs = probs[min_timestep:max_timestep].float() + probs = probs / probs.sum() + cat = torch.distributions.Categorical(probs=probs) + timesteps = cat.sample([b_size]).to(device=device) + min_timestep + else: + timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device=device) + return timesteps.long() + + +def get_timesteps_and_huber_c(args, timesteps, noise_scheduler): + if args.loss_type == "huber" or args.loss_type == "smooth_l1": + timesteps = timesteps[0].repeat(timesteps.size(0)) + timestep = timesteps.item() + + if args.huber_schedule == "exponential": + alpha = -math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps + huber_c = math.exp(-alpha * timestep) + elif args.huber_schedule == "snr": + alphas_cumprod = noise_scheduler.alphas_cumprod[timestep] + sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5 + huber_c = (1 - args.huber_c) / (1 + sigmas) ** 2 + args.huber_c + elif args.huber_schedule == "constant": + huber_c = args.huber_c + else: + raise NotImplementedError(f"Unknown Huber loss schedule {args.huber_schedule}!") + elif args.loss_type == "l2": + huber_c = 1 # may be anything, as it's not used + else: + raise NotImplementedError(f"Unknown loss type {args.loss_type}") + + return timesteps, huber_c + + +def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents, std_by_ts, mean_by_ts, timestep_probs): + # Sample a random timestep for each image + b_size = latents.shape[0] + + timesteps = get_timesteps(args, noise_scheduler, timestep_probs, b_size, latents.device) + timesteps, huber_c = get_timesteps_and_huber_c(args, timesteps, noise_scheduler) + # Sample noise that we'll add to the latents - noise = torch.randn_like(latents, device=latents.device) + if std_by_ts is not None: + channels = [] + for t in timesteps: + for i in range(0, 4): + # mean = 0 + std = std_by_ts[t][i][0][0] + mean = mean_by_ts[t][i][0][0] + channels.append(torch.empty((1, 1, latents.shape[2], latents.shape[3]), device=latents.device).normal_(mean=mean, std=std)) + noise = torch.cat(channels, dim=0).reshape(latents.shape) + else: + noise = torch.randn_like(latents, device=latents.device) if args.noise_offset: - noise = custom_train_functions.apply_noise_offset(latents, noise, args.noise_offset, args.adaptive_noise_scale) + if args.noise_offset_random_strength: + noise_offset = torch.rand(1, device=latents.device) * args.noise_offset + else: + noise_offset = args.noise_offset + noise = custom_train_functions.apply_noise_offset(latents, noise, noise_offset, args.adaptive_noise_scale) if args.multires_noise_iterations: noise = custom_train_functions.pyramid_noise_like( noise, latents.device, args.multires_noise_iterations, args.multires_noise_discount ) - # Sample a random timestep for each image - b_size = latents.shape[0] - min_timestep = 0 if args.min_timestep is None else args.min_timestep - max_timestep = noise_scheduler.config.num_train_timesteps if args.max_timestep is None else args.max_timestep - - timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device=latents.device) - timesteps = timesteps.long() - # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) if args.ip_noise_gamma: - noisy_latents = noise_scheduler.add_noise(latents, noise + args.ip_noise_gamma * torch.randn_like(latents), - timesteps) + if args.ip_noise_gamma_random_strength: + strength = torch.rand(1, device=latents.device) * args.ip_noise_gamma + else: + strength = args.ip_noise_gamma + noisy_latents = noise_scheduler.add_noise(latents, noise + strength * torch.randn_like(latents), timesteps) else: noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) - return noise, noisy_latents, timesteps + return noise, noisy_latents, timesteps, huber_c + + +def conditional_loss( + model_pred: torch.Tensor, target: torch.Tensor, reduction: str = "mean", loss_type: str = "l2", huber_c: float = 0.1 +): + if loss_type == "l2": + loss = torch.nn.functional.mse_loss(model_pred, target, reduction=reduction) + elif loss_type == "huber": + loss = 2 * huber_c * (torch.sqrt((model_pred - target) ** 2 + huber_c ** 2) - huber_c) + if reduction == "mean": + loss = torch.mean(loss) + elif reduction == "sum": + loss = torch.sum(loss) + elif loss_type == "smooth_l1": + loss = 2 * (torch.sqrt((model_pred - target) ** 2 + huber_c ** 2) - huber_c) + if reduction == "mean": + loss = torch.mean(loss) + elif reduction == "sum": + loss = torch.sum(loss) + else: + raise NotImplementedError(f"Unsupported Loss Type {loss_type}") + return loss def append_lr_to_logs(logs, lr_scheduler, optimizer_type, including_unet=True): @@ -5331,18 +5474,20 @@ def add(self, *, epoch: int, step: int, loss: float) -> None: def moving_average(self) -> float: return self.loss_total / len(self.loss_list) + def noise_stats(noise): - diff = noise - noise.mean(dim=(1,2,3), keepdim=True) - std = noise.std(dim=(1,2,3)) + diff = noise - noise.mean(dim=(1, 2, 3), keepdim=True) + std = noise.std(dim=(1, 2, 3)) zscores = diff / std[:, None, None, None] - skews = (zscores**3).mean(dim=(1,2,3)) - kurtoses = (zscores**4).mean(dim=(1,2,3)) - 3.0 + skews = (zscores ** 3).mean(dim=(1, 2, 3)) + kurtoses = (zscores ** 4).mean(dim=(1, 2, 3)) - 3.0 return std, skews, kurtoses + def stat_losses(noise, noise_pred, std_loss_weight=0.5, kl_loss_weight=3e-3, skew_loss_weight=0, kurtosis_loss_weight=0): std_loss = torch.nn.functional.mse_loss( - noise_pred.std(dim=(1,2,3)), - noise.std(dim=(1,2,3)), + noise_pred.std(dim=(1, 2, 3)), + noise.std(dim=(1, 2, 3)), reduction="none") * std_loss_weight skew_pred, kurt_pred = noise_stats(noise_pred) diff --git a/sd_scripts/sdxl_train.py b/sd_scripts/sdxl_train.py index aa0d28dd1..85cc5487a 100755 --- a/sd_scripts/sdxl_train.py +++ b/sd_scripts/sdxl_train.py @@ -564,7 +564,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) + noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps( + args, noise_scheduler, latents + ) noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype @@ -582,7 +584,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): or args.masked_loss ): # do not mean over batch dimension for snr weight or scale v-pred loss - loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + loss = train_util.conditional_loss( + noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c + ) if args.masked_loss: loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3]) @@ -598,7 +602,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): loss = loss.mean() # mean over batch dimension else: - loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean") + loss = train_util.conditional_loss( + noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c + ) accelerator.backward(loss) if accelerator.sync_gradients and args.max_grad_norm != 0.0: diff --git a/sd_scripts/train_network.py b/sd_scripts/train_network.py index 8da3ab4d9..143bc6596 100644 --- a/sd_scripts/train_network.py +++ b/sd_scripts/train_network.py @@ -25,6 +25,7 @@ from accelerate.utils import set_seed from diffusers import DDPMScheduler from .library import model_util +from .library import autostats import numpy as np from .library import train_util from .library.train_util import ( @@ -39,6 +40,7 @@ from .library import custom_train_functions from .library.custom_train_functions import ( apply_snr_weight, + get_mask, get_weighted_text_embeddings, prepare_scheduler_for_custom_training, scale_v_prediction_loss_like_noise_prediction, @@ -48,6 +50,7 @@ apply_multichannel_masked_loss, ) from .library.utils import setup_logging, add_logging_arguments +from safetensors import safe_open setup_logging() import logging @@ -143,19 +146,8 @@ def all_reduce_network(self, accelerator, network): def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet): train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet) - def process_batch(self, batch, is_train, train_unet, alphas_cumprod, network, network_has_multiplier, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, step_logs, train_text_encoder=True): - with torch.no_grad(): - if "latents" in batch and batch["latents"] is not None: - latents = batch["latents"].to(accelerator.device) - else: - # latentに変換 - latents = vae.encode(batch["images"].to(accelerator.device, dtype=vae_dtype)).latent_dist.sample() - - # NaNが含まれていれば警告を表示し0に置き換える - if torch.any(torch.isnan(latents)): - accelerator.print("NaN found in latents, replacing with zeros") - latents = torch.where(torch.isnan(latents), torch.zeros_like(latents), latents) - latents = latents * self.vae_scale_factor + def process_batch(self, batch, is_train, train_unet, alphas_cumprod, network, network_has_multiplier, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, step_logs, latents_from_batch, text_conds_for_batch, std_target_by_ts, mean_target_by_ts, scaled_std_target_by_ts, scaled_mean_target_by_ts, timestep_probs, global_step, train_text_encoder=True): + latents = latents_from_batch(batch) # get multiplier for each sample if network_has_multiplier: @@ -169,25 +161,12 @@ def process_batch(self, batch, is_train, train_unet, alphas_cumprod, network, ne accelerator.unwrap_model(network).set_multiplier(multipliers) with torch.set_grad_enabled(is_train and train_text_encoder), accelerator.autocast(): - # Get the text embedding for conditioning - if args.weighted_captions: - text_encoder_conds = get_weighted_text_embeddings( - tokenizers[0], - text_encoders[0], - batch["captions"], - accelerator.device, - args.max_token_length // 75 if args.max_token_length else 1, - clip_skip=args.clip_skip, - ) - else: - text_encoder_conds = self.get_text_cond( - args, accelerator, batch, tokenizers, text_encoders, weight_dtype - ) + text_encoder_conds = text_conds_for_batch(batch) # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps( - args, noise_scheduler, latents + noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps( + args, noise_scheduler, latents, scaled_std_target_by_ts, scaled_mean_target_by_ts, timestep_probs ) # ensure the hidden state will require grad @@ -216,7 +195,64 @@ def process_batch(self, batch, is_train, train_unet, alphas_cumprod, network, ne else: target = noise - if args.use_sig_loss: + if args.use_sts_loss: + base_loss = train_util.conditional_loss( + noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c + ) + + if args.masked_loss and random.random() < args.masked_loss_prob: + mask = get_mask(batch, noisy_latents).to(dtype=weight_dtype) + noise_pred.register_hook(lambda grad: grad * mask) + + loss = base_loss + autostats_weight = (1-args.autostats_effect_min) * math.exp(-args.autostats_decay_rate * global_step) + args.autostats_effect_min + step_logs["losses/std_loss_prob"] = autostats_weight + if std_target_by_ts is not None: + loss = base_loss.mean(dim=(2, 3), keepdims=True) + + pred_std_loss = F.mse_loss( noise_pred.std(dim=(2,3), keepdims=True).float(), std_target_by_ts[timesteps].float(), reduction="none" ) + loss = loss + pred_std_loss * autostats_weight * args.autostats_loss_weights[0] # * std_weighting[timesteps] + + pred_mean_loss = F.mse_loss( noise_pred.mean(dim=(2,3), keepdims=True).float(), mean_target_by_ts[timesteps].float(), reduction="none" ) + loss = loss + pred_mean_loss * autostats_weight * args.autostats_loss_weights[1] # mean_weighting[timesteps] * + + step_logs["losses/base_loss"] = base_loss.mean().item() + step_logs["losses/pred_std_loss"] = pred_std_loss.mean().item() + step_logs["losses/pred_mean_loss"] = pred_mean_loss.mean().item() + + ch_std = noise_pred.std(dim=(0,2,3)) + step_logs["metrics/std_ch_0"] = ch_std[0].item() + step_logs["metrics/std_ch_1"] = ch_std[1].item() + step_logs["metrics/std_ch_2"] = ch_std[2].item() + step_logs["metrics/std_ch_3"] = ch_std[3].item() + + ch_mean = noise_pred.mean(dim=(0,2,3)) + step_logs["metrics/mean_ch_0"] = ch_mean[0].item() + step_logs["metrics/mean_ch_1"] = ch_mean[1].item() + step_logs["metrics/mean_ch_2"] = ch_mean[2].item() + step_logs["metrics/mean_ch_3"] = ch_mean[3].item() + + loss = loss.mean(dim=(1,2,3)) + + if args.autostats_dynamic_timestep_weighting: + for i in range(loss.shape[0]): + timestep_probs[timesteps[i]] /= 1+loss[i].item() + timestep_probs = autostats.smooth(timestep_probs) + + loss_weights = batch["loss_weights"] # 各sampleごとのweight + loss = loss * loss_weights + + if args.min_snr_gamma: + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) + if args.scale_v_pred_loss_like_noise_pred: + loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) + if args.v_pred_like_loss: + loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) + if args.debiased_estimation_loss: + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) + + loss = loss.mean() + elif args.use_sig_loss: mae_loss = F.l1_loss(noise_pred, target, reduction="none") mse_loss = F.mse_loss(noise_pred, target, reduction="none") base_loss = 1/-mse_loss.exp() + 1 @@ -920,6 +956,80 @@ def remove_model(old_ckpt_name): accelerator.print(f"removing old checkpoint: {old_ckpt_file}") os.remove(old_ckpt_file) + def latents_from_batch(batch): + if "latents" in batch and batch["latents"] is not None: + latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype) + else: + with torch.no_grad(): + latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample().to(dtype=weight_dtype) + if torch.any(torch.isnan(latents)): + accelerator.print("NaN found in latents, replacing with zeros") + latents = torch.nan_to_num(latents, 0, out=latents) + return latents * self.vae_scale_factor + + def text_conds_for_batch(batch): + if args.weighted_captions: + return get_weighted_text_embeddings( + tokenizer, + text_encoder, + batch["captions"], + accelerator.device, + args.max_token_length // 75 if args.max_token_length else 1, + clip_skip=args.clip_skip, + ) + else: + return self.get_text_cond(args, accelerator, batch, tokenizers, text_encoders, weight_dtype) + + def collect_model_stats(): + samples_collected = 0 + steps = autostats.kerras_timesteps(64) + steps = ((steps / steps.max()).flip(0) * 999).int() + observed_stats = [None] * len(steps) + max_observations = 5 + with torch.no_grad(), accelerator.autocast(), tqdm(total=20, desc="Collecting model stats") as pbar: + for batch in train_dataloader: + if samples_collected == 0: + pbar.reset(max_observations * len(steps)) + samples_collected += batch["input_ids"].shape[0] + latents = latents_from_batch(batch) + text_conds = text_conds_for_batch(batch) + noise = torch.randn_like(latents) + for n in noise: + for channel in range(0, 4): + n[channel].normal_(0, 1) # Ensure that each channel has normal noise + + for index, timestep in enumerate(steps): + timesteps = torch.tensor([timestep] * len(latents)).to(accelerator.device) + noised = noise_scheduler.add_noise(latents, noise, timesteps) + pred = self.call_unet( + args, + accelerator, + unet, + noised, + timesteps, + text_conds, + batch, + weight_dtype, + ) + pbar.update(batch["input_ids"].shape[0]) + if observed_stats[index] is None: + observed_stats[index] = pred + else: + observed_stats[index] = torch.cat([observed_stats[index], pred], dim=0) + if samples_collected >= max_observations: + break + observations = torch.stack(observed_stats) + from safetensors.torch import save_file + save_file({ "observations": observations.contiguous(), "timesteps": torch.tensor(steps) }, args.autostats) + + std_target_by_ts, mean_target_by_ts, scaled_std_target_by_ts, scaled_mean_target_by_ts, timestep_probs = autostats.autostats(args, collect_model_stats) + + std_target_by_ts = std_target_by_ts.to(dtype=weight_dtype, device=accelerator.device) + mean_target_by_ts = mean_target_by_ts.to(dtype=weight_dtype, device=accelerator.device) + scaled_std_target_by_ts = scaled_std_target_by_ts.to(dtype=weight_dtype, device=accelerator.device) + scaled_mean_target_by_ts = scaled_mean_target_by_ts.to(dtype=weight_dtype, device=accelerator.device) + timestep_probs = timestep_probs.to(dtype=weight_dtype, device=accelerator.device) + # For --sample_at_first self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) @@ -941,7 +1051,7 @@ def remove_model(old_ckpt_name): with accelerator.accumulate(network): on_step_start(text_encoder, unet) is_train = True - loss = self.process_batch(batch, is_train, train_unet, alphas_cumprod, network, network_has_multiplier, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, step_logs, train_text_encoder=train_text_encoder) + loss = self.process_batch(batch, is_train, train_unet, alphas_cumprod, network, network_has_multiplier, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, step_logs, latents_from_batch, text_conds_for_batch, std_target_by_ts, mean_target_by_ts, scaled_std_target_by_ts, scaled_mean_target_by_ts, timestep_probs, global_step, train_text_encoder=train_text_encoder) accelerator.backward(loss) if accelerator.sync_gradients and args.max_grad_norm != 0.0: @@ -1013,7 +1123,7 @@ def remove_model(old_ckpt_name): for val_step, batch in enumerate(val_dataloader): step_logs = {} is_train = False - loss = self.process_batch(batch, is_train, train_unet, network, network_has_multiplier, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, step_logs) + loss = self.process_batch(batch, is_train, train_unet, alphas_cumprod, network, network_has_multiplier, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, step_logs, latents_from_batch, text_conds_for_batch, std_target_by_ts, mean_target_by_ts, scaled_std_target_by_ts, scaled_mean_target_by_ts, timestep_probs, global_step, train_text_encoder=train_text_encoder) current_loss = loss.detach().item() val_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss)