Skip to content

Commit

Permalink
compute capability is actually correct here
Browse files Browse the repository at this point in the history
  • Loading branch information
daanelson committed Oct 1, 2024
1 parent 8b17584 commit 72a5086
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,9 +147,6 @@ def base_setup(
# need > 48 GB of ram to store all models in VRAM
self.offload = "A40" in gpu_name

# fp8 only works on H100 and l40s atm.
self.disable_fp8 = disable_fp8 or ("H100" not in gpu_name and "L40S" not in gpu_name)

device = "cuda"
max_length = 256 if self.flow_model_name == "flux-schnell" else 512
self.t5 = load_t5(device, max_length=max_length)
Expand All @@ -169,6 +166,10 @@ def base_setup(
shared_models = LoadedModels(
flow=None, ae=self.ae, clip=self.clip, t5=self.t5, config=None
)

# fp8 only works w/compute capability >= 8.9
self.disable_fp8 = disable_fp8 or torch.cuda.get_device_capability() < (8, 9)

if not self.disable_fp8:
self.fp8_pipe = FluxPipeline.load_pipeline_from_config_path(
f"fp8/configs/config-1-{flow_model_name}-h100.json",
Expand Down

0 comments on commit 72a5086

Please sign in to comment.