From 0aac3088e302202de32463c8b280a2d9b4d13a55 Mon Sep 17 00:00:00 2001 From: Yorick van Pelt Date: Thu, 10 Oct 2024 17:30:21 +0200 Subject: [PATCH] Ruff format --- predict.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/predict.py b/predict.py index 60e29ef..4dcac0a 100644 --- a/predict.py +++ b/predict.py @@ -168,7 +168,7 @@ def base_setup( total_mem = torch.cuda.get_device_properties(0).total_memory self.offload = total_mem < 48 * 1024**3 if self.offload: - print("GPU memory is:", total_mem / 1024 ** 3, ", offloading models") + print("GPU memory is:", total_mem / 1024**3, ", offloading models") device = "cuda" max_length = 256 if self.flow_model_name == "flux-schnell" else 512 @@ -197,25 +197,25 @@ def base_setup( extra_args = { "compile_whole_model": True, "compile_extras": True, - "compile_blocks": True + "compile_blocks": True, } else: extra_args = { "compile_whole_model": False, "compile_extras": False, - "compile_blocks": False + "compile_blocks": False, } if self.offload: extra_args |= { "offload_text_encoder": True, "offload_vae": True, - "offload_flow": True + "offload_flow": True, } self.fp8_pipe = FluxPipeline.load_pipeline_from_config_path( f"fp8/configs/config-1-{flow_model_name}-h100.json", shared_models=shared_models, - **extra_args + **extra_args, ) if compile_fp8: