Skip to content

Commit

Permalink
Merge pull request #354 from torchmd/osx_arm64_build_fix
Browse files Browse the repository at this point in the history
Linux Aarch64 build fix
  • Loading branch information
stefdoerr authored Jan 31, 2025
2 parents 8374e96 + 35dbb90 commit 1deecd1
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 18 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ jobs:
pip -vv install .
fi
env:
CPU_ONLY: 1
WITH_CUDA: "0"

- name: Lint with flake8
run: |
Expand All @@ -89,7 +89,7 @@ jobs:
- name: Run tests
run: pytest -v -s --durations=10
env:
CPU_ONLY: 1
WITH_CUDA: "0"
SKIP_TORCH_COMPILE: ${{ runner.os == 'Windows' && 'true' || 'false' }}
OMP_PREFIX: ${{ runner.os == 'macOS' && '/Users/runner/miniconda3/envs/test' || '' }}
CPU_TRAIN: ${{ runner.os == 'macOS' && 'true' || 'false' }}
Expand Down
57 changes: 41 additions & 16 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,36 @@
import subprocess
from setuptools import setup, find_packages
import torch
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, include_paths, CppExtension
from torch.utils.cpp_extension import (
BuildExtension,
CUDAExtension,
include_paths,
CppExtension,
)
import os
import sys

is_windows = sys.platform == 'win32'
is_windows = sys.platform == "win32"

try:
version = (
subprocess.check_output(["git", "describe", "--abbrev=0", "--tags"])
.strip()
.decode("utf-8")
)
except:
except Exception:
print("Failed to retrieve the current version, defaulting to 0")
version = "0"
# If CPU_ONLY is defined
force_cpu_only = os.environ.get("CPU_ONLY", None) is not None
use_cuda = torch.cuda._is_compiled() if not force_cpu_only else False

# If WITH_CUDA is defined
if os.environ.get("WITH_CUDA", "0") == "1":
use_cuda = True
else:
use_cuda = torch.cuda._is_compiled()


def set_torch_cuda_arch_list():
""" Set the CUDA arch list according to the architectures the current torch installation was compiled for.
"""Set the CUDA arch list according to the architectures the current torch installation was compiled for.
This function is a no-op if the environment variable TORCH_CUDA_ARCH_LIST is already set or if torch was not compiled with CUDA support.
"""
if not os.environ.get("TORCH_CUDA_ARCH_LIST"):
Expand All @@ -35,20 +45,24 @@ def set_torch_cuda_arch_list():
formatted_versions += "+PTX"
os.environ["TORCH_CUDA_ARCH_LIST"] = formatted_versions


set_torch_cuda_arch_list()

extension_root= os.path.join("torchmdnet", "extensions")
neighbor_sources=["neighbors_cpu.cpp"]
extension_root = os.path.join("torchmdnet", "extensions")
neighbor_sources = ["neighbors_cpu.cpp"]
if use_cuda:
neighbor_sources.append("neighbors_cuda.cu")
neighbor_sources = [os.path.join(extension_root, "neighbors", source) for source in neighbor_sources]
neighbor_sources = [
os.path.join(extension_root, "neighbors", source) for source in neighbor_sources
]

ExtensionType = CppExtension if not use_cuda else CUDAExtension
extensions = ExtensionType(
name='torchmdnet.extensions.torchmdnet_extensions',
sources=[os.path.join(extension_root, "torchmdnet_extensions.cpp")] + neighbor_sources,
name="torchmdnet.extensions.torchmdnet_extensions",
sources=[os.path.join(extension_root, "torchmdnet_extensions.cpp")]
+ neighbor_sources,
include_dirs=include_paths(),
define_macros=[('WITH_CUDA', 1)] if use_cuda else [],
define_macros=[("WITH_CUDA", 1)] if use_cuda else [],
)

if __name__ == "__main__":
Expand All @@ -58,8 +72,19 @@ def set_torch_cuda_arch_list():
packages=find_packages(),
ext_modules=[extensions],
cmdclass={
'build_ext': BuildExtension.with_options(no_python_abi_suffix=True, use_ninja=False)},
"build_ext": BuildExtension.with_options(
no_python_abi_suffix=True, use_ninja=False
)
},
include_package_data=True,
entry_points={"console_scripts": ["torchmd-train = torchmdnet.scripts.train:main"]},
package_data={"torchmdnet": ["extensions/torchmdnet_extensions.so"] if not is_windows else ["extensions/torchmdnet_extensions.dll"]},
entry_points={
"console_scripts": ["torchmd-train = torchmdnet.scripts.train:main"]
},
package_data={
"torchmdnet": (
["extensions/torchmdnet_extensions.so"]
if not is_windows
else ["extensions/torchmdnet_extensions.dll"]
)
},
)

0 comments on commit 1deecd1

Please sign in to comment.