Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Re-organize SLL ops, pt 8 #3663

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 0 additions & 12 deletions .github/scripts/fbgemm_gpu_test.bash
Original file line number Diff line number Diff line change
Expand Up @@ -83,18 +83,6 @@ __configure_fbgemm_gpu_test_cpu () {
# These tests have non-CPU operators referenced in @given
./uvm/copy_test.py
./uvm/uvm_test.py
./sll/triton_sll_test.py
./sll/array_jagged_bmm_jagged_out_test.py
./sll/jagged_dense_elementwise_add_test.py
./sll/jagged_flash_attention_basic_test.py
./sll/jagged_jagged_bmm_jagged_out_test.py
./sll/jagged_dense_flash_attention_test.py
./sll/jagged_dense_bmm_test.py
./sll/jagged_dense_elementwise_mul_jagged_out_test.py
./sll/jagged_jagged_bmm_test.py
./sll/jagged_softmax_test.py
./sll/jagged2_to_padded_dense_test.py
./sll/multi_head_jagged_flash_attention_test.py
)
}

Expand Down
21 changes: 1 addition & 20 deletions fbgemm_gpu/fbgemm_gpu/sll/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,6 @@
meta_jagged_self_substraction_jagged_out,
)

from fbgemm_gpu.sll.triton_sll import ( # noqa F401
jagged_dense_elementwise_mul_jagged_out,
triton_jagged_self_substraction_jagged_out,
)

from fbgemm_gpu.utils import TorchLibraryFragment

lib = TorchLibraryFragment("fbgemm")
Expand Down Expand Up @@ -262,25 +257,11 @@
},
}

# pyre-ignore[5]
sll_gpu_registrations = {
"sll_jagged_self_substraction_jagged_out": {
"CUDA": triton_jagged_self_substraction_jagged_out,
},
"sll_jagged_dense_elementwise_mul_jagged_out": {
"CUDA": jagged_dense_elementwise_mul_jagged_out,
"AutogradCUDA": jagged_dense_elementwise_mul_jagged_out,
},
}

for op_name, dispatches in sll_cpu_registrations.items():
lib.register(op_name, dispatches)

if torch.cuda.is_available():
from fbgemm_gpu.sll.triton import op_registrations

for op_name, dispatches in op_registrations.items():
lib.register(op_name, dispatches)
from fbgemm_gpu.sll.triton import op_registrations as sll_gpu_registrations

for op_name, dispatches in sll_gpu_registrations.items():
lib.register(op_name, dispatches)
16 changes: 16 additions & 0 deletions fbgemm_gpu/fbgemm_gpu/sll/triton/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,11 @@
JaggedDenseAdd, # noqa F401
)

from fbgemm_gpu.sll.triton.triton_jagged_dense_elementwise_mul_jagged_out import ( # noqa F401
jagged_dense_elementwise_mul_jagged_out,
JaggedDenseElementwiseMul, # noqa F401
)

from fbgemm_gpu.sll.triton.triton_jagged_dense_flash_attention import ( # noqa F401
jagged_dense_flash_attention,
JaggedDenseFlashAttention, # noqa F401
Expand All @@ -47,6 +52,10 @@
JaggedFlashAttentionBasic, # noqa F401
)

from fbgemm_gpu.sll.triton.triton_jagged_self_substraction_jagged_out import (
triton_jagged_self_substraction_jagged_out,
)

from fbgemm_gpu.sll.triton.triton_jagged_softmax import ( # noqa F401
jagged2_softmax,
Jagged2Softmax, # noqa F401
Expand Down Expand Up @@ -108,4 +117,11 @@
"CUDA": multi_head_jagged_flash_attention,
"AutogradCUDA": multi_head_jagged_flash_attention,
},
"sll_jagged_self_substraction_jagged_out": {
"CUDA": triton_jagged_self_substraction_jagged_out,
},
"sll_jagged_dense_elementwise_mul_jagged_out": {
"CUDA": jagged_dense_elementwise_mul_jagged_out,
"AutogradCUDA": jagged_dense_elementwise_mul_jagged_out,
},
}
22 changes: 22 additions & 0 deletions fbgemm_gpu/fbgemm_gpu/sll/triton/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,28 @@
import torch


def next_power_of_two(N: int) -> int:
if N > 4096:
raise Exception(f"{N} is too large that is not supported yet")

if N > 2048:
return 4096
elif N > 1024:
return 2048
elif N > 512:
return 1024
elif N > 256:
return 512
elif N > 128:
return 256
elif N > 64:
return 128
elif N > 32:
return 64
else:
return 32


def expect_contiguous(x: torch.Tensor) -> torch.Tensor:
if not x.is_contiguous():
return x.contiguous()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,61 +11,6 @@
import triton.language as tl


def next_power_of_two(N: int) -> int:
if N > 4096:
raise Exception(f"{N} is too large that is not supported yet")

if N > 2048:
return 4096
elif N > 1024:
return 2048
elif N > 512:
return 1024
elif N > 256:
return 512
elif N > 128:
return 256
elif N > 64:
return 128
elif N > 32:
return 64
else:
return 32


@triton.jit
def jagged_self_substraction_jagged_out_kernel(
a_ptr, # jagged
b_ptr, # jagged
a_offsets_ptr,
b_offsets_ptr,
max_seq_len,
BLOCK_SIZE: tl.constexpr,
):
pid_batch = tl.program_id(0)
pid_index = tl.program_id(1)

a_offset = tl.load(a_offsets_ptr + pid_batch)
a_length = tl.load(a_offsets_ptr + pid_batch + 1) - a_offset
a_length = tl.minimum(a_length, max_seq_len + 1)

if a_length <= 1:
return

N = a_length - 1
if pid_index >= N:
return

