Skip to content

Commit

Permalink
Initial commit for adapting slider lora training, specifically for st…
Browse files Browse the repository at this point in the history
…able cascade
  • Loading branch information
Jeff Ding committed Apr 9, 2024
1 parent c27c282 commit 38769fb
Show file tree
Hide file tree
Showing 8 changed files with 1,837 additions and 0 deletions.
Empty file.
104 changes: 104 additions & 0 deletions sd_scripts/imagesliders/config_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
from typing import Literal, Optional

import yaml

from pydantic import BaseModel
import torch

from lora import TRAINING_METHODS

PRECISION_TYPES = Literal["fp32", "fp16", "bf16", "float32", "float16", "bfloat16"]
NETWORK_TYPES = Literal["lierla", "c3lier"]


class PretrainedModelConfig(BaseModel):
name_or_path: str
v2: bool = False
v_pred: bool = False

clip_skip: Optional[int] = None


class NetworkConfig(BaseModel):
type: NETWORK_TYPES = "lierla"
rank: int = 4
alpha: float = 1.0

training_method: TRAINING_METHODS = "full"


class TrainConfig(BaseModel):
precision: PRECISION_TYPES = "bfloat16"
noise_scheduler: Literal["ddim", "ddpm", "lms", "euler_a"] = "ddim"

iterations: int = 500
lr: float = 1e-4
optimizer: str = "adamw"
optimizer_args: str = ""
lr_scheduler: str = "constant"

max_denoising_steps: int = 50


class SaveConfig(BaseModel):
name: str = "untitled"
path: str = "./output"
per_steps: int = 200
precision: PRECISION_TYPES = "float32"


class LoggingConfig(BaseModel):
use_wandb: bool = False

verbose: bool = False


class OtherConfig(BaseModel):
use_xformers: bool = False


class RootConfig(BaseModel):
prompts_file: str
pretrained_model: PretrainedModelConfig

network: NetworkConfig

train: Optional[TrainConfig]

save: Optional[SaveConfig]

logging: Optional[LoggingConfig]

other: Optional[OtherConfig]


def parse_precision(precision: str) -> torch.dtype:
if precision == "fp32" or precision == "float32":
return torch.float32
elif precision == "fp16" or precision == "float16":
return torch.float16
elif precision == "bf16" or precision == "bfloat16":
return torch.bfloat16

raise ValueError(f"Invalid precision type: {precision}")


def load_config_from_yaml(config_path: str) -> RootConfig:
with open(config_path, "r") as f:
config = yaml.load(f, Loader=yaml.FullLoader)

root = RootConfig(**config)

if root.train is None:
root.train = TrainConfig()

if root.save is None:
root.save = SaveConfig()

if root.logging is None:
root.logging = LoggingConfig()

if root.other is None:
root.other = OtherConfig()

return root
14 changes: 14 additions & 0 deletions sd_scripts/imagesliders/debug_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import torch


def check_requires_grad(model: torch.nn.Module):
for name, module in list(model.named_modules())[:5]:
if len(list(module.parameters())) > 0:
print(f"Module: {name}")
for name, param in list(module.named_parameters())[:2]:
print(f" Parameter: {name}, Requires Grad: {param.requires_grad}")


def check_training_mode(model: torch.nn.Module):
for name, module in list(model.named_modules())[:5]:
print(f"Module: {name}, Training Mode: {module.training}")
256 changes: 256 additions & 0 deletions sd_scripts/imagesliders/lora.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,256 @@
# ref:
# - https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py
# - https://github.com/kohya-ss/sd-scripts/blob/main/networks/lora.py

import os
import math
from typing import Optional, List, Type, Set, Literal

import torch
import torch.nn as nn
from diffusers import UNet2DConditionModel
from safetensors.torch import save_file


UNET_TARGET_REPLACE_MODULE_TRANSFORMER = [
# "Transformer2DModel", # どうやらこっちの方らしい? # attn1, 2
"Attention"
]
UNET_TARGET_REPLACE_MODULE_CONV = [
"ResnetBlock2D",
"Downsample2D",
"Upsample2D",
# "DownBlock2D",
# "UpBlock2D"
] # locon, 3clier

LORA_PREFIX_UNET = "lora_unet"

DEFAULT_TARGET_REPLACE = UNET_TARGET_REPLACE_MODULE_TRANSFORMER

TRAINING_METHODS = Literal[
"noxattn", # train all layers except x-attns and time_embed layers
"innoxattn", # train all layers except self attention layers
"selfattn", # ESD-u, train only self attention layers
"xattn", # ESD-x, train only x attention layers
"full", # train all layers
"xattn-strict", # q and k values
"noxattn-hspace",
"noxattn-hspace-last",
# "xlayer",
# "outxattn",
# "outsattn",
# "inxattn",
# "inmidsattn",
# "selflayer",
]


class LoRAModule(nn.Module):
"""
replaces forward method of the original Linear, instead of replacing the original Linear module.
"""

def __init__(
self,
lora_name,
org_module: nn.Module,
multiplier=1.0,
lora_dim=4,
alpha=1,
):
"""if alpha == 0 or None, alpha is rank (no scaling)."""
super().__init__()
self.lora_name = lora_name
self.lora_dim = lora_dim

