Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add fallback to library and refactor bench #82

Merged
merged 3 commits into from
Feb 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 15 additions & 17 deletions bench/library/benchmark.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,25 @@
import argparse
import time
from contextlib import nullcontext

import numpy as np
import torch
from tqdm.auto import tqdm

from quanto.tensor.core import int2, int4, unpack_weights
from quanto.library import disable_extensions


def get_unpack_bench(bits, device):
qmax = 2**bits
a = torch.randint(0, qmax, [10240, 10240], dtype=torch.uint8).to(device)
bitsdtype = int2 if bits == 2 else int4

def torch_fn():
return unpack_weights(a, bitsdtype)

def kernel_fn():
def bench_fn():
return torch.ops.quanto.unpack(a, bits)

return [torch_fn, kernel_fn]
return bench_fn


def timing(get_bench_functions, device, iterations=10):
def timing(get_bench_func, device, iterations=10):
def synchronize(device):
if device.type == "cuda":
torch.cuda.synchronize()
Expand Down Expand Up @@ -53,15 +50,18 @@ def elapsed_time(self, other):

synchronize(device)

bench_func = get_bench_func(device)
# Warmup to load library
bench_func()
latencies = np.empty((iterations, 2))
for i in tqdm(range(iterations)):
bench_functions = get_bench_functions(device)
for j, fn in enumerate(bench_functions):
for j, context in enumerate([disable_extensions(), nullcontext()]):
start_event = timing_event(device)
end_event = timing_event(device)
synchronize(device)
start_event.record()
fn()
with context:
bench_func()
end_event.record()
synchronize(device)
latencies[i, j] = start_event.elapsed_time(end_event)
Expand Down Expand Up @@ -92,12 +92,10 @@ def main():
all_kernels = ["unpack_2bit", "unpack_4bit"]
kernels = all_kernels if args.kernel is None else [args.kernel]
for kernel in kernels:
get_bench_functions = GET_BENCH_FUNCTIONS[kernel]
torch_ms, kernel_ms = timing(get_bench_functions, device, iterations=args.it)
ratio = torch_ms / kernel_ms
print(
f"\n{kernel}[{device.type}]: torch = {torch_ms:.3f} ms, kernel = {kernel_ms:.3f} ms, ratio = {ratio:.1f}x"
)
get_bench_fn = GET_BENCH_FUNCTIONS[kernel]
python_ms, ext_ms = timing(get_bench_fn, device, iterations=args.it)
ratio = python_ms / ext_ms
print(f"\n{kernel}[{device.type}]: python = {python_ms:.3f} ms, ext = {ext_ms:.3f} ms, ratio = {ratio:.1f}x")


if __name__ == "__main__":
Expand Down
11 changes: 9 additions & 2 deletions quanto/library/ops.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from contextlib import contextmanager

import torch
Expand Down Expand Up @@ -38,10 +39,16 @@ def define(name, schema):
torch.library.define(f"{libname}::{name}", schema)

# Provide the inplementation for all dispatch key in the main library
@torch.library.impl("quanto::unpack", "default")
@torch.library.impl(f"quanto::{name}", "default")
def impl(*args, **kwargs):
if _ext_enabled:
return getattr(torch.ops.quanto_ext, name)(*args, **kwargs)
try:
return getattr(torch.ops.quanto_ext, name)(*args, **kwargs)
except Exception as e:
warnings.warn(
f"A {type(e)} exception occured while calling the optimized kernel for quanto::{name}."
"Falling back to default implementation."
)
return getattr(torch.ops.quanto_py, name)(*args, **kwargs)


Expand Down
Loading