From 3bf7029b92598220d78ee97415908cfcec7358eb Mon Sep 17 00:00:00 2001 From: technillogue Date: Fri, 8 Nov 2024 14:13:40 -0500 Subject: [PATCH 1/5] remove some einops usage --- flux/math.py | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/flux/math.py b/flux/math.py index f30fa66..537a90c 100644 --- a/flux/math.py +++ b/flux/math.py @@ -1,5 +1,5 @@ import torch -from einops import rearrange +#from einops import rearrange from torch import Tensor from torch.nn.attention import SDPBackend, sdpa_kernel @@ -10,19 +10,28 @@ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor: # Only enable flash attention backend with sdpa_kernel(SDPBackend.FLASH_ATTENTION): x = torch.nn.functional.scaled_dot_product_attention(q, k, v) - x = rearrange(x, "B H L D -> B L (H D)") - + # x = rearrange(x, "B H L D -> B L (H D)") + x = x.transpose(1, 2).contiguous().view(x.size(0), x.size(2), -1) return x def rope(pos: Tensor, dim: int, theta: int) -> Tensor: assert dim % 2 == 0 - scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim + # f64 is problematic + # https://github.com/pytorch/TensorRT/blob/v2.4.0/py/torch_tensorrt/dynamo/conversion/converter_utils.py#L380 + scale = torch.arange(0, dim, 2, dtype=torch.float32, device=pos.device) / dim + # scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim omega = 1.0 / (theta**scale) - out = torch.einsum("...n,d->...nd", pos, omega) - out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1) - out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2) - return out.float() + # out = torch.einsum("...n,d->...nd", pos, omega) + out = pos.unsqueeze(-1) * omega + # out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1) + cos_out = torch.cos(out) + sin_out = torch.sin(out) + out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1) + # out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2) + # Reshaping the tensor to (..., n, d, 2, 2) + out = out.view(*out.shape[:-1], 2, 2) + return out # .float() def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]: From fd1546df201b3fbbfba2dfda479570b6691df242 Mon Sep 17 00:00:00 2001 From: technillogue Date: Tue, 17 Sep 2024 13:59:01 -0400 Subject: [PATCH 2/5] formatting, initial changes, types, inlining, and initial attempt --- flux/model.py | 4 ++-- flux/modules/layers.py | 3 ++- flux/sampling.py | 16 ++++++++++------ predict.py | 22 ++++++++++++++++++++++ 4 files changed, 36 insertions(+), 9 deletions(-) diff --git a/flux/model.py b/flux/model.py index 75a681f..00e5f03 100644 --- a/flux/model.py +++ b/flux/model.py @@ -81,7 +81,7 @@ def __init__(self, params: FluxParams): def forward( self, - img: Tensor, + img: Tensor, # (bs, dynamic, 64) img_ids: Tensor, txt: Tensor, txt_ids: Tensor, @@ -93,7 +93,7 @@ def forward( raise ValueError("Input img and txt tensors must have 3 dimensions.") # running on sequences img - img = self.img_in(img) + img = self.img_in(img) # (bs, dynamic, hidden_size) vec = self.time_in(timestep_embedding(timesteps, 256)) if self.params.guidance_embed: if guidance is None: diff --git a/flux/modules/layers.py b/flux/modules/layers.py index 091ddf6..73753f7 100644 --- a/flux/modules/layers.py +++ b/flux/modules/layers.py @@ -25,7 +25,7 @@ def forward(self, ids: Tensor) -> Tensor: return emb.unsqueeze(1) -def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0): +def timestep_embedding(t: Tensor, dim, max_period=10000): """ Create sinusoidal timestep embeddings. :param t: a 1-D Tensor of N indices, one per batch element. @@ -34,6 +34,7 @@ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 10 :param max_period: controls the minimum frequency of the embeddings. :return: an (N, D) Tensor of positional embeddings. """ + time_factor = torch.tensor(1000.0) t = time_factor * t half = dim // 2 freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( diff --git a/flux/sampling.py b/flux/sampling.py index c250e64..bfd60b8 100644 --- a/flux/sampling.py +++ b/flux/sampling.py @@ -30,7 +30,9 @@ def get_noise( ) -def prepare(t5: HFEmbedder, clip: HFEmbedder, img: Tensor, prompt: str | list[str]) -> dict[str, Tensor]: +def prepare( + t5: HFEmbedder, clip: HFEmbedder, img: Tensor, prompt: str | list[str] +) -> dict[str, Tensor]: bs, c, h, w = img.shape if bs == 1 and not isinstance(prompt, str): bs = len(prompt) @@ -104,7 +106,7 @@ def denoise_single_item( vec: Tensor, timesteps: list[float], guidance: float = 4.0, - compile_run: bool = False + compile_run: bool = False, ): img = img.unsqueeze(0) img_ids = img_ids.unsqueeze(0) @@ -127,8 +129,8 @@ def denoise_single_item( img_ids=img_ids, txt=txt, txt_ids=txt_ids, - y=vec, timesteps=t_vec, + y=vec, guidance=guidance_vec, ) @@ -136,6 +138,7 @@ def denoise_single_item( return img, model + def denoise( model: Flux, # model input @@ -147,7 +150,7 @@ def denoise( # sampling parameters timesteps: list[float], guidance: float = 4.0, - compile_run: bool = False + compile_run: bool = False, ): batch_size = img.shape[0] output_imgs = [] @@ -162,13 +165,14 @@ def denoise( vec[i], timesteps, guidance, - compile_run + compile_run, ) compile_run = False output_imgs.append(denoised_img) return torch.cat(output_imgs), model + def unpack(x: Tensor, height: int, width: int) -> Tensor: return rearrange( x, @@ -177,4 +181,4 @@ def unpack(x: Tensor, height: int, width: int) -> Tensor: w=math.ceil(width / 16), ph=2, pw=2, - ) + ) diff --git a/predict.py b/predict.py index 4b68c0a..68bc4c9 100644 --- a/predict.py +++ b/predict.py @@ -1,4 +1,7 @@ +import torch_tensorrt import os +#os.environ["TORCH_LOGS"] = "+dynamic" +#os.environ["TORCH_COMPILE_DEBUG"] = "1" import time from typing import Any, Dict, Optional @@ -148,7 +151,26 @@ def base_setup( self.offload = "A40" in gpu_name device = "cuda" + self.ae = load_ae(self.flow_model_name, device="cpu" if self.offload else device) + inp = [torch.rand([1, 3, 1024, 1024], device="cuda")] + opt_ae = torch_tensorrt.compile(self.ae, inputs=inp, options={"truncate_long_and_double": True}) + torch_tensorrt.save(opt_ae, "autoencoder.engine", inputs=inp) + self.ae = opt_ae max_length = 256 if self.flow_model_name == "flux-schnell" else 512 + self.ae = load_ae(self.flow_model_name, device="cpu" if self.offload else device) + if os.path.exists("decoder.engine"): + t = time.time() + self.ae.decoder = torch.export.load("decoder.engine").module() + print("loading decoder took", time.time()-t) + else: + #inputs = [torch.randn([1, 3, 1024, 1024]) # enc/dec + t = time.time() + inputs = [torch.randn([1, 16, 128, 128], device="cuda")] # dec + dec = torch_tensorrt.compile(self.ae.decoder, inputs=inputs, options={"truncate_long_and_double": True}) + torch_tensorrt.save(dec, "decoder.engine", inputs=inputs) + print("compiling and saving decoder took", time.time()-t) + self.ae.decoder = dec + self.t5 = load_t5(device, max_length=max_length) self.clip = load_clip(device) self.flux = load_flow_model( From 9db707f839f5e9cb1135a50086a085d9e7d3103d Mon Sep 17 00:00:00 2001 From: technillogue Date: Thu, 10 Oct 2024 00:57:07 -0400 Subject: [PATCH 3/5] use nn.Sequential to remove python control flow from autoencoder up/downsampling --- flux/modules/autoencoder.py | 53 ++++++++++++------------------------- 1 file changed, 17 insertions(+), 36 deletions(-) diff --git a/flux/modules/autoencoder.py b/flux/modules/autoencoder.py index 75159f7..83965a7 100644 --- a/flux/modules/autoencoder.py +++ b/flux/modules/autoencoder.py @@ -127,24 +127,21 @@ def __init__( curr_res = resolution in_ch_mult = (1,) + tuple(ch_mult) - self.in_ch_mult = in_ch_mult - self.down = nn.ModuleList() + self.in_ch_mult: tuple[int] = in_ch_mult + down_layers = [] block_in = self.ch for i_level in range(self.num_resolutions): - block = nn.ModuleList() attn = nn.ModuleList() block_in = ch * in_ch_mult[i_level] block_out = ch * ch_mult[i_level] for _ in range(self.num_res_blocks): - block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) - block_in = block_out - down = nn.Module() - down.block = block - down.attn = attn + down_layers.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) + block_in = block_out # ? + # originally this provided for attn layers, but those are never actually created if i_level != self.num_resolutions - 1: - down.downsample = Downsample(block_in) + down_layers.append(Downsample(block_in)) curr_res = curr_res // 2 - self.down.append(down) + self.down = nn.Sequential(*down_layers) # middle self.mid = nn.Module() @@ -158,18 +155,10 @@ def __init__( def forward(self, x: Tensor) -> Tensor: # downsampling - hs = [self.conv_in(x)] - for i_level in range(self.num_resolutions): - for i_block in range(self.num_res_blocks): - h = self.down[i_level].block[i_block](hs[-1]) - if len(self.down[i_level].attn) > 0: - h = self.down[i_level].attn[i_block](h) - hs.append(h) - if i_level != self.num_resolutions - 1: - hs.append(self.down[i_level].downsample(hs[-1])) + h = self.conv_in(h) + h = self.down(h) # middle - h = hs[-1] h = self.mid.block_1(h) h = self.mid.attn_1(h) h = self.mid.block_2(h) @@ -214,21 +203,19 @@ def __init__( self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) # upsampling - self.up = nn.ModuleList() + up_layers = [] for i_level in reversed(range(self.num_resolutions)): - block = nn.ModuleList() - attn = nn.ModuleList() + blocks = [] block_out = ch * ch_mult[i_level] for _ in range(self.num_res_blocks + 1): - block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) + blocks.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) block_in = block_out - up = nn.Module() - up.block = block - up.attn = attn if i_level != 0: - up.upsample = Upsample(block_in) + blocks.append(Upsample(block_in)) curr_res = curr_res * 2 - self.up.insert(0, up) # prepend to get consistent order + # ??? gross + up_layers = blocks + up_layers # prepend to get consistent order + self.up = nn.Sequential(*up_layers) # end self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) @@ -244,13 +231,7 @@ def forward(self, z: Tensor) -> Tensor: h = self.mid.block_2(h) # upsampling - for i_level in reversed(range(self.num_resolutions)): - for i_block in range(self.num_res_blocks + 1): - h = self.up[i_level].block[i_block](h) - if len(self.up[i_level].attn) > 0: - h = self.up[i_level].attn[i_block](h) - if i_level != 0: - h = self.up[i_level].upsample(h) + h = self.up(h) # end h = self.norm_out(h) From 29aa8c6ac653501ab287b1e2e29e77b05cf57e28 Mon Sep 17 00:00:00 2001 From: technillogue Date: Mon, 14 Oct 2024 17:57:52 -0400 Subject: [PATCH 4/5] try to keep the same state_dict structure lazily create up_descending after state dict is already loaded, but only do it once --- flux/modules/autoencoder.py | 61 +++++++++++++++++++++++++++++-------- 1 file changed, 49 insertions(+), 12 deletions(-) diff --git a/flux/modules/autoencoder.py b/flux/modules/autoencoder.py index 83965a7..dc9a08c 100644 --- a/flux/modules/autoencoder.py +++ b/flux/modules/autoencoder.py @@ -105,6 +105,16 @@ def forward(self, x: Tensor): x = self.conv(x) return x +class DownBlock(nn.Module): + def __init__(self, block: list, downsample: nn.Module) -> None: + super().__init__() + # we're doing this instead of a flat nn.Sequential to preserve the keys "block" "downsample" + self.block = nn.Sequential(*block) + self.downsample = downsample + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.downsample(self.block(x)) + class Encoder(nn.Module): def __init__( @@ -127,20 +137,25 @@ def __init__( curr_res = resolution in_ch_mult = (1,) + tuple(ch_mult) - self.in_ch_mult: tuple[int] = in_ch_mult + self.in_ch_mult = in_ch_mult down_layers = [] block_in = self.ch + # ideally, this would all append to a single flat nn.Sequential + # we cannot do this due to the existing state dict keys for i_level in range(self.num_resolutions): - attn = nn.ModuleList() block_in = ch * in_ch_mult[i_level] block_out = ch * ch_mult[i_level] + block_layers = [] for _ in range(self.num_res_blocks): - down_layers.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) + block_layers.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) block_in = block_out # ? # originally this provided for attn layers, but those are never actually created if i_level != self.num_resolutions - 1: - down_layers.append(Downsample(block_in)) + downsample = Downsample(block_in) curr_res = curr_res // 2 + else: + downsample = nn.Identity() + down_layers.append(DownBlock(block_layers, downsample)) self.down = nn.Sequential(*down_layers) # middle @@ -168,6 +183,15 @@ def forward(self, x: Tensor) -> Tensor: h = self.conv_out(h) return h +class UpBlock(nn.Module): + def __init__(self, block: list, upsample: nn.Module) -> None: + super().__init__() + self.block = nn.Sequential(*block) + self.upsample = upsample + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.upsample(self.block(x)) + class Decoder(nn.Module): def __init__( @@ -203,24 +227,37 @@ def __init__( self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) # upsampling - up_layers = [] + up_blocks = [] + # 3, 2, 1, 0, descending order for i_level in reversed(range(self.num_resolutions)): - blocks = [] + level_blocks = [] block_out = ch * ch_mult[i_level] for _ in range(self.num_res_blocks + 1): - blocks.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) + level_blocks.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) block_in = block_out if i_level != 0: - blocks.append(Upsample(block_in)) + upsample = Upsample(block_in) curr_res = curr_res * 2 - # ??? gross - up_layers = blocks + up_layers # prepend to get consistent order - self.up = nn.Sequential(*up_layers) + else: + upsample = nn.Identity() + # 0, 1, 2, 3, ascending order + up_blocks.insert(0, UpBlock(level_blocks, upsample)) # prepend to get consistent order + self.up = nn.Sequential(*up_blocks) # end self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) + # this is a hack to get something like property but only evaluate it once + # we're doing it like this so that up_descending isn't in the state_dict keys + # without adding anything conditional to the main flow + def __getattr__(self, name): + if name == "up_descending": + self.up_descending = nn.Sequential(*reversed(self.up)) + Decoder.__getattr__ = nn.Module.__getattr__ + return self.up_descending + return super().__getattr__(name) + def forward(self, z: Tensor) -> Tensor: # z to block_in h = self.conv_in(z) @@ -231,7 +268,7 @@ def forward(self, z: Tensor) -> Tensor: h = self.mid.block_2(h) # upsampling - h = self.up(h) + h = self.up_descending(h) # end h = self.norm_out(h) From c982583264050c32cf79c6b43a6b7a1cfa91f797 Mon Sep 17 00:00:00 2001 From: technillogue Date: Thu, 7 Nov 2024 03:44:57 +0000 Subject: [PATCH 5/5] conditionally download and use decoder engine --- predict.py | 30 ++++++++++++++++++------------ 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/predict.py b/predict.py index 68bc4c9..8b374a2 100644 --- a/predict.py +++ b/predict.py @@ -47,6 +47,9 @@ "https://weights.replicate.delivery/default/falconai/nsfw-image-detection.tar" ) +DECODER_URL = "https://weights.replicate.delivery/default/official-models/flux/ae/decoder.engine" +DECODER_PATH = "model-cache/ae/decoder.engine" + # Suppress diffusers nsfw warnings logging.getLogger("diffusers").setLevel(logging.CRITICAL) logging.getLogger("transformers").setLevel(logging.CRITICAL) @@ -151,22 +154,28 @@ def base_setup( self.offload = "A40" in gpu_name device = "cuda" - self.ae = load_ae(self.flow_model_name, device="cpu" if self.offload else device) - inp = [torch.rand([1, 3, 1024, 1024], device="cuda")] - opt_ae = torch_tensorrt.compile(self.ae, inputs=inp, options={"truncate_long_and_double": True}) - torch_tensorrt.save(opt_ae, "autoencoder.engine", inputs=inp) - self.ae = opt_ae max_length = 256 if self.flow_model_name == "flux-schnell" else 512 + # we still need to load the encoder but it would be better to avoid loading the decoder twice self.ae = load_ae(self.flow_model_name, device="cpu" if self.offload else device) - if os.path.exists("decoder.engine"): + if not os.getenv("COMPILE_ENGINE"): + if not os.path.exists(DECODER_PATH): + download_weights(DECODER_URL, DECODER_PATH) t = time.time() - self.ae.decoder = torch.export.load("decoder.engine").module() - print("loading decoder took", time.time()-t) + self.ae.decoder = torch.export.load(DECODER_PATH).module() + print(f"loading decoder took {time.time() - t:.3f}s") else: #inputs = [torch.randn([1, 3, 1024, 1024]) # enc/dec t = time.time() inputs = [torch.randn([1, 16, 128, 128], device="cuda")] # dec - dec = torch_tensorrt.compile(self.ae.decoder, inputs=inputs, options={"truncate_long_and_double": True}) + self.ae.decoder.up_descending # access + self.orig_decoder = self.ae.decoder + base = {"truncate_long_and_double": True} + best = { + "num_avg_timing_iters": 2, + "use_fast_partitioner": False, + "optimization_level": 5, + } + dec = torch_tensorrt.compile(self.ae.decoder, inputs=inputs, options=base | best) torch_tensorrt.save(dec, "decoder.engine", inputs=inputs) print("compiling and saving decoder took", time.time()-t) self.ae.decoder = dec @@ -177,9 +186,6 @@ def base_setup( self.flow_model_name, device="cpu" if self.offload else device ) self.flux = self.flux.eval() - self.ae = load_ae( - self.flow_model_name, device="cpu" if self.offload else device - ) self.num_steps = 4 if self.flow_model_name == "flux-schnell" else 28 self.shift = self.flow_model_name != "flux-schnell"