Skip to content

Commit

Permalink
Added param args to control weighted loss and weights for multi-mask …
Browse files Browse the repository at this point in the history
…loss
  • Loading branch information
Jeff Ding committed May 3, 2024
1 parent 95d94b0 commit 322e640
Showing 1 changed file with 27 additions and 8 deletions.
35 changes: 27 additions & 8 deletions sd_scripts/stable_cascade_train_c_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -915,22 +915,18 @@ def remove_model(old_ckpt_name):
base_loss = 1/-mae_loss.exp() + 1
loss = base_loss.mean(dim=(2, 3), keepdims=True) * ac
loss = loss + base_loss.std(dim=(2,3), keepdims=True) * (1-ac)
if args.masked_loss and np.random.rand() < args.masked_loss_prob:
# loss, noise_mask = apply_masked_loss(loss, batch)
loss, noise_mask = apply_multichannel_masked_loss(loss, batch, 1.0, 1.0, 2.0)
else:
noise_mask = torch.ones_like(noise, device=noise.device)

loss = loss.mean(dim=(1,2,3))
loss_adjusted = (loss * loss_weight)
loss_adjusted = loss_adjusted.mean()
else:
loss = torch.nn.functional.mse_loss(pred.float(), target.float(), reduction="none")

if args.masked_loss and np.random.rand() < args.masked_loss_prob:
# loss, noise_mask = apply_masked_loss(loss, batch)
loss, noise_mask = apply_multichannel_masked_loss(loss, batch, 1.0, 1.5, 2.0)
if args.masked_loss and args.weighted_loss:
loss, noise_mask = apply_multichannel_masked_loss(loss, batch, args.background_weight, args.character_weight, args.detail_weight)
noise_mask = torch.ones_like(noise, device=noise.device)
elif args.masked_loss and np.random.rand() < args.masked_loss_prob:
loss, noise_mask = apply_masked_loss(loss, batch)
else:
noise_mask = torch.ones_like(noise, device=noise.device)

Expand Down Expand Up @@ -1213,6 +1209,29 @@ def setup_parser() -> argparse.ArgumentParser:
action="store_true",
help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う",
)
parser.add_argument(
"--background_weight",
type=float,
default=1.0,
help="Background weight for multi-mask training",
)
parser.add_argument(
"--character_weight",
type=float,
default=1.0,
help="Character weight for multi-mask training",
)
parser.add_argument(
"--detail_weight",
type=float,
default=1.0,
help="Detail weight for multi-mask training",
)
parser.add_argument(
"--weighted_loss",
action="store_true",
help="Use multi-mask weighted loss",
)
return parser

def is_decreasing(list, min_delta=0.001):
Expand Down

0 comments on commit 322e640

Please sign in to comment.