Skip to content

Commit

Permalink
Skip Unit Tests for ROCm CI (#1563)
Browse files Browse the repository at this point in the history
* skip failing unit tests for ROCm CI

* fix util import
  • Loading branch information
petrex authored Jan 17, 2025
1 parent d96c6a7 commit a1c67b9
Show file tree
Hide file tree
Showing 16 changed files with 71 additions and 1 deletion.
Empty file added test/__init__.py
Empty file.
4 changes: 4 additions & 0 deletions test/dtypes/test_affine_quantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import unittest

import torch
from test_utils import skip_if_rocm
from torch.testing._internal import common_utils
from torch.testing._internal.common_utils import (
TestCase,
Expand Down Expand Up @@ -89,6 +90,7 @@ def test_tensor_core_layout_transpose(self):
aqt_shape = aqt.shape
self.assertEqual(aqt_shape, shape)

@skip_if_rocm("ROCm development in progress")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@common_utils.parametrize(
"apply_quant", get_quantization_functions(True, True, "cuda", True)
Expand Down Expand Up @@ -168,6 +170,7 @@ def apply_uint6_weight_only_quant(linear):

deregister_aqt_quantized_linear_dispatch(dispatch_condition)

@skip_if_rocm("ROCm development in progress")
@common_utils.parametrize("apply_quant", get_quantization_functions(True, True))
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_print_quantized_module(self, apply_quant):
Expand All @@ -180,6 +183,7 @@ class TestAffineQuantizedBasic(TestCase):
COMMON_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])
COMMON_DTYPES = [torch.bfloat16]

@skip_if_rocm("ROCm development in progress")
@common_utils.parametrize("device", COMMON_DEVICES)
@common_utils.parametrize("dtype", COMMON_DTYPES)
def test_flatten_unflatten(self, device, dtype):
Expand Down
2 changes: 2 additions & 0 deletions test/dtypes/test_floatx.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import unittest

import torch
from test_utils import skip_if_rocm
from torch.testing._internal.common_utils import (
TestCase,
instantiate_parametrized_tests,
Expand Down Expand Up @@ -108,6 +109,7 @@ def test_to_copy_device(self, ebits, mbits):
@parametrize("ebits,mbits", _Floatx_DTYPES)
@parametrize("bias", [False, True])
@parametrize("dtype", [torch.half, torch.bfloat16])
@skip_if_rocm("ROCm development in progress")
@unittest.skipIf(is_fbcode(), reason="broken in fbcode")
def test_fpx_weight_only(self, ebits, mbits, bias, dtype):
N, OC, IC = 4, 256, 64
Expand Down
3 changes: 3 additions & 0 deletions test/float8/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
pytest.skip("Unsupported PyTorch version", allow_module_level=True)


from test_utils import skip_if_rocm

from torchao.float8.config import (
CastConfig,
Float8LinearConfig,
Expand Down Expand Up @@ -423,6 +425,7 @@ def test_linear_from_config_params(
@pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)])
@pytest.mark.parametrize("linear_bias", [True, False])
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
@skip_if_rocm("ROCm development in progress")
def test_linear_from_recipe(
self,
recipe_name,
Expand Down
2 changes: 2 additions & 0 deletions test/hqq/test_hqq_affine.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import unittest

import torch
from test_utils import skip_if_rocm

from torchao.quantization import (
MappingType,
Expand Down Expand Up @@ -110,6 +111,7 @@ def test_hqq_plain_5bit(self):
ref_dot_product_error=0.000704,
)

@skip_if_rocm("ROCm development in progress")
def test_hqq_plain_4bit(self):
self._test_hqq(
dtype=torch.uint4,
Expand Down
7 changes: 7 additions & 0 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@
except ModuleNotFoundError:
has_gemlite = False

from test_utils import skip_if_rocm

logger = logging.getLogger("INFO")

torch.manual_seed(0)
Expand Down Expand Up @@ -569,6 +571,7 @@ def test_per_token_linear_cpu(self):
self._test_per_token_linear_impl("cpu", dtype)

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@skip_if_rocm("ROCm development in progress")
def test_per_token_linear_cuda(self):
for dtype in (torch.float32, torch.float16, torch.bfloat16):
self._test_per_token_linear_impl("cuda", dtype)
Expand Down Expand Up @@ -687,6 +690,7 @@ def test_dequantize_int8_weight_only_quant_subclass(self, device, dtype):
@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.")
# @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")
@skip_if_rocm("ROCm development in progress")
def test_dequantize_int4_weight_only_quant_subclass(self, device, dtype):
if device == "cpu":
self.skipTest(f"Temporarily skipping for {device}")
Expand All @@ -706,6 +710,7 @@ def test_dequantize_int4_weight_only_quant_subclass(self, device, dtype):
@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.")
# @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")
@skip_if_rocm("ROCm development in progress")
def test_dequantize_int4_weight_only_quant_subclass_grouped(self, device, dtype):
if device == "cpu":
self.skipTest(f"Temporarily skipping for {device}")
Expand Down Expand Up @@ -899,6 +904,7 @@ def test_aq_float8_dynamic_quant_tensorwise_scaling_subclass(self, device, dtype
@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.")
# @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")
@skip_if_rocm("ROCm development in progress")
def test_int4_weight_only_quant_subclass(self, device, dtype):
if device == "cpu":
self.skipTest(f"Temporarily skipping for {device}")
Expand All @@ -918,6 +924,7 @@ def test_int4_weight_only_quant_subclass(self, device, dtype):
@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.")
# @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")
@skip_if_rocm("ROCm development in progress")
def test_int4_weight_only_quant_subclass_grouped(self, device, dtype):
if dtype != torch.bfloat16:
self.skipTest(f"Fails for {dtype}")
Expand Down
2 changes: 2 additions & 0 deletions test/kernel/test_galore_downproj.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import torch
from galore_test_utils import make_data
from test_utils import skip_if_rocm

from torchao.prototype.galore.kernels.matmul import set_tuner_top_k as matmul_tuner_topk
from torchao.prototype.galore.kernels.matmul import triton_mm_launcher
Expand All @@ -29,6 +30,7 @@

@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU")
@pytest.mark.parametrize("M, N, rank, allow_tf32, fp8_fast_accum, dtype", TEST_CONFIGS)
@skip_if_rocm("ROCm development in progress")
def test_galore_downproj(M, N, rank, allow_tf32, fp8_fast_accum, dtype):
torch.backends.cuda.matmul.allow_tf32 = allow_tf32
MAX_DIFF = MAX_DIFF_tf32 if allow_tf32 else MAX_DIFF_no_tf32
Expand Down
3 changes: 3 additions & 0 deletions test/prototype/test_awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
if TORCH_VERSION_AT_LEAST_2_3:
from torchao.prototype.awq import AWQObservedLinear, awq_uintx, insert_awq_observer_

from test_utils import skip_if_rocm


class ToyLinearModel(torch.nn.Module):
def __init__(self, m=512, n=256, k=128):
Expand Down Expand Up @@ -113,6 +115,7 @@ def test_awq_loading(device, qdtype):

@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="requires nightly pytorch")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@skip_if_rocm("ROCm development in progress")
def test_save_weights_only():
dataset_size = 100
l1, l2, l3 = 512, 256, 128
Expand Down
2 changes: 2 additions & 0 deletions test/prototype/test_low_bit_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
except ImportError:
lpmm = None

from test_utils import skip_if_rocm

_DEVICES = get_available_devices()

Expand Down Expand Up @@ -112,6 +113,7 @@ class TestOptim(TestCase):
)
@parametrize("dtype", [torch.float32, torch.bfloat16])
@parametrize("device", _DEVICES)
@skip_if_rocm("ROCm development in progress")
def test_optim_smoke(self, optim_name, dtype, device):
if optim_name.endswith("Fp8") and device == "cuda":
if not TORCH_VERSION_AT_LEAST_2_4:
Expand Down
3 changes: 3 additions & 0 deletions test/prototype/test_splitk.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,16 @@
except ImportError:
triton_available = False

from test_utils import skip_if_rocm

from torchao.utils import skip_if_compute_capability_less_than


@unittest.skipIf(not triton_available, "Triton is required but not available")
@unittest.skipIf(not torch.cuda.is_available(), "CUDA is required")
class TestFP8Gemm(TestCase):
@skip_if_compute_capability_less_than(9.0)
@skip_if_rocm("ROCm development in progress")
def test_gemm_split_k(self):
dtype = torch.float16
qdtype = torch.float8_e4m3fn
Expand Down
2 changes: 2 additions & 0 deletions test/quantization/test_galore_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
dequantize_blockwise,
quantize_blockwise,
)
from test_utils import skip_if_rocm

from torchao.prototype.galore.kernels import (
triton_dequant_blockwise,
Expand Down Expand Up @@ -82,6 +83,7 @@ def test_galore_quantize_blockwise(dim1, dim2, dtype, signed, blocksize):
"dim1,dim2,dtype,signed,blocksize",
TEST_CONFIGS,
)
@skip_if_rocm("ROCm development in progress")
def test_galore_dequant_blockwise(dim1, dim2, dtype, signed, blocksize):
g = torch.randn(dim1, dim2, device="cuda", dtype=dtype) * 0.01

Expand Down
3 changes: 3 additions & 0 deletions test/quantization/test_marlin_qqq.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import pytest
import torch
from test_utils import skip_if_rocm
from torch import nn
from torch.testing._internal.common_utils import TestCase, run_tests

Expand Down Expand Up @@ -45,6 +46,7 @@ def setUp(self):
)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available")
@skip_if_rocm("ROCm development in progress")
def test_marlin_qqq(self):
output_ref = self.model(self.input)
for group_size in [-1, 128]:
Expand All @@ -66,6 +68,7 @@ def test_marlin_qqq(self):

