Skip to content

Commit

Permalink
up
Browse files Browse the repository at this point in the history
  • Loading branch information
metascroy committed Feb 10, 2025
1 parent 6d7ead0 commit 159e270
Showing 1 changed file with 26 additions and 26 deletions.
52 changes: 26 additions & 26 deletions torchao/experimental/ops/mps/test/test_lowbit.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,32 +10,6 @@
import torch
from parameterized import parameterized

try:
print("TRYING")
for nbit in range(1, 8):
getattr(torch.ops.torchao, f"_linear_fp_act_{nbit}bit_weight")
getattr(torch.ops.torchao, f"_pack_weight_{nbit}bit")
except AttributeError:
try:
print("LOADING LIB")
libname = "libtorchao_ops_mps_aten.dylib"
libpath = os.path.abspath(os.path.join(os.path.dirname(__file__), "../cmake-out/lib/", libname))
print("AT ", libpath)
torch.ops.load_library(libpath)
print("LOADED")
except Exception as e:
print("FAILED TO LOAD")
raise e
# raise RuntimeError(f"Failed to load library {libpath}")
else:
try:
for nbit in range(1, 8):
getattr(torch.ops.torchao, f"_linear_fp_act_{nbit}bit_weight")
getattr(torch.ops.torchao, f"_pack_weight_{nbit}bit")
except AttributeError as e:
raise e


class TestLowBitQuantWeightsLinear(unittest.TestCase):
CASES = [
(nbit, *param)
Expand Down Expand Up @@ -105,4 +79,30 @@ def test_linear(self, nbit, M=1, K=32, N=32, group_size=32):

if __name__ == "__main__":
print("RUNNING UNIT TESTS")
try:
print("TRYING")
for nbit in range(1, 8):
print("NBIT", nbit)
getattr(torch.ops.torchao, f"_linear_fp_act_{nbit}bit_weight")
getattr(torch.ops.torchao, f"_pack_weight_{nbit}bit")
except AttributeError:
try:
print("LOADING LIB")
libname = "libtorchao_ops_mps_aten.dylib"
libpath = os.path.abspath(os.path.join(os.path.dirname(__file__), "../cmake-out/lib/", libname))
print("AT ", libpath)
torch.ops.load_library(libpath)
print("LOADED")
except Exception as e:
print("FAILED TO LOAD")
raise e
# raise RuntimeError(f"Failed to load library {libpath}")
else:
try:
for nbit in range(1, 8):
getattr(torch.ops.torchao, f"_linear_fp_act_{nbit}bit_weight")
getattr(torch.ops.torchao, f"_pack_weight_{nbit}bit")
except AttributeError as e:
raise e

unittest.main()

0 comments on commit 159e270

Please sign in to comment.