Skip to content

Commit

Permalink
Added a parameter to customize the text encoder dim and alpha separat…
Browse files Browse the repository at this point in the history
…ely, working only for stable cascade ATM, requires testing.
  • Loading branch information
Jeff Ding committed May 3, 2024
1 parent ea366ea commit 95d94b0
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 2 deletions.
22 changes: 20 additions & 2 deletions sd_scripts/networks/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,8 @@ def create_network(
multiplier: float,
network_dim: Optional[int],
network_alpha: Optional[float],
text_encoder_dim: Optional[int],
text_encoder_alpha: Optional[float],
vae: AutoencoderKL,
text_encoder: Union[CLIPTextModel, List[CLIPTextModel]],
unet,
Expand All @@ -432,6 +434,11 @@ def create_network(
if network_alpha is None:
network_alpha = 1.0

if text_encoder_dim is None:
text_encoder_dim = network_dim
if text_encoder_alpha is None:
text_encoder_alpha = network_alpha

# extract dim/alpha for conv2d, and block dim
conv_dim = kwargs.get("conv_dim", None)
conv_alpha = kwargs.get("conv_alpha", None)
Expand Down Expand Up @@ -482,6 +489,8 @@ def create_network(
multiplier=multiplier,
lora_dim=network_dim,
alpha=network_alpha,
text_encoder_dim=text_encoder_dim,
text_encoder_alpha=text_encoder_alpha,
dropout=neuron_dropout,
rank_dropout=rank_dropout,
module_dropout=module_dropout,
Expand Down Expand Up @@ -770,6 +779,8 @@ def __init__(
multiplier: float = 1.0,
lora_dim: int = 4,
alpha: float = 1,
text_encoder_dim: int = 4,
text_encoder_alpha: float = 1,
dropout: Optional[float] = None,
rank_dropout: Optional[float] = None,
module_dropout: Optional[float] = None,
Expand Down Expand Up @@ -802,6 +813,9 @@ def __init__(
self.dropout = dropout
self.rank_dropout = rank_dropout
self.module_dropout = module_dropout
self.text_encoder_dim = text_encoder_dim
self.text_encoder_alpha = text_encoder_alpha


if modules_dim is not None:
logger.info(f"create LoRA network from weights")
Expand Down Expand Up @@ -879,8 +893,12 @@ def create_modules(
else:
# 通常、すべて対象とする
if is_linear or is_conv2d_1x1:
dim = self.lora_dim
alpha = self.alpha
if is_unet:
dim = self.lora_dim
alpha = self.alpha
else:
dim = self.text_encoder_dim
alpha = self.text_encoder_alpha
elif self.conv_lora_dim is not None:
dim = self.conv_lora_dim
alpha = self.conv_alpha
Expand Down
20 changes: 20 additions & 0 deletions sd_scripts/stable_cascade_train_c_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,10 +375,18 @@ def train(self, args):
# workaround for LyCORIS (;^ω^)
net_kwargs["dropout"] = args.network_dropout

if args.text_encoder_dim is None:
args.text_encoder_dim = args.network_dim

if args.text_encoder_alpha is None:
args.text_encoder_alpha = args.network_alpha

network = network_module.create_network(
1.0,
args.network_dim,
args.network_alpha,
args.text_encoder_dim,
args.text_encoder_alpha,
effnet,
text_encoder,
stage_c,
Expand Down Expand Up @@ -1136,6 +1144,18 @@ def setup_parser() -> argparse.ArgumentParser:
default=1,
help="alpha for LoRA weight scaling, default 1 (same as network_dim for same behavior as old version) / LoRaの重み調整のalpha値、デフォルト1(旧バージョンと同じ動作をするにはnetwork_dimと同じ値を指定)",
)
parser.add_argument(
"--text_encoder_dim",
type=int,
default=None,
help="network dimensions for text encoder(depends on each network) / モジュールの次元数(ネットワークにより定義は異なります)",
)
parser.add_argument(
"--text_encoder_alpha",
type=float,
default=1,
help="alpha for for text encoder LoRA weight scaling, default 1 (same as network_dim for same behavior as old version) / LoRaの重み調整のalpha値、デフォルト1(旧バージョンと同じ動作をするにはnetwork_dimと同じ値を指定)",
)
parser.add_argument(
"--network_dropout",
type=float,
Expand Down

0 comments on commit 95d94b0

Please sign in to comment.