@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="Needs PyTorch 2.5+")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available")
@skip_if_rocm("ROCm development in progress")
def test_marlin_qqq_compile(self):
model_copy = copy.deepcopy(self.model)
model_copy.forward = torch.compile(model_copy.forward, fullgraph=True)
Expand Down
4 changes: 3 additions & 1 deletion test/sparsity/test_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import pytest
import torch
from test_utils import skip_if_rocm
from torch import nn
from torch.testing._internal.common_utils import TestCase, run_tests

Expand Down Expand Up @@ -37,6 +38,7 @@ def setUp(self):
)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available")
@skip_if_rocm("ROCm development in progress")
def test_quant_sparse_marlin_layout_eager(self):
apply_fake_sparsity(self.model)
model_copy = copy.deepcopy(self.model)
Expand All @@ -48,13 +50,13 @@ def test_quant_sparse_marlin_layout_eager(self):
# Sparse + quantized
quantize_(self.model, int4_weight_only(layout=MarlinSparseLayout()))
sparse_result = self.model(self.input)

assert torch.allclose(
dense_result, sparse_result, atol=3e-1
), "Results are not close"

@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="Needs PyTorch 2.5+")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available")
@skip_if_rocm("ROCm development in progress")
def test_quant_sparse_marlin_layout_compile(self):
apply_fake_sparsity(self.model)
model_copy = copy.deepcopy(self.model)
Expand Down
3 changes: 3 additions & 0 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
from torchao.sparsity.marlin import inject_24, marlin_24_workspace, pack_to_marlin_24
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, compute_max_diff, is_fbcode

