Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model Support] FLUX.1-dev #28

Merged
merged 17 commits into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .flake8
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
[flake8]
max-line-length = 120
extend-ignore = E203
filename = *.py

[isort]
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ __pycache__/
# Distribution / packaging
.Python
build/
.build/
develop-eggs/
dist/
downloads/
Expand Down
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ repos:
- id: black
name: black
language: python
args: ["--config", ".flake8"]

- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
Expand Down
4 changes: 3 additions & 1 deletion python/src/diffusionkit/mlx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,14 @@
"sd3-8b-unreleased": "models/sd3_8b_beta.safetensors", # unreleased
"argmaxinc/mlx-FLUX.1-schnell": "argmaxinc/mlx-FLUX.1-schnell",
"argmaxinc/mlx-FLUX.1-schnell-4bit-quantized": "argmaxinc/mlx-FLUX.1-schnell-4bit-quantized",
"argmaxinc/mlx-FLUX.1-dev": "argmaxinc/mlx-FLUX.1-dev",
}

T5_MAX_LENGTH = {
"argmaxinc/mlx-stable-diffusion-3-medium": 512,
"argmaxinc/mlx-FLUX.1-schnell": 256,
"argmaxinc/mlx-FLUX.1-schnell-4bit-quantized": 256,
"argmaxinc/mlx-FLUX.1-dev": 512,
}


Expand Down Expand Up @@ -653,7 +655,7 @@ def encode_text(
text,
(negative_text if cfg_weight > 1 else None),
)
padded_tokens_t5 = mx.zeros((1, 256)).astype(tokens_t5.dtype)
padded_tokens_t5 = mx.zeros((1, T5_MAX_LENGTH[self.model_version])).astype(tokens_t5.dtype)
padded_tokens_t5[:, : tokens_t5.shape[1]] = tokens_t5[
[0], :
] # Ignore negative text
Expand Down
18 changes: 18 additions & 0 deletions python/src/diffusionkit/mlx/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ def hidden_size(self) -> int:

low_memory_mode: bool = True

guidance_embed: bool = False


SD3_8b = MMDiTConfig(depth_multimodal=38, num_heads=3, upcast_multimodal_blocks=[35])

Expand All @@ -90,6 +92,22 @@ def hidden_size(self) -> int:
dtype=mx.bfloat16,
)

FLUX_DEV = MMDiTConfig(
num_heads=24,
depth_multimodal=19,
depth_unified=38,
parallel_mlp_for_unified_blocks=True,
hidden_size_override=3072,
patchify_via_reshape=True,
pos_embed_type=PositionalEncoding.PreSDPARope,
rope_axes_dim=(16, 56, 56),
pooled_text_embed_dim=768, # CLIP-L/14 only
use_qk_norm=True,
float16_dtype=mx.bfloat16,
guidance_embed=True,
dtype=mx.bfloat16,
)


@dataclass
class AutoencoderConfig:
Expand Down
24 changes: 23 additions & 1 deletion python/src/diffusionkit/mlx/mmdit.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,13 @@ def __init__(self, config: MMDiTConfig):
super().__init__()
self.config = config

if config.guidance_embed:
self.guidance_in = MLPEmbedder(
in_dim=config.frequency_embed_dim, hidden_dim=config.hidden_size
)
else:
self.guidance_in = nn.Identity()

# Input adapters and embeddings
self.x_embedder = LatentImageAdapter(config)

Expand Down Expand Up @@ -209,6 +216,9 @@ def __call__(
else:
positional_encodings = None

if self.config.guidance_embed:
timestep = self.guidance_in(self.t_embedder(timestep))

# MultiModalTransformer layers
if self.config.depth_multimodal > 0:
for bidx, block in enumerate(self.multimodal_transformer_blocks):
Expand Down Expand Up @@ -236,7 +246,6 @@ def __call__(
:, token_level_text_embeddings.shape[1] :, ...
]

# Final layer
latent_image_embeddings = self.final_layer(
latent_image_embeddings,
timestep,
Expand Down Expand Up @@ -933,6 +942,19 @@ def apply(q_or_k: mx.array, rope: mx.array) -> mx.array:
)


class MLPEmbedder(nn.Module):
def __init__(self, in_dim: int, hidden_dim: int):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(in_dim, hidden_dim),
nn.SiLU(),
nn.Linear(hidden_dim, hidden_dim)
)

def __call__(self, x):
return self.mlp(x)


def affine_transform(
x: mx.array,
shift: mx.array,
Expand Down
10 changes: 9 additions & 1 deletion python/src/diffusionkit/mlx/model_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@
"argmaxinc/mlx-FLUX.1-schnell-4bit-quantized": "flux-schnell-4bit-quantized.safetensors",
"vae": "ae.safetensors",
},
"argmaxinc/mlx-FLUX.1-dev": {
"argmaxinc/mlx-FLUX.1-dev": "flux1-dev.safetensors",
"vae": "ae.safetensors",
},
}
_DEFAULT_MODEL = "argmaxinc/stable-diffusion"
_MODELS = {
Expand Down Expand Up @@ -75,6 +79,10 @@
"vae_encoder": "encoder.",
"vae_decoder": "decoder.",
},
"argmaxinc/mlx-FLUX.1-dev": {
"vae_encoder": "encoder.",
"vae_decoder": "decoder.",
},
}

_FLOAT16 = mx.bfloat16
Expand Down Expand Up @@ -704,7 +712,7 @@ def load_flux(
hf_hub_download(key, "config.json")
weights = mx.load(flux_weights_ckpt)

if model_key == "argmaxinc/mlx-FLUX.1-schnell":
if model_key in ["argmaxinc/mlx-FLUX.1-schnell", "argmaxinc/mlx-FLUX.1-dev"]:
weights = flux_state_dict_adjustments(
weights,
prefix="",
Expand Down
5 changes: 4 additions & 1 deletion python/src/diffusionkit/mlx/scripts/generate_images.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,21 @@
"sd3-8b-unreleased": 1024,
"argmaxinc/mlx-FLUX.1-schnell": 512,
"argmaxinc/mlx-FLUX.1-schnell-4bit-quantized": 512,
"argmaxinc/mlx-FLUX.1-dev": 512,
}
WIDTH = {
"argmaxinc/mlx-stable-diffusion-3-medium": 512,
"sd3-8b-unreleased": 1024,
"argmaxinc/mlx-FLUX.1-schnell": 512,
"argmaxinc/mlx-FLUX.1-schnell-4bit-quantized": 512,
"argmaxinc/mlx-FLUX.1-dev": 512,
}
SHIFT = {
"argmaxinc/mlx-stable-diffusion-3-medium": 3.0,
"sd3-8b-unreleased": 3.0,
"argmaxinc/mlx-FLUX.1-schnell": 1.0,
"argmaxinc/mlx-FLUX.1-schnell-4bit-quantized": 1.0,
"argmaxinc/mlx-FLUX.1-dev": 1.0,
}


Expand Down Expand Up @@ -111,7 +114,7 @@ def cli():
args.a16 = True

if "FLUX" in args.model_version and args.cfg > 0.0:
logger.warning("Disabling CFG for FLUX.1-schnell model.")
logger.warning(f"Disabling CFG for {args.model_version} model.")
args.cfg = 0.0

if args.benchmark_mode:
Expand Down
65 changes: 65 additions & 0 deletions python/src/diffusionkit/mlx/test-conversion-mlx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import unittest
import mlx.core as mx
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please follow the UnitTest.TestCase usage and make this a simple unit test like this? Also, no need to upload to hub from within the test 👍

import os
from pathlib import Path
import sys

current_dir = Path(__file__).resolve().parent
parent_dir = current_dir.parent
sys.path.append(str(parent_dir))

try:
from .config import FLUX_DEV, FLUX_SCHNELL, MMDiTConfig
from .mmdit import MMDiT
from .model_io import flux_state_dict_adjustments
except ImportError:
from diffusionkit.mlx.config import FLUX_DEV, FLUX_SCHNELL, MMDiTConfig
from diffusionkit.mlx.mmdit import MMDiT
from diffusionkit.mlx.model_io import flux_state_dict_adjustments

class TestFluxConversion(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.custom_hf_home = "/Volumes/USB/huggingface/hub"
os.environ["HF_HOME"] = cls.custom_hf_home

def load_flux_weights(self, model_key="flux-dev"):
config = FLUX_DEV if model_key == "flux-dev" else FLUX_SCHNELL
repo_id = "black-forest-labs/FLUX.1-dev" if model_key == "flux-dev" else "black-forest-labs/FLUX.1-schnell"
file_name = "flux1-dev.safetensors" if model_key == "flux-dev" else "flux1-schnell.safetensors"

hf_home = os.getenv("HF_HOME", os.path.expanduser("~/.cache/huggingface"))
local_file = os.path.join(hf_home, "hub", repo_id.split("/")[-1], file_name)

if not os.path.exists(local_file):
self.fail(f"Test file not found: {local_file}. Please download it before running the test.")

weights = mx.load(local_file)
return weights, config

def test_flux_conversion(self):
weights, config = self.load_flux_weights("flux-dev")

model = MMDiT(config)
mlx_model = mx.tree_flatten(model)
mlx_dict = {m[0]: m[1] for m in mlx_model if isinstance(m[1], mx.array)}

adjusted_weights = flux_state_dict_adjustments(
weights, prefix="", hidden_size=config.hidden_size, mlp_ratio=config.mlp_ratio
)

weights_set = set(adjusted_weights.keys())
mlx_dict_set = set(mlx_dict.keys())

self.assertEqual(len(weights_set - mlx_dict_set), 0, "There are keys in weights but not in model")
self.assertEqual(len(mlx_dict_set - weights_set), 0, "There are keys in model but not in weights")

mismatches = 0
for k in weights_set & mlx_dict_set:
if adjusted_weights[k].shape != mlx_dict[k].shape:
mismatches += 1

self.assertEqual(mismatches, 0, f"Found {mismatches} shape mismatches between weights and model")

if __name__ == "__main__":
unittest.main()
65 changes: 65 additions & 0 deletions python/test-gen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import os
from pathlib import Path
from diffusionkit.mlx import FluxPipeline
from huggingface_hub import HfFolder, HfApi
from PIL import Image

# Define cache paths
usb_cache_path = "/Volumes/USB/huggingface/cache"
local_cache_path = os.path.expanduser("~/.cache/huggingface")


# Function to set and verify cache directory
def set_hf_cache():
if os.path.exists("/Volumes/USB"):
os.environ["HF_HOME"] = usb_cache_path
Path(usb_cache_path).mkdir(parents=True, exist_ok=True)
print(f"Using USB cache: {usb_cache_path}")
else:
os.environ["HF_HOME"] = local_cache_path
print(f"USB not found. Using local cache: {local_cache_path}")

print(f"HF_HOME is set to: {os.environ['HF_HOME']}")
HfFolder.save_token(HfFolder.get_token())


# Set cache before initializing the pipeline
set_hf_cache()

# Initialize the pipeline
pipeline = FluxPipeline(
shift=1.0,
model_version="FLUX.1-dev",
low_memory_mode=True,
a16=True,
w16=True,
)

# Load LoRA weights
# pipeline.load_lora_weights("XLabs-AI/flux-RealismLora")

# Define image generation parameters
HEIGHT = 512
WIDTH = 512
NUM_STEPS = 10 # 4 for FLUX.1-schnell, 50 for SD3
CFG_WEIGHT = 0. # for FLUX.1-schnell, 5. for SD3
# LORA_SCALE = 0.8 # LoRA strength

# Define the prompt
prompt = "A photo realistic cat holding a sign that says hello world in the style of a snapchat from 2015"

# Generate the image
image, _ = pipeline.generate_image(
prompt,
cfg_weight=CFG_WEIGHT,
num_steps=NUM_STEPS,
latent_size=(HEIGHT // 8, WIDTH // 8),
# lora_scale=LORA_SCALE,
)

# Save the generated image
output_format = "png"
output_quality = 100
image.save(f"flux_image.{output_format}", format=output_format, quality=output_quality)

print(f"Image generation complete. Saved image in {output_format} format.")
Loading