diff --git a/predict.py b/predict.py index 840855d..45c1caa 100644 --- a/predict.py +++ b/predict.py @@ -148,7 +148,7 @@ def base_setup( total_mem = torch.cuda.get_device_properties(0).total_memory self.offload = total_mem < 48 * 1024**3 if self.offload: - print("GPU memory is:", total_mem / 1024 ** 3, ", offloading models") + print("GPU memory is:", total_mem / 1024**3, ", offloading models") device = "cuda" max_length = 256 if self.flow_model_name == "flux-schnell" else 512 @@ -177,25 +177,25 @@ def base_setup( extra_args = { "compile_whole_model": True, "compile_extras": True, - "compile_blocks": True + "compile_blocks": True, } else: extra_args = { "compile_whole_model": False, "compile_extras": False, - "compile_blocks": False + "compile_blocks": False, } if self.offload: extra_args |= { "offload_text_encoder": True, "offload_vae": True, - "offload_flow": True + "offload_flow": True, } self.fp8_pipe = FluxPipeline.load_pipeline_from_config_path( f"fp8/configs/config-1-{flow_model_name}-h100.json", shared_models=shared_models, - **extra_args + **extra_args, ) if compile_fp8: