forked from jeffreyding18/sd-scripts
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Initial commit for adapting slider lora training, specifically for st…
…able cascade
- Loading branch information
Jeff Ding
committed
Apr 9, 2024
1 parent
c27c282
commit 38769fb
Showing
8 changed files
with
1,837 additions
and
0 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
from typing import Literal, Optional | ||
|
||
import yaml | ||
|
||
from pydantic import BaseModel | ||
import torch | ||
|
||
from lora import TRAINING_METHODS | ||
|
||
PRECISION_TYPES = Literal["fp32", "fp16", "bf16", "float32", "float16", "bfloat16"] | ||
NETWORK_TYPES = Literal["lierla", "c3lier"] | ||
|
||
|
||
class PretrainedModelConfig(BaseModel): | ||
name_or_path: str | ||
v2: bool = False | ||
v_pred: bool = False | ||
|
||
clip_skip: Optional[int] = None | ||
|
||
|
||
class NetworkConfig(BaseModel): | ||
type: NETWORK_TYPES = "lierla" | ||
rank: int = 4 | ||
alpha: float = 1.0 | ||
|
||
training_method: TRAINING_METHODS = "full" | ||
|
||
|
||
class TrainConfig(BaseModel): | ||
precision: PRECISION_TYPES = "bfloat16" | ||
noise_scheduler: Literal["ddim", "ddpm", "lms", "euler_a"] = "ddim" | ||
|
||
iterations: int = 500 | ||
lr: float = 1e-4 | ||
optimizer: str = "adamw" | ||
optimizer_args: str = "" | ||
lr_scheduler: str = "constant" | ||
|
||
max_denoising_steps: int = 50 | ||
|
||
|
||
class SaveConfig(BaseModel): | ||
name: str = "untitled" | ||
path: str = "./output" | ||
per_steps: int = 200 | ||
precision: PRECISION_TYPES = "float32" | ||
|
||
|
||
class LoggingConfig(BaseModel): | ||
use_wandb: bool = False | ||
|
||
verbose: bool = False | ||
|
||
|
||
class OtherConfig(BaseModel): | ||
use_xformers: bool = False | ||
|
||
|
||
class RootConfig(BaseModel): | ||
prompts_file: str | ||
pretrained_model: PretrainedModelConfig | ||
|
||
network: NetworkConfig | ||
|
||
train: Optional[TrainConfig] | ||
|
||
save: Optional[SaveConfig] | ||
|
||
logging: Optional[LoggingConfig] | ||
|
||
other: Optional[OtherConfig] | ||
|
||
|
||
def parse_precision(precision: str) -> torch.dtype: | ||
if precision == "fp32" or precision == "float32": | ||
return torch.float32 | ||
elif precision == "fp16" or precision == "float16": | ||
return torch.float16 | ||
elif precision == "bf16" or precision == "bfloat16": | ||
return torch.bfloat16 | ||
|
||
raise ValueError(f"Invalid precision type: {precision}") | ||
|
||
|
||
def load_config_from_yaml(config_path: str) -> RootConfig: | ||
with open(config_path, "r") as f: | ||
config = yaml.load(f, Loader=yaml.FullLoader) | ||
|
||
root = RootConfig(**config) | ||
|
||
if root.train is None: | ||
root.train = TrainConfig() | ||
|
||
if root.save is None: | ||
root.save = SaveConfig() | ||
|
||
if root.logging is None: | ||
root.logging = LoggingConfig() | ||
|
||
if root.other is None: | ||
root.other = OtherConfig() | ||
|
||
return root |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
import torch | ||
|
||
|
||
def check_requires_grad(model: torch.nn.Module): | ||
for name, module in list(model.named_modules())[:5]: | ||
if len(list(module.parameters())) > 0: | ||
print(f"Module: {name}") | ||
for name, param in list(module.named_parameters())[:2]: | ||
print(f" Parameter: {name}, Requires Grad: {param.requires_grad}") | ||
|
||
|
||
def check_training_mode(model: torch.nn.Module): | ||
for name, module in list(model.named_modules())[:5]: | ||
print(f"Module: {name}, Training Mode: {module.training}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,256 @@ | ||
# ref: | ||
# - https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py | ||
# - https://github.com/kohya-ss/sd-scripts/blob/main/networks/lora.py | ||
|
||
import os | ||
import math | ||
from typing import Optional, List, Type, Set, Literal | ||
|
||
import torch | ||
import torch.nn as nn | ||
from diffusers import UNet2DConditionModel | ||
from safetensors.torch import save_file | ||
|
||
|
||
UNET_TARGET_REPLACE_MODULE_TRANSFORMER = [ | ||
# "Transformer2DModel", # どうやらこっちの方らしい? # attn1, 2 | ||
"Attention" | ||
] | ||
UNET_TARGET_REPLACE_MODULE_CONV = [ | ||
"ResnetBlock2D", | ||
"Downsample2D", | ||
"Upsample2D", | ||
# "DownBlock2D", | ||
# "UpBlock2D" | ||
] # locon, 3clier | ||
|
||
LORA_PREFIX_UNET = "lora_unet" | ||
|
||
DEFAULT_TARGET_REPLACE = UNET_TARGET_REPLACE_MODULE_TRANSFORMER | ||
|
||
TRAINING_METHODS = Literal[ | ||
"noxattn", # train all layers except x-attns and time_embed layers | ||
"innoxattn", # train all layers except self attention layers | ||
"selfattn", # ESD-u, train only self attention layers | ||
"xattn", # ESD-x, train only x attention layers | ||
"full", # train all layers | ||
"xattn-strict", # q and k values | ||
"noxattn-hspace", | ||
"noxattn-hspace-last", | ||
# "xlayer", | ||
# "outxattn", | ||
# "outsattn", | ||
# "inxattn", | ||
# "inmidsattn", | ||
# "selflayer", | ||
] | ||
|
||
|
||
class LoRAModule(nn.Module): | ||
""" | ||
replaces forward method of the original Linear, instead of replacing the original Linear module. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
lora_name, | ||
org_module: nn.Module, | ||
multiplier=1.0, | ||
lora_dim=4, | ||
alpha=1, | ||
): | ||
"""if alpha == 0 or None, alpha is rank (no scaling).""" | ||
super().__init__() | ||
self.lora_name = lora_name | ||
self.lora_dim = lora_dim | ||
|
||
if "Linear" in org_module.__class__.__name__: | ||
in_dim = org_module.in_features | ||
out_dim = org_module.out_features | ||
self.lora_down = nn.Linear(in_dim, lora_dim, bias=False) | ||
self.lora_up = nn.Linear(lora_dim, out_dim, bias=False) | ||
|
||
elif "Conv" in org_module.__class__.__name__: # 一応 | ||
in_dim = org_module.in_channels | ||
out_dim = org_module.out_channels | ||
|
||
self.lora_dim = min(self.lora_dim, in_dim, out_dim) | ||
if self.lora_dim != lora_dim: | ||
print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}") | ||
|
||
kernel_size = org_module.kernel_size | ||
stride = org_module.stride | ||
padding = org_module.padding | ||
self.lora_down = nn.Conv2d( | ||
in_dim, self.lora_dim, kernel_size, stride, padding, bias=False | ||
) | ||
self.lora_up = nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False) | ||
|
||
if type(alpha) == torch.Tensor: | ||
alpha = alpha.detach().numpy() | ||
alpha = lora_dim if alpha is None or alpha == 0 else alpha | ||
self.scale = alpha / self.lora_dim | ||
self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える | ||
|
||
# same as microsoft's | ||
nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5)) | ||
nn.init.zeros_(self.lora_up.weight) | ||
|
||
self.multiplier = multiplier | ||
self.org_module = org_module # remove in applying | ||
|
||
def apply_to(self): | ||
self.org_forward = self.org_module.forward | ||
self.org_module.forward = self.forward | ||
del self.org_module | ||
|
||
def forward(self, x): | ||
return ( | ||
self.org_forward(x) | ||
+ self.lora_up(self.lora_down(x)) * self.multiplier * self.scale | ||
) | ||
|
||
|
||
class LoRANetwork(nn.Module): | ||
def __init__( | ||
self, | ||
unet: UNet2DConditionModel, | ||
rank: int = 4, | ||
multiplier: float = 1.0, | ||
alpha: float = 1.0, | ||
train_method: TRAINING_METHODS = "full", | ||
) -> None: | ||
super().__init__() | ||
self.lora_scale = 1 | ||
self.multiplier = multiplier | ||
self.lora_dim = rank | ||
self.alpha = alpha | ||
|
||
# LoRAのみ | ||
self.module = LoRAModule | ||
|
||
# unetのloraを作る | ||
self.unet_loras = self.create_modules( | ||
LORA_PREFIX_UNET, | ||
unet, | ||
DEFAULT_TARGET_REPLACE, | ||
self.lora_dim, | ||
self.multiplier, | ||
train_method=train_method, | ||
) | ||
print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") | ||
|
||
# assertion 名前の被りがないか確認しているようだ | ||
lora_names = set() | ||
for lora in self.unet_loras: | ||
assert ( | ||
lora.lora_name not in lora_names | ||
), f"duplicated lora name: {lora.lora_name}. {lora_names}" | ||
lora_names.add(lora.lora_name) | ||
|
||
# 適用する | ||
for lora in self.unet_loras: | ||
lora.apply_to() | ||
self.add_module( | ||
lora.lora_name, | ||
lora, | ||
) | ||
|
||
del unet | ||
|
||
torch.cuda.empty_cache() | ||
|
||
def create_modules( | ||
self, | ||
prefix: str, | ||
root_module: nn.Module, | ||
target_replace_modules: List[str], | ||
rank: int, | ||
multiplier: float, | ||
train_method: TRAINING_METHODS, | ||
) -> list: | ||
loras = [] | ||
names = [] | ||
for name, module in root_module.named_modules(): | ||
if train_method == "noxattn" or train_method == "noxattn-hspace" or train_method == "noxattn-hspace-last": # Cross Attention と Time Embed 以外学習 | ||
if "attn2" in name or "time_embed" in name: | ||
continue | ||
elif train_method == "innoxattn": # Cross Attention 以外学習 | ||
if "attn2" in name: | ||
continue | ||
elif train_method == "selfattn": # Self Attention のみ学習 | ||
if "attn1" not in name: | ||
continue | ||
elif train_method == "xattn" or train_method == "xattn-strict": # Cross Attention のみ学習 | ||
if "attn2" not in name: | ||
continue | ||
elif train_method == "full": # 全部学習 | ||
pass | ||
else: | ||
raise NotImplementedError( | ||
f"train_method: {train_method} is not implemented." | ||
) | ||
if module.__class__.__name__ in target_replace_modules: | ||
for child_name, child_module in module.named_modules(): | ||
if child_module.__class__.__name__ in ["Linear", "Conv2d", "LoRACompatibleLinear", "LoRACompatibleConv"]: | ||
if train_method == 'xattn-strict': | ||
if 'out' in child_name: | ||
continue | ||
if train_method == 'noxattn-hspace': | ||
if 'mid_block' not in name: | ||
continue | ||
if train_method == 'noxattn-hspace-last': | ||
if 'mid_block' not in name or '.1' not in name or 'conv2' not in child_name: | ||
continue | ||
lora_name = prefix + "." + name + "." + child_name | ||
lora_name = lora_name.replace(".", "_") | ||
# print(f"{lora_name}") | ||
lora = self.module( | ||
lora_name, child_module, multiplier, rank, self.alpha | ||
) | ||
# print(name, child_name) | ||
# print(child_module.weight.shape) | ||
loras.append(lora) | ||
names.append(lora_name) | ||
# print(f'@@@@@@@@@@@@@@@@@@@@@@@@@@@@ \n {names}') | ||
return loras | ||
|
||
def prepare_optimizer_params(self): | ||
all_params = [] | ||
|
||
if self.unet_loras: # 実質これしかない | ||
params = [] | ||
[params.extend(lora.parameters()) for lora in self.unet_loras] | ||
param_data = {"params": params} | ||
all_params.append(param_data) | ||
|
||
return all_params | ||
|
||
def save_weights(self, file, dtype=None, metadata: Optional[dict] = None): | ||
state_dict = self.state_dict() | ||
|
||
if dtype is not None: | ||
for key in list(state_dict.keys()): | ||
v = state_dict[key] | ||
v = v.detach().clone().to("cpu").to(dtype) | ||
state_dict[key] = v | ||
|
||
# for key in list(state_dict.keys()): | ||
# if not key.startswith("lora"): | ||
# # lora以外除外 | ||
# del state_dict[key] | ||
|
||
if os.path.splitext(file)[1] == ".safetensors": | ||
save_file(state_dict, file, metadata) | ||
else: | ||
torch.save(state_dict, file) | ||
def set_lora_slider(self, scale): | ||
self.lora_scale = scale | ||
|
||
def __enter__(self): | ||
for lora in self.unet_loras: | ||
lora.multiplier = 1.0 * self.lora_scale | ||
|
||
def __exit__(self, exc_type, exc_value, tb): | ||
for lora in self.unet_loras: | ||
lora.multiplier = 0 |
Oops, something went wrong.