Skip to content

Commit 460483d

Browse files
committed
Update tutorials code in NeuronSDK 2.21 release
1 parent 9919484 commit 460483d

7 files changed

+293
-3
lines changed

src/nki_samples/tutorials/sd_attention/sd_attention_nki_kernels.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
# NKI_EXAMPLE_31_BEGIN
1616
@nki.jit
1717
def fused_self_attn_for_SD_small_head_size(q_ref, k_ref, v_ref, use_causal_mask=False,
18-
mixed_percision=True):
18+
mixed_precision=True):
1919
"""
2020
Fused self attention kernel for small head dimension Stable Diffusion workload,
2121
simplified for this tutorial.
@@ -38,14 +38,14 @@ def fused_self_attn_for_SD_small_head_size(q_ref, k_ref, v_ref, use_causal_mask=
3838
3939
IO tensor dtypes:
4040
- This kernel assumes all IO tensors have the same dtype
41-
- If mixed_percision is True, then all Tensor Engine operation will be performed in
41+
- If mixed_precision is True, then all Tensor Engine operation will be performed in
4242
bfloat16 and accumulation will be performed in float32. Otherwise the intermediates
4343
will be in the same type as the inputs.
4444
"""
4545
# Use q_ref dtype as the intermediate tensor dtype
4646
# Assume all IO tensors have the same dtype
4747
kernel_dtype = q_ref.dtype
48-
pe_in_dt = nl.bfloat16 if mixed_percision else np.float32
48+
pe_in_dt = nl.bfloat16 if mixed_precision else np.float32
4949
assert q_ref.dtype == k_ref.dtype == v_ref.dtype
5050

