Skip to content

Commit

Permalink
Disable fp8 (#29)
Browse files Browse the repository at this point in the history
* don't run fp8 on gpus that can't run fp8
  • Loading branch information
daanelson authored Oct 1, 2024
1 parent 1fe5797 commit c720e2d
Showing 1 changed file with 17 additions and 8 deletions.
25 changes: 17 additions & 8 deletions predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ def base_setup(
flow_model_name: str,
compile_fp8: bool = False,
compile_bf16: bool = False,
disable_fp8: bool = False,
) -> None:
self.flow_model_name = flow_model_name
print(f"Booting model {self.flow_model_name}")
Expand Down Expand Up @@ -166,13 +167,17 @@ def base_setup(
flow=None, ae=self.ae, clip=self.clip, t5=self.t5, config=None
)

self.fp8_pipe = FluxPipeline.load_pipeline_from_config_path(
f"fp8/configs/config-1-{flow_model_name}-h100.json",
shared_models=shared_models,
)
# fp8 only works w/compute capability >= 8.9
self.disable_fp8 = disable_fp8 or torch.cuda.get_device_capability() < (8, 9)

if not self.disable_fp8:
self.fp8_pipe = FluxPipeline.load_pipeline_from_config_path(
f"fp8/configs/config-1-{flow_model_name}-h100.json",
shared_models=shared_models,
)

if compile_fp8:
self.compile_fp8()
if compile_fp8:
self.compile_fp8()

if compile_bf16:
self.compile_bf16()
Expand Down Expand Up @@ -480,14 +485,16 @@ def predict(
) -> List[Path]:
hws_kwargs = self.preprocess(aspect_ratio, seed, megapixels)

if go_fast:
if go_fast and not self.disable_fp8:
imgs, np_imgs = self.fp8_predict(
prompt,
num_outputs,
num_inference_steps=num_inference_steps,
**hws_kwargs,
)
else:
if self.disable_fp8:
print("running bf16 model, fp8 disabled")
imgs, np_imgs = self.base_predict(
prompt,
num_outputs,
Expand Down Expand Up @@ -544,7 +551,7 @@ def predict(
go_fast = False
hws_kwargs = self.preprocess(aspect_ratio, seed, megapixels)

if go_fast:
if go_fast and not self.disable_fp8:
imgs, np_imgs = self.fp8_predict(
prompt,
num_outputs,
Expand All @@ -555,6 +562,8 @@ def predict(
**hws_kwargs,
)
else:
if self.disable_fp8:
print("running bf16 model, fp8 disabled")
imgs, np_imgs = self.base_predict(
prompt,
num_outputs,
Expand Down

0 comments on commit c720e2d

Please sign in to comment.