|
1 |
| -from torch.library import define |
| 1 | +from contextlib import contextmanager |
| 2 | + |
| 3 | +import torch |
2 | 4 |
|
3 | 5 |
|
4 | 6 | # This file contains the definitions of all operations under torch.ops.quanto
|
5 | 7 |
|
6 |
| -define("quanto::unpack", "(Tensor self, int bits) -> Tensor") |
| 8 | + |
| 9 | +_ext_enabled = True |
| 10 | + |
| 11 | + |
| 12 | +@contextmanager |
| 13 | +def disable_extensions(): |
| 14 | + """Disable quanto extensions (debug)""" |
| 15 | + try: |
| 16 | + global _ext_enabled |
| 17 | + _ext_enabled = False |
| 18 | + yield |
| 19 | + finally: |
| 20 | + _ext_enabled = True |
| 21 | + |
| 22 | + |
| 23 | +def define(name, schema): |
| 24 | + """Define a new quanto operation. |
| 25 | +
|
| 26 | + The operation will actually be defined in three libraries: |
| 27 | + - the top-level quanto library as quanto::<op>, |
| 28 | + - the quanto python library as quanto_py::<op>, |
| 29 | + - the quanto extension library as quanto_ext::<op>. |
| 30 | +
|
| 31 | + Only the implementations for the python and extension library need |
| 32 | + to be provided: the top-level implementation for the operation is |
| 33 | + provided when calling this method and simply routes the calls towards |
| 34 | + either the python or extension implementations based on the selected |
| 35 | + mode. |
| 36 | + """ |
| 37 | + for libname in ["quanto", "quanto_py", "quanto_ext"]: |
| 38 | + torch.library.define(f"{libname}::{name}", schema) |
| 39 | + |
| 40 | + # Provide the inplementation for all dispatch key in the main library |
| 41 | + @torch.library.impl("quanto::unpack", "default") |
| 42 | + def impl(*args, **kwargs): |
| 43 | + if _ext_enabled: |
| 44 | + return getattr(torch.ops.quanto_ext, name)(*args, **kwargs) |
| 45 | + return getattr(torch.ops.quanto_py, name)(*args, **kwargs) |
| 46 | + |
| 47 | + |
| 48 | +define("unpack", "(Tensor self, int bits) -> Tensor") |
0 commit comments