Skip to content

Commit

Permalink
Enable ROCM in CI (#999)
Browse files Browse the repository at this point in the history
* Enable ROCM in CI

---------

Co-authored-by: amdfaa <[email protected]>
  • Loading branch information
msaroufim and amdfaa authored Jan 17, 2025
1 parent cf45336 commit d96c6a7
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 4 deletions.
13 changes: 10 additions & 3 deletions .github/workflows/regression_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ concurrency:
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}

permissions:
id-token: write
contents: read

jobs:
test-nightly:
strategy:
Expand All @@ -33,10 +37,16 @@ jobs:
torch-spec: '--pre torch --index-url https://download.pytorch.org/whl/nightly/cpu'
gpu-arch-type: "cpu"
gpu-arch-version: ""
- name: ROCM Nightly
runs-on: linux.rocm.gpu.2
torch-spec: '--pre torch --index-url https://download.pytorch.org/whl/nightly/rocm6.3'
gpu-arch-type: "rocm"
gpu-arch-version: "6.3"

uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
with:
timeout: 120
no-sudo: ${{ matrix.gpu-arch-type == 'rocm' }}
runner: ${{ matrix.runs-on }}
gpu-arch-type: ${{ matrix.gpu-arch-type }}
gpu-arch-version: ${{ matrix.gpu-arch-version }}
Expand Down Expand Up @@ -71,7 +81,6 @@ jobs:
torch-spec: 'torch==2.5.1 --index-url https://download.pytorch.org/whl/cu121'
gpu-arch-type: "cuda"
gpu-arch-version: "12.1"

- name: CPU 2.3
runs-on: linux.4xlarge
torch-spec: 'torch==2.3.0 --index-url https://download.pytorch.org/whl/cpu'
Expand Down Expand Up @@ -99,8 +108,6 @@ jobs:
conda create -n venv python=3.9 -y
conda activate venv
echo "::group::Install newer objcopy that supports --set-section-alignment"
yum install -y devtoolset-10-binutils
export PATH=/opt/rh/devtoolset-10/root/usr/bin/:$PATH
python -m pip install --upgrade pip
pip install ${{ matrix.torch-spec }}
pip install -r dev-requirements.txt
Expand Down
2 changes: 1 addition & 1 deletion torchao/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,7 +607,7 @@ def _torch_version_at_least(min_version):
def is_MI300():
if torch.cuda.is_available() and torch.version.hip:
mxArchName = ["gfx940", "gfx941", "gfx942"]
archName = torch.cuda.get_device_properties().gcnArchName
archName = torch.cuda.get_device_properties(0).gcnArchName
for arch in mxArchName:
if arch in archName:
return True
Expand Down

0 comments on commit d96c6a7

Please sign in to comment.