diff --git a/setup.py b/setup.py index 6cd4eee4..9ccd9774 100644 --- a/setup.py +++ b/setup.py @@ -9,13 +9,13 @@ # If WITH_CUDA is defined -if os.getenv("WITH_CUDA") is None: - if os.environ.get("WITH_CUDA", "0") == "1": - use_cuda = True - elif os.environ.get("WITH_CUDA", "0") == "0": - use_cuda = False - else: - raise ValueError("Invalid flag with WITH_CUDA environment variable. Expected '0' or '1'") +env_with_cuda = os.getenv("WITH_CUDA") +if env_with_cuda is not None: + if env_with_cuda not in ("0", "1"): + raise ValueError( + "Invalid flag with WITH_CUDA environment variable. Expected '0' or '1'" + ) + use_cuda = env_with_cuda == "1" else: use_cuda = torch.cuda._is_compiled()