Skip to content

Commit

Permalink
Make fp8 work on older GPUs (#34)
Browse files Browse the repository at this point in the history
* fp8: fall back to float32 matmul on cuda capability < 8.9

This re-enables the use of fp8 on older GPUs, which can be useful
to save vram.

* Don't compile fp8 when offloaded, it's going to be slow anyways
  • Loading branch information
yorickvP authored Nov 4, 2024
1 parent d052e21 commit a9a42fb
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 13 deletions.
42 changes: 32 additions & 10 deletions fp8/float8_quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,16 +275,38 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
prev_dims = x.shape[:-1]
x = x.view(-1, self.in_features)

# float8 matmul, much faster than float16 matmul w/ float32 accumulate on ADA devices!
out = torch._scaled_mm( # noqa
x,
self.float8_data.T,
scale_a=self.input_scale_reciprocal,
scale_b=self.scale_reciprocal,
bias=self.bias,
out_dtype=self.weight.dtype,
use_fast_accum=True,
)
device = x.device
if x.device.type != 'cpu' and torch.cuda.get_device_capability(x.device) >= (8, 9):
# float8 matmul, much faster than float16 matmul w/ float32 accumulate on ADA devices!
out = torch._scaled_mm( # noqa
x,
self.float8_data.T,
scale_a=self.input_scale_reciprocal,
scale_b=self.scale_reciprocal,
bias=self.bias,
out_dtype=self.weight.dtype,
use_fast_accum=True,
)
else:
# Plain matrix multiplication for non-ADA devices
# Assuming x is in float8 and self.float8_data is in float8 as well
# Convert to float32, perform the multiplication, and then apply scaling and bias if necessary

# Convert float8 to float32 for the multiplication
x_float32 = x.to(torch.float32)
float8_data_float32 = self.float8_data.T.to(torch.float32)

# Regular matrix multiplication
out = torch.matmul(x_float32, float8_data_float32)

# Scale the output accordingly
out = out * (self.input_scale_reciprocal * self.scale_reciprocal)

# Add bias if it exists
if self.bias is not None:
out += self.bias
out = out.to(self.weight.dtype)

if IS_TORCH_2_4:
out = out[0]
return out.view(*prev_dims, self.out_features)
Expand Down
29 changes: 26 additions & 3 deletions predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,11 @@ def base_setup(
self.falcon_processor = ViTImageProcessor.from_pretrained(FALCON_MODEL_NAME)

# need > 48 GB of ram to store all models in VRAM
self.offload = "A40" in gpu_name
total_mem = torch.cuda.get_device_properties(0).total_memory
self.offload = total_mem < 48 * 1024**3
if self.offload:
print("GPU memory is:", total_mem / 1024**3, ", offloading models")
compile_fp8 = False

device = "cuda"
max_length = 256 if self.flow_model_name == "flux-schnell" else 512
Expand All @@ -187,13 +191,32 @@ def base_setup(
flow=None, ae=self.ae, clip=self.clip, t5=self.t5, config=None
)

# fp8 only works w/compute capability >= 8.9
self.disable_fp8 = disable_fp8 or torch.cuda.get_device_capability() < (8, 9)
self.disable_fp8 = disable_fp8

if not self.disable_fp8:
if compile_fp8:
extra_args = {
"compile_whole_model": True,
"compile_extras": True,
"compile_blocks": True,
}
else:
extra_args = {
"compile_whole_model": False,
"compile_extras": False,
"compile_blocks": False,
}

if self.offload:
extra_args |= {
"offload_text_encoder": True,
"offload_vae": True,
"offload_flow": True,
}
self.fp8_pipe = FluxPipeline.load_pipeline_from_config_path(
f"fp8/configs/config-1-{flow_model_name}-h100.json",
shared_models=shared_models,
**extra_args,
)

if compile_fp8:
Expand Down

0 comments on commit a9a42fb

Please sign in to comment.