diff --git a/bench/library/benchmark.py b/bench/library/benchmark.py index 2bb1aa9e..a3e5159c 100644 --- a/bench/library/benchmark.py +++ b/bench/library/benchmark.py @@ -1,28 +1,25 @@ import argparse import time +from contextlib import nullcontext import numpy as np import torch from tqdm.auto import tqdm -from quanto.tensor.core import int2, int4, unpack_weights +from quanto.library import disable_extensions def get_unpack_bench(bits, device): qmax = 2**bits a = torch.randint(0, qmax, [10240, 10240], dtype=torch.uint8).to(device) - bitsdtype = int2 if bits == 2 else int4 - def torch_fn(): - return unpack_weights(a, bitsdtype) - - def kernel_fn(): + def bench_fn(): return torch.ops.quanto.unpack(a, bits) - return [torch_fn, kernel_fn] + return bench_fn -def timing(get_bench_functions, device, iterations=10): +def timing(get_bench_func, device, iterations=10): def synchronize(device): if device.type == "cuda": torch.cuda.synchronize() @@ -53,15 +50,18 @@ def elapsed_time(self, other): synchronize(device) + bench_func = get_bench_func(device) + # Warmup to load library + bench_func() latencies = np.empty((iterations, 2)) for i in tqdm(range(iterations)): - bench_functions = get_bench_functions(device) - for j, fn in enumerate(bench_functions): + for j, context in enumerate([disable_extensions(), nullcontext()]): start_event = timing_event(device) end_event = timing_event(device) synchronize(device) start_event.record() - fn() + with context: + bench_func() end_event.record() synchronize(device) latencies[i, j] = start_event.elapsed_time(end_event) @@ -92,12 +92,10 @@ def main(): all_kernels = ["unpack_2bit", "unpack_4bit"] kernels = all_kernels if args.kernel is None else [args.kernel] for kernel in kernels: - get_bench_functions = GET_BENCH_FUNCTIONS[kernel] - torch_ms, kernel_ms = timing(get_bench_functions, device, iterations=args.it) - ratio = torch_ms / kernel_ms - print( - f"\n{kernel}[{device.type}]: torch = {torch_ms:.3f} ms, kernel = {kernel_ms:.3f} ms, ratio = {ratio:.1f}x" - ) + get_bench_fn = GET_BENCH_FUNCTIONS[kernel] + python_ms, ext_ms = timing(get_bench_fn, device, iterations=args.it) + ratio = python_ms / ext_ms + print(f"\n{kernel}[{device.type}]: python = {python_ms:.3f} ms, ext = {ext_ms:.3f} ms, ratio = {ratio:.1f}x") if __name__ == "__main__": diff --git a/quanto/library/ops.py b/quanto/library/ops.py index 02cf9112..064472e1 100644 --- a/quanto/library/ops.py +++ b/quanto/library/ops.py @@ -1,3 +1,4 @@ +import warnings from contextlib import contextmanager import torch @@ -38,10 +39,16 @@ def define(name, schema): torch.library.define(f"{libname}::{name}", schema) # Provide the inplementation for all dispatch key in the main library - @torch.library.impl("quanto::unpack", "default") + @torch.library.impl(f"quanto::{name}", "default") def impl(*args, **kwargs): if _ext_enabled: - return getattr(torch.ops.quanto_ext, name)(*args, **kwargs) + try: + return getattr(torch.ops.quanto_ext, name)(*args, **kwargs) + except Exception as e: + warnings.warn( + f"A {type(e)} exception occured while calling the optimized kernel for quanto::{name}." + "Falling back to default implementation." + ) return getattr(torch.ops.quanto_py, name)(*args, **kwargs)