if "Linear" in org_module.__class__.__name__:
in_dim = org_module.in_features
out_dim = org_module.out_features
self.lora_down = nn.Linear(in_dim, lora_dim, bias=False)
self.lora_up = nn.Linear(lora_dim, out_dim, bias=False)

elif "Conv" in org_module.__class__.__name__: # 一応
in_dim = org_module.in_channels
out_dim = org_module.out_channels

self.lora_dim = min(self.lora_dim, in_dim, out_dim)
if self.lora_dim != lora_dim:
print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}")

kernel_size = org_module.kernel_size
stride = org_module.stride
padding = org_module.padding
self.lora_down = nn.Conv2d(
in_dim, self.lora_dim, kernel_size, stride, padding, bias=False
)
self.lora_up = nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False)

if type(alpha) == torch.Tensor:
alpha = alpha.detach().numpy()
alpha = lora_dim if alpha is None or alpha == 0 else alpha
self.scale = alpha / self.lora_dim
self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える

# same as microsoft's
nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
nn.init.zeros_(self.lora_up.weight)

self.multiplier = multiplier
self.org_module = org_module # remove in applying

def apply_to(self):
self.org_forward = self.org_module.forward
self.org_module.forward = self.forward
del self.org_module

def forward(self, x):
return (
self.org_forward(x)
+ self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
)


class LoRANetwork(nn.Module):
def __init__(
self,
unet: UNet2DConditionModel,
rank: int = 4,
multiplier: float = 1.0,
alpha: float = 1.0,
train_method: TRAINING_METHODS = "full",
) -> None:
super().__init__()
self.lora_scale = 1
self.multiplier = multiplier
self.lora_dim = rank
self.alpha = alpha

# LoRAのみ
self.module = LoRAModule

# unetのloraを作る
self.unet_loras = self.create_modules(
LORA_PREFIX_UNET,
unet,
DEFAULT_TARGET_REPLACE,
self.lora_dim,
self.multiplier,
train_method=train_method,
)
print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")

# assertion 名前の被りがないか確認しているようだ
lora_names = set()
for lora in self.unet_loras:
assert (
lora.lora_name not in lora_names
), f"duplicated lora name: {lora.lora_name}. {lora_names}"
lora_names.add(lora.lora_name)

# 適用する
for lora in self.unet_loras:
lora.apply_to()
self.add_module(
lora.lora_name,
lora,
)

del unet

torch.cuda.empty_cache()

def create_modules(
self,
prefix: str,
root_module: nn.Module,
target_replace_modules: List[str],
rank: int,
multiplier: float,
train_method: TRAINING_METHODS,
) -> list:
loras = []
names = []
for name, module in root_module.named_modules():
if train_method == "noxattn" or train_method == "noxattn-hspace" or train_method == "noxattn-hspace-last": # Cross Attention と Time Embed 以外学習
if "attn2" in name or "time_embed" in name:
continue
elif train_method == "innoxattn": # Cross Attention 以外学習
if "attn2" in name:
continue
elif train_method == "selfattn": # Self Attention のみ学習
if "attn1" not in name:
continue
elif train_method == "xattn" or train_method == "xattn-strict": # Cross Attention のみ学習
if "attn2" not in name:
continue
elif train_method == "full": # 全部学習
pass
else:
raise NotImplementedError(
f"train_method: {train_method} is not implemented."
)
if module.__class__.__name__ in target_replace_modules:
for child_name, child_module in module.named_modules():
if child_module.__class__.__name__ in ["Linear", "Conv2d", "LoRACompatibleLinear", "LoRACompatibleConv"]:
if train_method == 'xattn-strict':
if 'out' in child_name:
continue
if train_method == 'noxattn-hspace':
if 'mid_block' not in name:
continue
if train_method == 'noxattn-hspace-last':
if 'mid_block' not in name or '.1' not in name or 'conv2' not in child_name:
continue
lora_name = prefix + "." + name + "." + child_name
lora_name = lora_name.replace(".", "_")
# print(f"{lora_name}")
lora = self.module(
lora_name, child_module, multiplier, rank, self.alpha
)
# print(name, child_name)
# print(child_module.weight.shape)
loras.append(lora)
names.append(lora_name)
# print(f'@@@@@@@@@@@@@@@@@@@@@@@@@@@@ \n {names}')
return loras

def prepare_optimizer_params(self):
all_params = []

if self.unet_loras: # 実質これしかない
params = []
[params.extend(lora.parameters()) for lora in self.unet_loras]
param_data = {"params": params}
all_params.append(param_data)

return all_params

def save_weights(self, file, dtype=None, metadata: Optional[dict] = None):
state_dict = self.state_dict()

if dtype is not None:
for key in list(state_dict.keys()):
v = state_dict[key]
v = v.detach().clone().to("cpu").to(dtype)
state_dict[key] = v

# for key in list(state_dict.keys()):
# if not key.startswith("lora"):
# # lora以外除外
# del state_dict[key]

if os.path.splitext(file)[1] == ".safetensors":
save_file(state_dict, file, metadata)
else:
torch.save(state_dict, file)
def set_lora_slider(self, scale):
self.lora_scale = scale

def __enter__(self):
for lora in self.unet_loras:
lora.multiplier = 1.0 * self.lora_scale

def __exit__(self, exc_type, exc_value, tb):
for lora in self.unet_loras:
lora.multiplier = 0
Loading

0 comments on commit 38769fb

Please sign in to comment.