Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TRT decoder #45

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 17 additions & 8 deletions flux/math.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import torch
from einops import rearrange
#from einops import rearrange
from torch import Tensor
from torch.nn.attention import SDPBackend, sdpa_kernel

Expand All @@ -10,19 +10,28 @@ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
# Only enable flash attention backend
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
x = rearrange(x, "B H L D -> B L (H D)")

# x = rearrange(x, "B H L D -> B L (H D)")
x = x.transpose(1, 2).contiguous().view(x.size(0), x.size(2), -1)
return x


def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
assert dim % 2 == 0
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
# f64 is problematic
# https://github.com/pytorch/TensorRT/blob/v2.4.0/py/torch_tensorrt/dynamo/conversion/converter_utils.py#L380
scale = torch.arange(0, dim, 2, dtype=torch.float32, device=pos.device) / dim
# scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
omega = 1.0 / (theta**scale)
out = torch.einsum("...n,d->...nd", pos, omega)
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
return out.float()
# out = torch.einsum("...n,d->...nd", pos, omega)
out = pos.unsqueeze(-1) * omega
# out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
cos_out = torch.cos(out)
sin_out = torch.sin(out)
out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1)
# out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
# Reshaping the tensor to (..., n, d, 2, 2)
out = out.view(*out.shape[:-1], 2, 2)
return out # .float()


def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
Expand Down
4 changes: 2 additions & 2 deletions flux/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def __init__(self, params: FluxParams):

def forward(
self,
img: Tensor,
img: Tensor, # (bs, dynamic, 64)
img_ids: Tensor,
txt: Tensor,
txt_ids: Tensor,
Expand All @@ -93,7 +93,7 @@ def forward(
raise ValueError("Input img and txt tensors must have 3 dimensions.")

# running on sequences img
img = self.img_in(img)
img = self.img_in(img) # (bs, dynamic, hidden_size)
vec = self.time_in(timestep_embedding(timesteps, 256))
if self.params.guidance_embed:
if guidance is None:
Expand Down
90 changes: 54 additions & 36 deletions flux/modules/autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,16 @@ def forward(self, x: Tensor):
x = self.conv(x)
return x

class DownBlock(nn.Module):
def __init__(self, block: list, downsample: nn.Module) -> None:
super().__init__()
# we're doing this instead of a flat nn.Sequential to preserve the keys "block" "downsample"
self.block = nn.Sequential(*block)
self.downsample = downsample

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.downsample(self.block(x))


class Encoder(nn.Module):
def __init__(
Expand All @@ -128,23 +138,25 @@ def __init__(
curr_res = resolution
in_ch_mult = (1,) + tuple(ch_mult)
self.in_ch_mult = in_ch_mult
self.down = nn.ModuleList()
down_layers = []
block_in = self.ch
# ideally, this would all append to a single flat nn.Sequential
# we cannot do this due to the existing state dict keys
for i_level in range(self.num_resolutions):
block = nn.ModuleList()
attn = nn.ModuleList()
block_in = ch * in_ch_mult[i_level]
block_out = ch * ch_mult[i_level]
block_layers = []
for _ in range(self.num_res_blocks):
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
block_in = block_out
down = nn.Module()
down.block = block
down.attn = attn
block_layers.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
block_in = block_out # ?
# originally this provided for attn layers, but those are never actually created
if i_level != self.num_resolutions - 1:
down.downsample = Downsample(block_in)
downsample = Downsample(block_in)
curr_res = curr_res // 2
self.down.append(down)
else:
downsample = nn.Identity()
down_layers.append(DownBlock(block_layers, downsample))
self.down = nn.Sequential(*down_layers)

# middle
self.mid = nn.Module()
Expand All @@ -158,18 +170,10 @@ def __init__(

def forward(self, x: Tensor) -> Tensor:
# downsampling
hs = [self.conv_in(x)]
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](hs[-1])
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
hs.append(h)
if i_level != self.num_resolutions - 1:
hs.append(self.down[i_level].downsample(hs[-1]))
h = self.conv_in(h)
h = self.down(h)

# middle
h = hs[-1]
h = self.mid.block_1(h)
h = self.mid.attn_1(h)
h = self.mid.block_2(h)
Expand All @@ -179,6 +183,15 @@ def forward(self, x: Tensor) -> Tensor:
h = self.conv_out(h)
return h

class UpBlock(nn.Module):
def __init__(self, block: list, upsample: nn.Module) -> None:
super().__init__()
self.block = nn.Sequential(*block)
self.upsample = upsample

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.upsample(self.block(x))


class Decoder(nn.Module):
def __init__(
Expand Down Expand Up @@ -214,26 +227,37 @@ def __init__(
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)

# upsampling
self.up = nn.ModuleList()
up_blocks = []
# 3, 2, 1, 0, descending order
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
level_blocks = []
block_out = ch * ch_mult[i_level]
for _ in range(self.num_res_blocks + 1):
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
level_blocks.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
block_in = block_out
up = nn.Module()
up.block = block
up.attn = attn
if i_level != 0:
up.upsample = Upsample(block_in)
upsample = Upsample(block_in)
curr_res = curr_res * 2
self.up.insert(0, up) # prepend to get consistent order
else:
upsample = nn.Identity()
# 0, 1, 2, 3, ascending order
up_blocks.insert(0, UpBlock(level_blocks, upsample)) # prepend to get consistent order
self.up = nn.Sequential(*up_blocks)

# end
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)

# this is a hack to get something like property but only evaluate it once
# we're doing it like this so that up_descending isn't in the state_dict keys
# without adding anything conditional to the main flow
def __getattr__(self, name):
if name == "up_descending":
self.up_descending = nn.Sequential(*reversed(self.up))
Decoder.__getattr__ = nn.Module.__getattr__
return self.up_descending
return super().__getattr__(name)

def forward(self, z: Tensor) -> Tensor:
# z to block_in
h = self.conv_in(z)
Expand All @@ -244,13 +268,7 @@ def forward(self, z: Tensor) -> Tensor:
h = self.mid.block_2(h)

# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
h = self.up[i_level].block[i_block](h)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h)
if i_level != 0:
h = self.up[i_level].upsample(h)
h = self.up_descending(h)

# end
h = self.norm_out(h)
Expand Down
3 changes: 2 additions & 1 deletion flux/modules/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def forward(self, ids: Tensor) -> Tensor:
return emb.unsqueeze(1)


def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
def timestep_embedding(t: Tensor, dim, max_period=10000):
"""
Create sinusoidal timestep embeddings.
:param t: a 1-D Tensor of N indices, one per batch element.
Expand All @@ -34,6 +34,7 @@ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 10
:param max_period: controls the minimum frequency of the embeddings.
:return: an (N, D) Tensor of positional embeddings.
"""
time_factor = torch.tensor(1000.0)
t = time_factor * t
half = dim // 2
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
Expand Down
16 changes: 10 additions & 6 deletions flux/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ def get_noise(
)


def prepare(t5: HFEmbedder, clip: HFEmbedder, img: Tensor, prompt: str | list[str]) -> dict[str, Tensor]:
def prepare(
t5: HFEmbedder, clip: HFEmbedder, img: Tensor, prompt: str | list[str]
) -> dict[str, Tensor]:
bs, c, h, w = img.shape
if bs == 1 and not isinstance(prompt, str):
bs = len(prompt)
Expand Down Expand Up @@ -104,7 +106,7 @@ def denoise_single_item(
vec: Tensor,
timesteps: list[float],
guidance: float = 4.0,
compile_run: bool = False
compile_run: bool = False,
):
img = img.unsqueeze(0)
img_ids = img_ids.unsqueeze(0)
Expand All @@ -127,15 +129,16 @@ def denoise_single_item(
img_ids=img_ids,
txt=txt,
txt_ids=txt_ids,
y=vec,
timesteps=t_vec,
y=vec,
guidance=guidance_vec,
)

img = img + (t_prev - t_curr) * pred.squeeze(0)

return img, model


def denoise(
model: Flux,
# model input
Expand All @@ -147,7 +150,7 @@ def denoise(
# sampling parameters
timesteps: list[float],
guidance: float = 4.0,
compile_run: bool = False
compile_run: bool = False,
):
batch_size = img.shape[0]
output_imgs = []
Expand All @@ -162,13 +165,14 @@ def denoise(
vec[i],
timesteps,
guidance,
compile_run
compile_run,
)
compile_run = False
output_imgs.append(denoised_img)

return torch.cat(output_imgs), model


def unpack(x: Tensor, height: int, width: int) -> Tensor:
return rearrange(
x,
Expand All @@ -177,4 +181,4 @@ def unpack(x: Tensor, height: int, width: int) -> Tensor:
w=math.ceil(width / 16),
ph=2,
pw=2,
)
)
34 changes: 31 additions & 3 deletions predict.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import torch_tensorrt
import os
#os.environ["TORCH_LOGS"] = "+dynamic"
#os.environ["TORCH_COMPILE_DEBUG"] = "1"
import time
from typing import Any, Dict, Optional

Expand Down Expand Up @@ -44,6 +47,9 @@
"https://weights.replicate.delivery/default/falconai/nsfw-image-detection.tar"
)

DECODER_URL = "https://weights.replicate.delivery/default/official-models/flux/ae/decoder.engine"
DECODER_PATH = "model-cache/ae/decoder.engine"

# Suppress diffusers nsfw warnings
logging.getLogger("diffusers").setLevel(logging.CRITICAL)
logging.getLogger("transformers").setLevel(logging.CRITICAL)
Expand Down Expand Up @@ -149,15 +155,37 @@ def base_setup(

device = "cuda"
max_length = 256 if self.flow_model_name == "flux-schnell" else 512
# we still need to load the encoder but it would be better to avoid loading the decoder twice
self.ae = load_ae(self.flow_model_name, device="cpu" if self.offload else device)
if not os.getenv("COMPILE_ENGINE"):
if not os.path.exists(DECODER_PATH):
download_weights(DECODER_URL, DECODER_PATH)
t = time.time()
self.ae.decoder = torch.export.load(DECODER_PATH).module()
print(f"loading decoder took {time.time() - t:.3f}s")
else:
#inputs = [torch.randn([1, 3, 1024, 1024]) # enc/dec
t = time.time()
inputs = [torch.randn([1, 16, 128, 128], device="cuda")] # dec
self.ae.decoder.up_descending # access
self.orig_decoder = self.ae.decoder
base = {"truncate_long_and_double": True}
best = {
"num_avg_timing_iters": 2,
"use_fast_partitioner": False,
"optimization_level": 5,
}
dec = torch_tensorrt.compile(self.ae.decoder, inputs=inputs, options=base | best)
torch_tensorrt.save(dec, "decoder.engine", inputs=inputs)
print("compiling and saving decoder took", time.time()-t)
self.ae.decoder = dec

self.t5 = load_t5(device, max_length=max_length)
self.clip = load_clip(device)
self.flux = load_flow_model(
self.flow_model_name, device="cpu" if self.offload else device
)
self.flux = self.flux.eval()
self.ae = load_ae(
self.flow_model_name, device="cpu" if self.offload else device
)

self.num_steps = 4 if self.flow_model_name == "flux-schnell" else 28
self.shift = self.flow_model_name != "flux-schnell"
Expand Down
Loading