Skip to content

Commit

Permalink
Added schedule free opt, updated training args
Browse files Browse the repository at this point in the history
  • Loading branch information
Jeff Ding committed Jul 29, 2024
1 parent 018e339 commit c89b8fc
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 4 deletions.
22 changes: 19 additions & 3 deletions sd_scripts/fine_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,13 +246,26 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
unet.to(weight_dtype)
text_encoder.to(weight_dtype)

use_schedule_free_optimizer = args.optimizer_type.lower().endswith("schedulefree")
# acceleratorがなんかよろしくやってくれるらしい
if args.train_text_encoder:
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
unet, text_encoder, optimizer, train_dataloader = accelerator.prepare(
unet, text_encoder, optimizer, train_dataloader
)
if not use_schedule_free_optimizer:
lr_scheduler = accelerator.prepare(lr_scheduler)
else:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
unet, optimizer, train_dataloader = accelerator.prepare(unet, optimizer, train_dataloader)
if not use_schedule_free_optimizer:
lr_scheduler = accelerator.prepare(lr_scheduler)

# make lambda function for calling optimizer.train() and optimizer.eval() if schedule-free optimizer is used
if use_schedule_free_optimizer:
optimizer_train_if_needed = lambda: optimizer.train()
optimizer_eval_if_needed = lambda: optimizer.eval()
else:
optimizer_train_if_needed = lambda: None
optimizer_eval_if_needed = lambda: None

# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
if args.full_fp16:
Expand Down Expand Up @@ -310,6 +323,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
m.train()

for step, batch in enumerate(train_dataloader):
optimizer_train_if_needed()
current_step.value = global_step
with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく
with torch.no_grad():
Expand Down Expand Up @@ -379,6 +393,8 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)

optimizer_eval_if_needed()

# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
Expand Down
37 changes: 37 additions & 0 deletions sd_scripts/library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3226,6 +3226,20 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
choices=["l2", "huber", "smooth_l1"],
help="The type of loss function to use (L2, Huber, or smooth L1), default is L2 / 使用する損失関数の種類(L2、Huber、またはsmooth L1)、デフォルトはL2",
)
parser.add_argument(
"--huber_schedule",
type=str,
default="snr",
choices=["constant", "exponential", "snr"],
help="The scheduling method for Huber loss (constant, exponential, or SNR-based). Only used when loss_type is 'huber' or 'smooth_l1'. default is snr"
+ " / Huber損失のスケジューリング方法(constant、exponential、またはSNRベース)。loss_typeが'huber'または'smooth_l1'の場合に有効、デフォルトは snr",
)
parser.add_argument(
"--huber_c",
type=float,
default=0.1,
help="The huber loss parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type. default is 0.1 / Huber損失のパラメータ。loss_typeがhuberまたはsmooth l1の場合に有効。デフォルトは0.1",
)
parser.add_argument(
"--lowram",
action="store_true",
Expand Down Expand Up @@ -4090,6 +4104,21 @@ def get_optimizer(args, trainable_params):
optimizer_class = torch.optim.AdamW
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)

elif optimizer_type.endswith("schedulefree".lower()):
try:
import schedulefree as sf
except ImportError:
raise ImportError("No schedulefree / schedulefreeがインストールされていないようです")
if optimizer_type == "AdamWScheduleFree".lower():
optimizer_class = sf.AdamWScheduleFree
logger.info(f"use AdamWScheduleFree optimizer | {optimizer_kwargs}")
elif optimizer_type == "SGDScheduleFree".lower():
optimizer_class = sf.SGDScheduleFree
logger.info(f"use SGDScheduleFree optimizer | {optimizer_kwargs}")
else:
raise ValueError(f"Unknown optimizer type: {optimizer_type}")
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)

if optimizer is None:
# 任意のoptimizerを使う
optimizer_type = args.optimizer_type # lowerでないやつ(微妙)
Expand Down Expand Up @@ -4118,6 +4147,14 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int):
"""
Unified API to get any scheduler from its name.
"""

if args.optimizer_type.lower().endswith("schedulefree"):
# return dummy scheduler: it has 'step' method but does nothing
logger.info("use dummy scheduler for schedule free optimizer / schedule free optimizer用のダミースケジューラを使用します")
lr_scheduler = TYPE_TO_SCHEDULER_FUNCTION[SchedulerType.CONSTANT](optimizer)
lr_scheduler.step = lambda: None
return lr_scheduler

name = args.lr_scheduler
num_warmup_steps: Optional[int] = args.lr_warmup_steps
num_training_steps = args.max_train_steps * num_processes # * args.gradient_accumulation_steps
Expand Down
21 changes: 20 additions & 1 deletion sd_scripts/train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,7 +627,19 @@ def train(self, args):
else:
pass # if text_encoder is not trained, no need to prepare. and device and dtype are already set

network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(network, optimizer, train_dataloader, lr_scheduler)
use_schedule_free_optimizer = args.optimizer_type.lower().endswith("schedulefree")

network, optimizer, train_dataloader = accelerator.prepare(network, optimizer, train_dataloader)

if not use_schedule_free_optimizer:
lr_scheduler = accelerator.prepare(lr_scheduler)

if use_schedule_free_optimizer:
optimizer_train_if_needed = lambda: (optimizer.optimizer if hasattr(optimizer, "optimizer") else optimizer).train()
optimizer_eval_if_needed = lambda: (optimizer.optimizer if hasattr(optimizer, "optimizer") else optimizer).eval()
else:
optimizer_train_if_needed = lambda: None
optimizer_eval_if_needed = lambda: None

if args.gradient_checkpointing:
# according to TI example in Diffusers, train is required
Expand Down Expand Up @@ -1030,6 +1042,10 @@ def collect_model_stats():
scaled_mean_target_by_ts = scaled_mean_target_by_ts.to(dtype=weight_dtype, device=accelerator.device)
timestep_probs = timestep_probs.to(dtype=weight_dtype, device=accelerator.device)

if self.is_sdxl:
ts = 500
mean_target_by_ts[:ts, 3] = mean_target_by_ts[:ts, 3] * torch.arange(0, 1.0, 1 / ts, device=mean_target_by_ts.device).view(-1, 1, 1)

# For --sample_at_first
self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)

Expand All @@ -1047,6 +1063,7 @@ def collect_model_stats():
alphas_cumprod = noise_scheduler.alphas_cumprod.to(accelerator.device)
for step, batch in enumerate(train_dataloader):
step_logs = {}
optimizer_train_if_needed()
current_step.value = global_step
with accelerator.accumulate(network):
on_step_start(text_encoder, unet)
Expand All @@ -1070,6 +1087,8 @@ def collect_model_stats():
else:
keys_scaled, mean_norm, maximum_norm = None, None, None

optimizer_eval_if_needed()

# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
Expand Down

0 comments on commit c89b8fc

Please sign in to comment.