if torch.version.hip is not None:
pytest.skip("Skipping the test in ROCm", allow_module_level=True)

if is_fbcode():
pytest.skip(
"Skipping the test in fbcode since we don't have TARGET file for kernels"
Expand Down
3 changes: 3 additions & 0 deletions test/test_s8s4_linear_cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
from torchao.quantization.utils import group_quantize_tensor_symmetric
from torchao.utils import compute_max_diff

if torch.version.hip is not None:
pytest.skip("Skipping the test in ROCm", allow_module_level=True)

S8S4_LINEAR_CUTLASS_DTYPE = [torch.float16, torch.bfloat16]
S8S4_LINEAR_CUTLASS_BATCH_SIZE = [1, 4, 8, 16, 32, 64]
S8S4_LINEAR_CUTLASS_SIZE_MNK = [
Expand Down
29 changes: 29 additions & 0 deletions test/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,40 @@
import functools
import unittest
from unittest.mock import patch

import pytest
import torch

from torchao.utils import TorchAOBaseTensor, torch_version_at_least


def skip_if_rocm(message=None):
"""Decorator to skip tests on ROCm platform with custom message.
Args:
message (str, optional): Additional information about why the test is skipped.
"""

def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
if torch.version.hip is not None:
skip_message = "Skipping the test in ROCm"
if message:
skip_message += f": {message}"
pytest.skip(skip_message)
return func(*args, **kwargs)

return wrapper

# Handle both @skip_if_rocm and @skip_if_rocm() syntax
if callable(message):
func = message
message = None
return decorator(func)
return decorator


class TestTorchVersionAtLeast(unittest.TestCase):
def test_torch_version_at_least(self):
test_cases = [
Expand Down

0 comments on commit a1c67b9

Please sign in to comment.