Skip to content

Commit 8ed7e69

Browse files
committed
feat(library): add default python operations
This is mainly for debug for now, but it allows to disable the library extensions and use instead implementations based on pytorch operators only.
1 parent 5c4951b commit 8ed7e69

20 files changed

+149
-30
lines changed

quanto/library/README.md

+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# Quanto operations library
2+
3+
This contains the `quanto::` operations, available in python under `torch.ops.quanto`.
4+
5+
To add a new operation:
6+
7+
- add a definition for the operation in `library/ops.py`,
8+
- provide a default implementation using pytorch operators only under `library/python`,
9+
- provide optimized kernels for all devices under `library/ext`.

quanto/library/__init__.py

+2-7
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,3 @@
1-
import torch
2-
3-
from .cpp import *
1+
from .ext import *
42
from .ops import *
5-
6-
7-
if torch.backends.mps.is_available():
8-
from .mps import *
3+
from .python import *

quanto/library/builtin/__init__.py

Whitespace-only changes.

quanto/library/ext/README.md

+23
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# Quanto library extensions
2+
3+
This folder contains the implementations of all `quanto_ext::` operations.
4+
5+
This namespace corresponds to the device-specifc optimized implementations of quanto operations.
6+
7+
Implementations can be provided as part of:
8+
9+
- the generic C++ pytorch extension under `cpp`,
10+
- the CUDA extension under `cuda`,
11+
- the Metal Performance Shader extension under `mps`.
12+
13+
The operations are defined in `library/ops.py`.
14+
15+
To provide an implementation for specific device types, use the following syntax:
16+
17+
```
18+
@torch.library.impl("quanto_ext::unpack", ["CPU", "CUDA"])
19+
def unpack(packed: torch.Tensor, bits: int) -> torch.Tensor:
20+
return ext().unpack(t, bits)
21+
```
22+
23+
Please refer to each extension folder to see how to add the actual implementation.

quanto/library/ext/__init__.py

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
import torch
2+
3+
from .cpp import *
4+
5+
6+
if torch.backends.mps.is_available():
7+
from .mps import *

quanto/library/ext/cpp/README.md

+11
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# Quanto generic C++ extension
2+
3+
Kernels in this extension must use only the C++ syntax.
4+
5+
They can use any pytorch operation defined under `aten::` or `c10::`.
6+
7+
To add a new implementation for an operation defined in `library./ops.py`:
8+
9+
- add the corresponding `.cpp` file to the list of sources in `__init__.py`,
10+
- add a binding to `pybind_module.cpp`,
11+
- provide an implementation calling the binding in `__init__.py`.

quanto/library/cpp/__init__.py quanto/library/ext/cpp/__init__.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -8,25 +8,25 @@
88
__all__ = []
99

1010

11-
_backend = None
11+
_ext = None
1212

1313