5151
# Shape checking
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
"""
2+
Copyright (C) 2024, Amazon.com. All Rights Reserved
3+
4+
JAX implementation for SPMD tensor addition with multiple Neuron cores NKI tutorial.
5+
6+
"""
7+
# NKI_EXAMPLE_50_BEGIN
8+
import jax
9+
import jax.numpy as jnp
10+
# NKI_EXAMPLE_50_END
11+
12+
from spmd_multiple_nc_tensor_addition_nki_kernels import nki_tensor_add_nc2
13+
14+
# NKI_EXAMPLE_50_BEGIN
15+
if __name__ == "__main__":
16+
17+
seed_a, seed_b = jax.random.split(jax.random.PRNGKey(42))
18+
a = jax.random.uniform(seed_a, (512, 2048), dtype=jnp.bfloat16)
19+
b = jax.random.uniform(seed_b, (512, 2048), dtype=jnp.bfloat16)
20+
21+
output_nki = nki_tensor_add_nc2(a, b)
22+
print(f"output_nki={output_nki}")
23+
24+
output_jax = a + b
25+
print(f"output_jax={output_jax}")
26+
27+
allclose = jnp.allclose(output_jax, output_nki, atol=1e-4, rtol=1e-2)
28+
if allclose:
29+
print("NKI and JAX match")
30+
else:
31+
print("NKI and JAX differ")
32+
33+
assert allclose
34+
# NKI_EXAMPLE_50_END
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
"""
2+
Copyright (C) 2024, Amazon.com. All Rights Reserved
3+
4+
NKI implementation for SPMD tensor addition with multiple Neuron cores NKI tutorial.
5+
6+
"""
7+
import numpy as np
8+
import neuronxcc.nki as nki
9+
import neuronxcc.nki.language as nl
10+
from spmd_tensor_addition_nki_kernels import nki_tensor_add_kernel_
11+
12+
13+
# NKI_EXAMPLE_48_BEGIN
14+
def nki_tensor_add_nc2(a_input, b_input):
15+
"""NKI kernel caller to compute element-wise addition of two input tensors using multiple Neuron cores.
16+
17+
This kernel caller lifts tile-size restriction, by applying the kernel on tiles of the inputs/outputs.
18+
a_input and b_input are sharded across Neuron cores, directly utilizing Trn2 architecture capabilities
19+
20+
Args:
21+
a_input: a first input tensor, of shape [N*128, M*512]
22+
b_input: a second input tensor, of shape [N*128, M*512]
23+
24+
Returns:
25+
a tensor of shape [N*128, M*512], the result of a_input + b_input
26+
"""
27+
28+
# The SPMD launch grid denotes the number of kernel instances.
29+
# In this case, we use a 2D grid where the size of each invocation is 128x512
30+
# Since we're sharding across neuron cores on the 1st dimension we want to do our slicing at
31+
# 128 per core * 2 cores = 256
32+
grid_x = a_input.shape[0] // (128 * 2)
33+
grid_y = a_input.shape[1] // 512
34+
35+
# In addition, we distribute the kernel to physical neuron cores around the first dimension
36+
# of the spmd grid.
37+
# This means:
38+
# Physical NC [0]: kernel[n, m] where n is even
39+
# Physical NC [1]: kernel[n, m] where n is odd
40+
# notice, by specifying this information in the SPMD grid, we can use multiple neuron cores
41+
# without updating the original `nki_tensor_add_kernel_` kernel.
42+
return nki_tensor_add_kernel_[nl.spmd_dim(grid_x, nl.nc(2)), grid_y](a_input, b_input)
43+
# NKI_EXAMPLE_48_END
44+
45+
if __name__ == "__main__":
46+
a = np.random.rand(512, 2048).astype(np.float16)
47+
b = np.random.rand(512, 2048).astype(np.float16)
48+
49+
output_nki = nki_tensor_add_nc2(a, b)
50+
print(f"output_nki={output_nki}")
51+
52+
output_np = a + b
53+
print(f"output_np={output_np}")
54+
55+
allclose = np.allclose(output_np, output_nki, atol=1e-4, rtol=1e-2)
56+
if allclose:
57+
print("NKI and NumPy match")
58+
else:
59+
print("NKI and NumPy differ")
60+
61+
assert allclose
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
"""
2+
Copyright (C) 2024, Amazon.com. All Rights Reserved
3+
4+
PyTorch implementation for SPMD tensor addition with multiple Neuron cores NKI tutorial.
5+
6+
"""
7+
# NKI_EXAMPLE_49_BEGIN
8+
import torch
9+
from torch_xla.core import xla_model as xm
10+
# NKI_EXAMPLE_49_END
11+
12+
from spmd_multiple_nc_tensor_addition_nki_kernels import nki_tensor_add_nc2
13+
14+
15+
# NKI_EXAMPLE_49_BEGIN
16+
if __name__ == "__main__":
17+
device = xm.xla_device()
18+
19+
a = torch.rand((512, 2048), dtype=torch.bfloat16).to(device=device)
20+
b = torch.rand((512, 2048), dtype=torch.bfloat16).to(device=device)
21+
22+
output_nki = nki_tensor_add_nc2(a, b)
23+
print(f"output_nki={output_nki}")
24+
25+
output_torch = a + b
26+
print(f"output_torch={output_torch}")
27+
28+
allclose = torch.allclose(output_torch, output_nki, atol=1e-4, rtol=1e-2)
29+
if allclose:
30+
print("NKI and Torch match")
31+
else:
32+
print("NKI and Torch differ")
33+
34+
assert allclose
35+
# NKI_EXAMPLE_49_END
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
"""
2+
Copyright (C) 2024, Amazon.com. All Rights Reserved
3+
4+
JAX implementation for SPMD tensor addition NKI tutorial.
5+
6+
"""
7+
# NKI_EXAMPLE_30_BEGIN
8+
import jax
9+
import jax.numpy as jnp
10+
# NKI_EXAMPLE_30_END
11+
12+
from spmd_tensor_addition_nki_kernels import nki_tensor_add
13+
14+
# NKI_EXAMPLE_30_BEGIN
15+
if __name__ == "__main__":
16+
17+
seed_a, seed_b = jax.random.split(jax.random.PRNGKey(42))
18+
a = jax.random.uniform(seed_a, (256, 1024), dtype=jnp.bfloat16)
19+
b = jax.random.uniform(seed_b, (256, 1024), dtype=jnp.bfloat16)
20+
21+
output_nki = nki_tensor_add(a, b)
22+
print(f"output_nki={output_nki}")
23+
24+
output_jax = a + b
25+
print(f"output_jax={output_jax}")
26+
27+
allclose = jnp.allclose(output_jax, output_nki, atol=1e-4, rtol=1e-2)
28+
if allclose:
29+
print("NKI and JAX match")
30+
else:
31+
print("NKI and JAX differ")
32+
33+
assert allclose
34+
# NKI_EXAMPLE_30_END
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
"""
2+
Copyright (C) 2024, Amazon.com. All Rights Reserved
3+
4+
NKI implementation for SPMD tensor addition NKI tutorial.
5+
6+
"""
7+
import numpy as np
8+
# NKI_EXAMPLE_27_BEGIN
9+
import neuronxcc.nki as nki
10+
import neuronxcc.nki.language as nl
11+
12+
13+
@nki.jit
14+
def nki_tensor_add_kernel_(a_input, b_input):
15+
"""NKI kernel to compute element-wise addition of two input tensors
16+
17+
This kernel assumes strict input/output sizes can be uniformly tiled to [128,512]
18+
19+
Args:
20+
a_input: a first input tensor
21+
b_input: a second input tensor
22+
23+
Returns:
24+
c_output: an output tensor
25+
"""
26+
# Create output tensor shared between all SPMD instances as result tensor
27+
c_output = nl.ndarray(a_input.shape, dtype=a_input.dtype, buffer=nl.shared_hbm)
28+
29+
# Calculate tile offsets based on current 'program'
30+
offset_i_x = nl.program_id(0) * 128
31+
offset_i_y = nl.program_id(1) * 512
32+
33+
# Generate tensor indices to index tensors a and b
34+
ix = offset_i_x + nl.arange(128)[:, None]
35+
iy = offset_i_y + nl.arange(512)[None, :]
36+
37+
# Load input data from device memory (HBM) to on-chip memory (SBUF)
38+
# We refer to an indexed portion of a tensor as an intermediate tensor
39+
a_tile = nl.load(a_input[ix, iy])
40+
b_tile = nl.load(b_input[ix, iy])
41+
42+
# compute a + b
43+
c_tile = a_tile + b_tile
44+
45+
# store the addition results back to device memory (c_output)
46+
nl.store(c_output[ix, iy], value=c_tile)
47+
48+
# Transfer the ownership of `c_output` to the caller
49+
return c_output
50+
# NKI_EXAMPLE_27_END
51+
52+
53+
# NKI_EXAMPLE_28_BEGIN
54+
def nki_tensor_add(a_input, b_input):
55+
"""NKI kernel caller to compute element-wise addition of two input tensors
56+
57+
This kernel caller lifts tile-size restriction, by applying the kernel on tiles of the inputs/outputs
58+
59+
Args:
60+
a_input: a first input tensor, of shape [N*128, M*512]
61+
b_input: a second input tensor, of shape [N*128, M*512]
62+
63+
Returns:
64+
a tensor of shape [N*128, M*512], the result of a_input + b_input
65+
"""
66+
67+
# The SPMD launch grid denotes the number of kernel instances.
68+
# In this case, we use a 2D grid where the size of each invocation is 128x512
69+
grid_x = a_input.shape[0] // 128
70+
grid_y = a_input.shape[1] // 512
71+
72+
return nki_tensor_add_kernel_[grid_x, grid_y](a_input, b_input)
73+
# NKI_EXAMPLE_28_END
74+
75+
if __name__ == "__main__":
76+
a = np.random.rand(256, 1024).astype(np.float16)
77+
b = np.random.rand(256, 1024).astype(np.float16)
78+
79+
output_nki = nki_tensor_add(a, b)
80+
print(f"output_nki={output_nki}")
81+
82+
output_np = a + b
83+
print(f"output_np={output_np}")
84+
85+
allclose = np.allclose(output_np, output_nki, atol=1e-4, rtol=1e-2)
86+
if allclose:
87+
print("NKI and NumPy match")
88+
else:
89+
print("NKI and NumPy differ")
90+
91+
assert allclose
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
"""
2+
Copyright (C) 2024, Amazon.com. All Rights Reserved
3+
4+
PyTorch implementation for SPMD tensor addition NKI tutorial.
5+
6+
"""
7+
# NKI_EXAMPLE_29_BEGIN
8+
import torch
9+
from torch_xla.core import xla_model as xm
10+
# NKI_EXAMPLE_29_END
11+
12+
from spmd_tensor_addition_nki_kernels import nki_tensor_add
13+
14+
15+
# NKI_EXAMPLE_29_BEGIN
16+
if __name__ == "__main__":
17+
device = xm.xla_device()
18+
19+
a = torch.rand((256, 1024), dtype=torch.bfloat16).to(device=device)
20+
b = torch.rand((256, 1024), dtype=torch.bfloat16).to(device=device)
21+
22+
output_nki = nki_tensor_add(a, b)
23+
print(f"output_nki={output_nki}")
24+
25+
output_torch = a + b
26+
print(f"output_torch={output_torch}")
27+
28+
allclose = torch.allclose(output_torch, output_nki, atol=1e-4, rtol=1e-2)
29+
if allclose:
30+
print("NKI and Torch match")
31+
else:
32+
print("NKI and Torch differ")
33+
34+
assert allclose
35+
# NKI_EXAMPLE_29_END

0 commit comments

Comments
 (0)