Skip to content

Commit

Permalink
WIP stage b latents for loss calc
Browse files Browse the repository at this point in the history
  • Loading branch information
Jeff Ding committed May 21, 2024
1 parent a22c9e1 commit 3ad290b
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 10 deletions.
1 change: 1 addition & 0 deletions sd_scripts/library/stable_cascade.py
Original file line number Diff line number Diff line change
Expand Up @@ -797,6 +797,7 @@ def gen_r_embedding(self, r, max_positions=10000):
def gen_c_embeddings(self, clip):
if len(clip.shape) == 2:
clip = clip.unsqueeze(1)
clip = clip.to(device=self.clip_mapper.weight.device, dtype=self.clip_mapper.weight.dtype)
clip = self.clip_mapper(clip).view(clip.size(0), clip.size(1) * self.c_clip_seq, -1)
clip = self.clip_norm(clip)
return clip
Expand Down
45 changes: 35 additions & 10 deletions sd_scripts/stable_cascade_train_c_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import torch
from .library.device_utils import init_ipex, clean_memory_on_device

init_ipex()

import numpy as np
Expand Down Expand Up @@ -179,7 +180,7 @@ def sample_images(self, accelerator, args, epoch, global_step, device, vae, toke
train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet)

def train(self, args):
session_id = random.randint(0, 2**32)
session_id = random.randint(0, 2 ** 32)
training_started_at = time.time()

train_util.verify_training_args(args)
Expand All @@ -191,7 +192,7 @@ def train(self, args):
use_user_config = args.dataset_config is not None

if args.seed is None:
args.seed = random.randint(0, 2**32)
args.seed = random.randint(0, 2 ** 32)
set_seed(args.seed)

# tokenizerは単体またはリスト、tokenizersは必ずリスト:既存のコードとの互換性のため
Expand Down Expand Up @@ -284,6 +285,8 @@ def train(self, args):
loading_device = accelerator.device if args.lowram else "cpu"
effnet = sc_utils.load_effnet(args.effnet_checkpoint_path, loading_device)
stage_c = sc_utils.load_stage_c_model(args.stage_c_checkpoint_path, dtype=weight_dtype, device=loading_device)
if args.use_stage_b:
stage_b = sc_utils.load_stage_b_model(args.stage_b_checkpoint_path, dtype=weight_dtype, device=loading_device)
text_encoder = sc_utils.load_clip_text_model(args.text_model_checkpoint_path, dtype=weight_dtype, device=loading_device)
model_version = sc.MODEL_VERSION_STABLE_CASCADE

Expand Down Expand Up @@ -600,7 +603,7 @@ def train(self, args):
"ss_network_dropout": args.network_dropout, # some networks may not have dropout
"ss_mixed_precision": args.mixed_precision,
"ss_full_fp16": bool(args.full_fp16),
"ss_v2": False, # bool(args.v2),
"ss_v2": False, # bool(args.v2),
"ss_base_model_version": model_version,
"ss_clip_skip": args.clip_skip,
"ss_max_token_length": args.max_token_length,
Expand Down Expand Up @@ -792,7 +795,7 @@ def train(self, args):
target=sc.EpsilonTarget(),
noise_cond=sc.CosineTNoiseCond(),
loss_weight=sc.AdaptiveLossWeight() if args.adaptive_loss_weight else sc.P2LossWeight(),
num_timesteps= 1000 if args.max_train_steps is None else args.max_train_steps,
num_timesteps=1000 if args.max_train_steps is None else args.max_train_steps,
)

if accelerator.is_main_process:
Expand Down Expand Up @@ -845,7 +848,7 @@ def remove_model(old_ckpt_name):
epoch_loss_map = {}
# training loop
for epoch in range(num_train_epochs):
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
accelerator.print(f"\nepoch {epoch + 1}/{num_train_epochs}")
current_epoch.value = epoch + 1

metadata["ss_epoch"] = str(epoch + 1)
Expand Down Expand Up @@ -907,20 +910,35 @@ def remove_model(old_ckpt_name):
noised, noise_cond, clip_text=encoder_hidden_states, clip_text_pooled=pool, clip_img=zero_img_emb
)

with torch.no_grad():
imgs = batch["images"]
if imgs is not None:
effnet_embeddings = effnet(imgs.to(accelerator.device, dtype=effnet_dtype))
else:
effnet_embeddings = None

if args.use_stage_b:
pred_b = stage_b(
noised, noise_cond, effnet=effnet_embeddings, clip=pool.to(accelerator.device)
)

if args.use_sig_loss:
timesteps = gdf.sample_timesteps(latents.size(0))
alphas_cumprod = gdf.alphas_cumprod.to(latents.device)
ac = alphas_cumprod[timesteps]
mae_loss = torch.nn.functional.l1_loss(pred, target, reduction="none")
base_loss = 1/-mae_loss.exp() + 1
base_loss = 1 / -mae_loss.exp() + 1
loss = base_loss.mean(dim=(2, 3), keepdims=True) * ac
loss = loss + base_loss.std(dim=(2,3), keepdims=True) * (1-ac)
loss = loss + base_loss.std(dim=(2, 3), keepdims=True) * (1 - ac)

loss = loss.mean(dim=(1,2,3))
loss = loss.mean(dim=(1, 2, 3))
loss_adjusted = (loss * loss_weight)
loss_adjusted = loss_adjusted.mean()
else:
loss = torch.nn.functional.mse_loss(pred.float(), target.float(), reduction="none")
if args.use_stage_b:
loss_b = torch.nn.functional.mse_loss(pred_b.float(), target.float(), reduction="none")
loss = (loss + loss_b) / 2

if args.masked_loss and args.weighted_loss:
loss, noise_mask = apply_multichannel_masked_loss(loss, batch, args.background_weight, args.character_weight, args.detail_weight)
Expand All @@ -942,8 +960,8 @@ def remove_model(old_ckpt_name):

step_logs["metrics/noise_pred_std"] = pred_std.mean().item()
step_logs["metrics/noise_pred_mean"] = pred.mean()
step_logs["metrics/std_divergence"] = true_std.mean().item() - pred_std.mean().item()
step_logs["metrics/skew_divergence"] = true_skews.mean().item() - pred_skews.mean().item()
step_logs["metrics/std_divergence"] = true_std.mean().item() - pred_std.mean().item()
step_logs["metrics/skew_divergence"] = true_skews.mean().item() - pred_skews.mean().item()
step_logs["metrics/kurtosis_divergence"] = true_kurtoses.mean().item() - pred_kurtoses.mean().item()

loss_adjusted = loss_adjusted.mean()
Expand Down Expand Up @@ -1209,11 +1227,18 @@ def setup_parser() -> argparse.ArgumentParser:
action="store_true",
help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う",
)
parser.add_argument(
"--use_stage_b",
action="store_true",
help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う",
)
return parser


def is_decreasing(list, min_delta=0.001):
return all(list[i] - list[i + 1] >= min_delta for i in range(len(list) - 1))


def is_increasing(list, min_delta=0.001):
return all(list[i + 1] - list[i] >= min_delta for i in range(len(list) - 1))

Expand Down

0 comments on commit 3ad290b

Please sign in to comment.