a_cur = tl.load(a_ptr + a_offset + pid_index)
offs = tl.arange(0, BLOCK_SIZE)
mask = offs < N
a_row = tl.load(a_ptr + a_offset + offs + 1, mask=mask)
b = a_cur - a_row

b_offset = tl.load(b_offsets_ptr + pid_batch)
tl.store(b_ptr + b_offset + pid_index * N + offs, b, mask=mask)


@triton.jit
def jagged_dense_elementwise_mul_jagged_out_kernel(
a_ptr, # 1d jagged
Expand Down Expand Up @@ -123,33 +68,6 @@ def jagged_dense_elementwise_mul_jagged_out_kernel(
c_ptrs += BLOCK_N


def triton_jagged_self_substraction_jagged_out(
jagged_A: torch.Tensor,
offsets_a: torch.Tensor,
offsets_b: torch.Tensor,
max_seq_len,
) -> torch.Tensor:
B = offsets_a.size(0) - 1

jagged_B = torch.empty(
(int(offsets_b[-1].item())), device=jagged_A.device, dtype=jagged_A.dtype
)

BLOCK_SIZE = max(next_power_of_two(max_seq_len), 16)
grid = (B, max_seq_len)

jagged_self_substraction_jagged_out_kernel[grid](
jagged_A,
jagged_B,
offsets_a,
offsets_b,
max_seq_len,
BLOCK_SIZE, # pyre-fixme[6]: For 6th argument expected `constexpr` but got `int`.
)

return jagged_B


def triton_jagged_dense_elementwise_mul_jagged_out(
jagged_A,
dense_B,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

import torch
import triton
import triton.language as tl

from .common import next_power_of_two


@triton.jit
def jagged_self_substraction_jagged_out_kernel(
a_ptr, # jagged
b_ptr, # jagged
a_offsets_ptr,
b_offsets_ptr,
max_seq_len,
BLOCK_SIZE: tl.constexpr,
):
pid_batch = tl.program_id(0)
pid_index = tl.program_id(1)

a_offset = tl.load(a_offsets_ptr + pid_batch)
a_length = tl.load(a_offsets_ptr + pid_batch + 1) - a_offset
a_length = tl.minimum(a_length, max_seq_len + 1)

if a_length <= 1:
return

N = a_length - 1
if pid_index >= N:
return

a_cur = tl.load(a_ptr + a_offset + pid_index)
offs = tl.arange(0, BLOCK_SIZE)
mask = offs < N
a_row = tl.load(a_ptr + a_offset + offs + 1, mask=mask)
b = a_cur - a_row

b_offset = tl.load(b_offsets_ptr + pid_batch)
tl.store(b_ptr + b_offset + pid_index * N + offs, b, mask=mask)


def triton_jagged_self_substraction_jagged_out(
jagged_A: torch.Tensor,
offsets_a: torch.Tensor,
offsets_b: torch.Tensor,
max_seq_len,
) -> torch.Tensor:
B = offsets_a.size(0) - 1

jagged_B = torch.empty(
(int(offsets_b[-1].item())), device=jagged_A.device, dtype=jagged_A.dtype
)

BLOCK_SIZE = max(next_power_of_two(max_seq_len), 16)
grid = (B, max_seq_len)

jagged_self_substraction_jagged_out_kernel[grid](
jagged_A,
jagged_B,
offsets_a,
offsets_b,
max_seq_len,
BLOCK_SIZE,
)

return jagged_B
10 changes: 6 additions & 4 deletions fbgemm_gpu/test/sll/array_jagged_bmm_jagged_out_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

import hypothesis.strategies as st
import torch
from fbgemm_gpu.sll.triton import triton_array_jagged_bmm_jagged_out
from hypothesis import given, settings

from .common import open_source
Expand All @@ -21,6 +20,9 @@
else:
from fbgemm_gpu.test.test_utils import gpu_unavailable, running_on_rocm

if torch.cuda.is_available():
from fbgemm_gpu.sll.triton import triton_array_jagged_bmm_jagged_out


class ArrayJaggedBmmJaggedTest(unittest.TestCase):
# pyre-fixme[56]: Pyre was not able to infer the type of argument
Expand All @@ -31,7 +33,7 @@ class ArrayJaggedBmmJaggedTest(unittest.TestCase):
)
@unittest.skipIf(*gpu_unavailable)
@unittest.skipIf(*running_on_rocm)
@settings(deadline=20000)
@settings(deadline=30000)
def test_triton_array_jagged_bmm_jagged_out(
self,
B: int,
Expand Down Expand Up @@ -157,7 +159,7 @@ def ref_array_jagged_bmm_jagged_out(
)
@unittest.skipIf(*gpu_unavailable)
@unittest.skipIf(*running_on_rocm)
@settings(deadline=20000)
@settings(deadline=30000)
def test_triton_array_jagged_bmm_jagged_out_with_grad(
self,
B: int,
Expand Down Expand Up @@ -244,7 +246,7 @@ def test_triton_array_jagged_bmm_jagged_out_with_grad(
)
@unittest.skipIf(*gpu_unavailable)
@unittest.skipIf(*running_on_rocm)
@settings(deadline=20000)
@settings(deadline=30000)
def test_triton_array_jagged_bmm_jagged_out_meta_backend(
self,
B: int,
Expand Down
3 changes: 1 addition & 2 deletions fbgemm_gpu/test/sll/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
# pyre-ignore-all-errors[56]

import fbgemm_gpu
import fbgemm_gpu.sll.cpu_sll
import fbgemm_gpu.sll.triton_sll
import fbgemm_gpu.sll
import torch

# pyre-fixme[16]: Module `fbgemm_gpu` has no attribute `open_source`.
Expand Down
Loading
Loading