14-
def backend():
15-
"""Helper to load the CPU backend only when it is required"""
16-
global _backend
17-
if _backend is None:
14+
def ext():
15+
"""Helper to load the CPU ext only when it is required"""
16+
global _ext
17+
if _ext is None:
1818
module_path = os.path.dirname(__file__)
19-
_backend = load(
19+
_ext = load(
2020
name="quanto_cpp",
2121
sources=[
2222
f"{module_path}/unpack.cpp",
2323
f"{module_path}/pybind_module.cpp",
2424
],
2525
extra_cflags=["-O3"],
2626
)
27-
return _backend
27+
return _ext
2828

2929

30-
@impl("quanto::unpack", ["CPU", "CUDA"])
30+
@impl("quanto_ext::unpack", ["CPU", "CUDA"])
3131
def unpack_cpp(t: torch.Tensor, bits: int):
32-
return backend().unpack(t, bits)
32+
return ext().unpack(t, bits)
File renamed without changes.
File renamed without changes.
File renamed without changes.

quanto/library/ext/mps/README.md

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# Quanto Metal Performance Shaders extension
2+
3+
To add a new implementation for an operation defined in `library./ops.py`:
4+
5+
- add the corresponding `.mm` file to the list of sources in `__init__.py`,
6+
- add a binding to `pybind_module.cpp`,
7+
- provide an implementation calling the binding in `__init__.py`.

quanto/library/mps/__init__.py quanto/library/ext/mps/__init__.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -8,22 +8,22 @@
88
__all__ = []
99

1010

11-
_backend = None
11+
_ext = None
1212

1313

14-
def backend():
15-
"""Helper to load the MPS backend only when it is required"""
16-
global _backend
17-
if _backend is None:
14+
def ext():
15+
"""Helper to load the MPS extension only when it is required"""
16+
global _ext
17+
if _ext is None:
1818
module_path = os.path.dirname(__file__)
19-
_backend = load(
19+
_ext = load(
2020
name="quanto_mps",
2121
sources=[f"{module_path}/unpack.mm", f"{module_path}/pybind_module.cpp"],
2222
extra_cflags=["-std=c++17"],
2323
)
24-
return _backend
24+
return _ext
2525

2626

27-
@impl("quanto::unpack", "MPS")
27+
@impl("quanto_ext::unpack", "MPS")
2828
def unpack_mps(t: torch.Tensor, bits: int):
29-
return backend().unpack(t, bits)
29+
return ext().unpack(t, bits)
File renamed without changes.
File renamed without changes.
File renamed without changes.

quanto/library/ops.py

+44-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,48 @@
1-
from torch.library import define
1+
from contextlib import contextmanager
2+
3+
import torch
24

35

46
# This file contains the definitions of all operations under torch.ops.quanto
57

6-
define("quanto::unpack", "(Tensor self, int bits) -> Tensor")
8+
9+
_ext_enabled = True
10+
11+
12+
@contextmanager
13+
def disable_extensions():
14+
"""Disable quanto extensions (debug)"""
15+
try:
16+
global _ext_enabled
17+
_ext_enabled = False
18+
yield
19+
finally:
20+
_ext_enabled = True
21+
22+
23+
def define(name, schema):
24+
"""Define a new quanto operation.
25+
26+
The operation will actually be defined in three libraries:
27+
- the top-level quanto library as quanto::<op>,
28+
- the quanto python library as quanto_py::<op>,
29+
- the quanto extension library as quanto_ext::<op>.
30+
31+
Only the implementations for the python and extension library need
32+
to be provided: the top-level implementation for the operation is
33+
provided when calling this method and simply routes the calls towards
34+
either the python or extension implementations based on the selected
35+
mode.
36+
"""
37+
for libname in ["quanto", "quanto_py", "quanto_ext"]:
38+
torch.library.define(f"{libname}::{name}", schema)
39+
40+
# Provide the inplementation for all dispatch key in the main library
41+
@torch.library.impl("quanto::unpack", "default")
42+
def impl(*args, **kwargs):
43+
if _ext_enabled:
44+
return getattr(torch.ops.quanto_ext, name)(*args, **kwargs)
45+
return getattr(torch.ops.quanto_py, name)(*args, **kwargs)
46+
47+
48+
define("unpack", "(Tensor self, int bits) -> Tensor")

quanto/library/python/README.md

+18
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# Quanto library python/pytorch operations
2+
3+
This folder contains the implementations of all `quanto_py::` operations.
4+
5+
This namespace corresponds to the default, python-only implementations of quanto operations.
6+
7+
The operations are defined in `library/ops.py`.
8+
9+
To provide an implementation for an operation, use the following syntax:
10+
11+
```
12+
@torch.library.impl("quanto_py::unpack", "default")
13+
def unpack(packed: torch.Tensor, bits: int) -> torch.Tensor:
14+
...
15+
```
16+
17+
The implementation **must** support all device types. This is true if it
18+
is a composition of built-in PyTorch operators.

quanto/library/python/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .unpack import *

quanto/library/builtin/unpack.py quanto/library/python/unpack.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import torch
22

33

4-
@torch.libary.impl("quanto::unpack", "default")
4+
@torch.library.impl("quanto_py::unpack", "default")
55
def unpack(packed: torch.Tensor, bits: int) -> torch.Tensor:
66
"""
77
Un-Pack int4 / int2 weights (packed in a uint8) into a torch.int8 tensor

test/library/test_unpack.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,21 @@
1+
from contextlib import nullcontext
2+
13
import pytest
24
import torch
35

6+
from quanto.library import disable_extensions
47
from quanto.tensor.core import int2, int4, pack_weights
58

69

710
@pytest.mark.parametrize("bits", [2, 4], ids=["int2", "int4"])
811
@pytest.mark.parametrize("shape", [(12,), (32, 32)], ids=["vector", "matrix"])
9-
def test_unpack(bits, shape, device):
12+
@pytest.mark.parametrize("use_ext", [True, False], ids=["ext", "no-ext"])
13+
def test_unpack(bits, shape, use_ext, device):
1014
qmax = 2**bits
1115
a = torch.randint(0, qmax, shape, dtype=torch.uint8).to(device)
1216
bitsdtype = int2 if bits == 2 else int4
1317
packed_a = pack_weights(a, bitsdtype)
14-
unpacked_a = torch.ops.quanto.unpack(packed_a, bits)
18+
context = nullcontext() if use_ext else disable_extensions()
19+
with context:
20+
unpacked_a = torch.ops.quanto.unpack(packed_a, bits)
1521
assert torch.equal(unpacked_a, a)

0 commit comments

Comments
 (0)