Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

nki-samples updates for NeuronSDK release 2.21 #38

Merged
merged 3 commits into from
Dec 23, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
NeuronSDK 2.21: NKI-Samples Update (#6)
Update nki-samples for NeuronSDK 2.21 Beta Release.
1. Use new nki.jit decorator for kernels.
2. Added samples for the new direct allocation feature.
3. Misc tests and documentation improvements.

Co-authored-by: aws-qieqingy <[email protected]>
neuron-code-sharing-robot and aws-qieqingy authored Dec 13, 2024
commit a33dfbc883a80f3875ec534a6bb5e7d3b159543e
69 changes: 7 additions & 62 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Contributing Guidelines

Thank you for your interest in contributing to our project. Whether it's a new NKI kernel, improving existing kernel code, bug fix, new feature, correction, or additional
Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional
documentation, we greatly value feedback and contributions from our community.

Please read through this document before submitting any issues or pull requests to ensure we have all the necessary
@@ -24,13 +24,14 @@ reported the issue. Please try to include as much information as you can. Detail
Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that:

1. You are working against the latest source on the *main* branch.
2. You check existing open, and recently merged pull requests to make sure someone else hasn't addressed the problem already.
2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already.
3. You open an issue to discuss any significant work - we would hate for your time to be wasted.

To send us a pull request, please:

1. Fork the repository.
2. Modify the source; please focus on the specific changes you are contributing. If you also reformat all the code, it will be hard for us to focus on your change.
3. Please ensure your change satisfies the requirements listed in [Testing Requirements](#testing-requirements) and [Coding Guidelines](#coding-guidelines)
2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change.
3. Ensure local tests pass.
4. Commit to your fork using clear commit messages.
5. Send us a pull request, answering any default questions in the pull request interface.
6. Wait for a repository collaborator to look at your pull request, run the automated tests, and review. If additional changes or discussion is needed, a collaborator will get back to you, so please stay involved in the conversation.
@@ -39,64 +40,8 @@ To send us a pull request, please:
GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and
[creating a pull request](https://help.github.com/articles/creating-a-pull-request/).

### Testing Requirements
Running the binaries for a NKI kernel require Neuron devices on an AWS EC2 instance from trn1, trn1n, or inf2 instance families.
Details on setting up an instance can be found in [here](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/frameworks/torch/torch-setup.html).

If you would like to test your kernel without requiring a Neuron device, you can use `nki.simulate()` to run your kernel using `NumPy` input/output tensors and types.
An example can be found in the [layernorm tutorial test](test/unit/test_tutorials_layernorm.py). However, kernels with _only_ simulation tests will not be accepted.

#### Requirements for Kernels Targeting `src/reference/`

All kernels located in this folder need to have the following tests.

1. Numeric accuracy tests with `nki.baremetal`. The output from the kernel
must be validated against a CPU reference implementation. See `test_flash_attn_fwd_numerical` in [test_flash_attn_fwd.py](test/unit/test_flash_attn_fwd.py) as an example. Documentation for `nki.baremetal` is available at [here](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/nki/api/generated/nki.baremetal.html).

2. Performance benchmark tests with `nki.benchmark`. The unit test must have performance checks. At a minimum, put an assertion to verify p99 latency meets a certain threshold. See `test_flash_attn_fwd_perf` in [test_flash_attn_fwd.py](test/unit/test_flash_attn_fwd.py) as an example. Documentation for `nki.benchmark` is available at [here](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/nki/api/generated/nki.benchmark.html)

3. End-to-End integration tests that use your kernel in a model.

a. Each test should be in its own separate folder.

b. Each Test must have a `run.sh` script, that accepts an argument \<path to test_result.json\>. See [run.sh of FlashAttention](test/integration/flash_attention/run.sh) as an example.

c. The test scripts must produce benchmark results with the `benchmark` function, located in [LatencyCollector.py](test/integration/perf_utils/LatencyCollector.py). The `benchmark` function will write the latency of your E2E model to the `test_result.json`.

d. Register your test target in [run_integration.sh](test/integration/run_integration.sh).


### Coding Guidelines
Most guidelines are covered by a **PEP-8** check on all newly submitted code, which covers aspects such as code layout and basic Python naming conventions.
In addition to PEP-8, we use the following NKI specific style guidelines:

1. **Abbreviations**
* Importing NKI modules should use consistent names. For example,
```
import neuronxcc.nki as nki
import neuronxcc.nki.isa as nisa
import neuronxcc.nki.language as nl
import neuronxcc.nki.typing as nt
import numpy as np
```
2. Variable Names
* Indexing should specify partition and free dimensions along with the variable they are used for. For example:
The index for the partition dimension for tile `a` would be
```
i_p_a = nl.arange(128)[:, None]
```
while the index for the free dimension for tile `b` would be
```
i_f_b = nl.arange(512)[None, :]
```
* Name loop variables, indices, and buffers consistently, and specify their intended use in the name.

3. Documentation
* New kernels should containing inline docstrings that describe the semantics of the kernel, and provide information on the IO layout.
Upon release, we generate the documentation for our kernels and merge them into the NKI API documentation which will appear in the official AWS NKI documentation.


## Finding Contributions to Work on
## Finding contributions to work on
Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any ['help wanted'](https://github.com/aws-neuron/nki-samples/labels/help%20wanted) issues is a great place to start.


@@ -106,7 +51,7 @@ For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of
[email protected] with any additional questions or comments.


## Security Issue Notifications
## Security issue notifications
If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue.


16 changes: 0 additions & 16 deletions LICENSE

This file was deleted.

1 change: 1 addition & 0 deletions LICENSE.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
TODO: Fill LICENSE after it is finalized
93 changes: 74 additions & 19 deletions README.md
Original file line number Diff line number Diff line change
@@ -6,7 +6,7 @@ At the core of the Neuron SDK is the Neuron Compiler, which takes computation gr
them into highly optimized machine code.

[NKI](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/nki) is a Python-based programming environment designed for the compiler which
adopts commonly used NumPy and Triton-like syntax along with tile-level semantics.
adopts commonly used NumPy andTriton-like syntax along with tile-level semantics.
NKI also interoperates with the Neuron Profiler, providing insights into performance bottlenecks and instruction latencies.
It offers tensor printing support, standard error messaging, and built-in kernel simulation capabilities for efficient debugging purposes.
NKI offers two types of programming interfaces:
@@ -16,25 +16,31 @@ enabling bare-metal access to the chip for full control.

![alt "High-level flow of NKI in the Neuron Compiler. NKI emits IR immediately before the backend-IR compilation stage"](doc_assets/high-level-nki-flow.png#center "High-Level NKI Flow")

## Documentation
The latest NKI documentation can be found on the AWS Documentation site, [here](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/nki/).
Documentation for NKI kernels are both inline (docstring) and available on the documentation site's
[kernel API reference page](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/nki/api/nki.kernels.html).
### nki.language
**nki.language** enables precise control over computation and data movement on NeuronCores-- the processing units within AWS Inferentia and Trainium chips.
Developers can control data movement between device memory and on-chip memory explicitly using `nl.load()` and `nl.store()` operations.
Developers can then perform the desired computations on loaded tensors, such as element-wise operations or tensor contractions,
providing crucial performance improvements. Additionally, developers can control how computation is performed on different compute engines inside NeuronCores.
nki.language APIs are considered high-level APIs and are designed for "ease of use" for ML practitioners.
To achieve the best performance, developers can enlist the nki.isa APIs.

![alt "Diagram of the NeuronCore Architecture. It shows 4 engines: tensor, vector, scalar, and GPSIMD, connected to SBUF memory. The tensor, vector, and scalar engines are also connected to a high-speed PSUM memory bank that supports accumulate on write. Lastly the HBM (DRAM) is connected to both SBUF and PSUM memory banks."](doc_assets/pm-nc.png#scale_50#center "NeuronCore Architecture")

### nki.isa

**nki.isa** provides direct access to chip instructions to offer flexibility and fine-grained control over instruction usage and performance optimizations.
Developers can utilize various `nki.isa` instructions using the Tensor, Vector, Scalar, GP-SIMD, and DMA engines.
For example, developers can use `nki.isa.nc_matmul()` to compute a matrix multiplication using Tensor Engine.
Alternatively, developers can use `nki.isa.activation()` to apply an activation function on every element of the input tile using Scalar Engine.

## Repository Structure

### src

#### reference
This folder contains the source code of the `neuronxcc.nki.kernels`, and they are optimized kernels from the Neuron Team serving as samples.

All kernels located in this folder have numeric accuracy tests
The [reference kernels](src/reference/) are optimized reference kernels. All kernels located in this folder must have all of numeric accuracy tests
and performance benchmarks defined in the [test](test/) directory. We also demonstrate using these kernels end-to-end in our [integration tests](test/integration/).

Note that these kernels are already being deployed as part of the Neuron stack. With flash attention as an example,
[compiling Llama models with transformers-neuronx](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/libraries/transformers-neuronx/transformers-neuronx-developer-guide.html)
will automatically invoke the `flash_fwd` kernel in [attention.py](src/reference/attention.py). Therefore, replacing the framework operators with these NKI kernels likely won't result in extra performance benefit.


#### tutorials
The [tutorial kernels](src/tutorials/) are for educational purpose and include the kernels that are used in NKI guides.
@@ -52,16 +58,65 @@ verify the numeric accuracy of the operation, and publish performance results to
The [integration tests](tests/integration) folder contains integration tests of (selected) kernels. They verify the numeric accuracy of the model’s output,
and publish end-to-end performance results into the [integration benchmarks](docs/benchmarks/integration) folder.

## Maintenance Policy
NKI is currently released as **beta** while we gather feedback from our users and integrate it into the API. NKI API follow the [Neuron SDK Maintenance Policy](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/sdk-policy.html).
## Documentation
The latest NKI documentation can be found on the AWS Documentation site, [here](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/nki/).
Documentation for NKI kernels are both inline (docstring) and available on the documentation site's
[kernel API reference page](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/nki/api/nki.kernels.html).

## Getting Help
Have a look at the GitHub issues for this repository where you will find past issues customers have encountered with workarounds and clarifications.
If you cannot find a suitable issue for your use-case feel free to [file an issue](https://github.com/aws-neuron/nki-samples/issues/new) to ask for assistance or to suggest improvements. Please read [CONTRIBUTING.md](CONTRIBUTING.md) for detailed information on submitting issues.
## Versioning
NKI is currently released as **beta** while we gather feedback from our users and integrate it into the API. We will also be updating the NKI API as needed
to support new Neuron and Neuron Compiler features. While NKI is in beta we may need to make backwards-incompatible changes to incorporate feedback from
our users or to support new use-cases of NKI on Neuron devices. Upon releasing NKI as generally available (GA), we will commit to not making backwards
incompatible changes to the NKI API for any supported version of the Neuron compiler.

## Contributing
We invite you to join the NKI community! If you'd like to share kernels you create with the community, we welcome your contributions to this repository via
GitHub pull-requests as well as through filed issues discussing features, bug fixes, new use-cases, and API improvements. Please see [CONTRIBUTING.md](CONTRIBUTING.md) for more information
We invite you to join the NKI community! If you'd like to share kernels you create with the community, we welcome your contributions to this repository via.
GitHub pull-requests as well as through filed issues discussing features, bug fixes, new use-cases, and API improvements.

### Getting Help
Have a look at the GitHub issues for this repository where you will find past issues customers have encountered with workarounds and clarifications.
If you cannot find a suitable issue for your use-case feel free to file an issue asking for assistance or to suggest improvements.

In addition, extensive NKI documentation can be found [here](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/nki).

### Testing and Merging
Running the binaries for a NKI kernel require Neuron devices on an AWS EC2 instance from trn1, trn1n, or inf2 instance families.
Details on setting up an instance can be found in [here](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/frameworks/torch/torch-setup.html).

Before merging, the Neuron team will need to internally test and verify kernels work as expected. If the change is accepted,
we will manually merge your changes, and it will be merged here upon the next release.

If you would like to test your kernel without a requiring a Neuron device, you can use `nki.simulate()` to run your kernel using `NumPy` tensors and types.
An example can be found in the [layernorm tutorial test](test/unit/test_tutorials_layernorm.py).

### Coding Guidelines
Most guidelines are covered by a **PEP-8** check on all newly submitted code, which covers aspects such as code layout and basic Python naming conventions.
In addition to PEP-8, we use the following NKI specific style guidelines:

1. **Abbreviations**
* Importing NKI modules should use consistent names. For example,
```
import neuronxcc.nki as nki
import neuronxcc.nki.isa as nisa
import neuronxcc.nki.language as nl
import neuronxcc.nki.typing as nt
import numpy as np
```
2. Variable Names
* Indexing should specify partition and free dimensions along with the variable they are used for. For example:
The index for the partition dimension for tile `a` would be
```
i_p_a = nl.arange(128)[:, None]
```
while the index for the free dimension for tile `b` would be
```
i_f_b = nl.arange(512)[None, :]
```
* Name loop variables, indices, and buffers consistently, and specify their intended use in the name.

3. Documentation
* New kernels should containing inline docstrings that describe the semantics of the kernel, and provide information on the IO layout.
Upon release, we generate the documentation for our kernels and merge them into the NKI API documentation which will appear in the official AWS NKI documentation.

## Licensing
This repository is licensed under the terms of the [MIT-0 License](LICENSE.txt)
22 changes: 21 additions & 1 deletion src/reference/__init__.py
Original file line number Diff line number Diff line change
@@ -6,7 +6,27 @@
Kernels here are the same to the ones available in the
NKI Github Sample Repo.

TODO: Insert link to Github Repo when available
https://github.com/aws-neuron/nki-samples
"""
from neuronxcc.nki.kernels.attention import fused_self_attn_for_SD_small_head_size, flash_attn_bwd, flash_fwd
from neuronxcc.nki.kernels.vision import resize_nearest_fixed_dma_kernel, select_and_scatter_kernel
from neuronxcc.nki.kernels.tutorial import add_kernel_nx8x128x512
from neuronxcc.nki.kernels.allocated_attention import allocated_fused_self_attn_for_SD_small_head_size
from neuronxcc.nki.kernels.allocated_fused_linear import allocated_fused_rms_norm_qkv

from neuronxcc.nki._private_kernels.legacy.attention import \
(fused_self_attn_for_SD_small_head_size as _fused_self_attn_for_SD_small_head_size,
flash_attn_bwd as _flash_attn_bwd, flash_fwd as _flash_fwd)
from neuronxcc.nki._private_kernels.legacy.vision import (
resize_nearest_fixed_dma_kernel as _resize_nearest_fixed_dma_kernel,
select_and_scatter_kernel as _select_and_scatter_kernel)
from neuronxcc.nki._private_kernels.legacy.tutorial import add_kernel_nx8x128x512 as _add_kernel_nx8x128x512
from neuronxcc.nki._private_kernels.legacy.allocated_fused_linear import _allocated_fused_rms_norm_qkv

fused_self_attn_for_SD_small_head_size._legacy_func = _fused_self_attn_for_SD_small_head_size
flash_attn_bwd._legacy_func = _flash_attn_bwd
flash_fwd._legacy_func = _flash_fwd
resize_nearest_fixed_dma_kernel._legacy_func = _resize_nearest_fixed_dma_kernel
select_and_scatter_kernel._legacy_func = _select_and_scatter_kernel
add_kernel_nx8x128x512._legacy_func = _add_kernel_nx8x128x512
allocated_fused_rms_norm_qkv._legacy_func = _allocated_fused_rms_norm_qkv
283 changes: 283 additions & 0 deletions src/reference/allocated_attention.py

Large diffs are not rendered by default.

114 changes: 114 additions & 0 deletions src/reference/allocated_fused_linear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
"""
Copyright (c) 2024, Amazon.com. All Rights Reserved
kernels - Fused normalization with linear layers
"""

import neuronxcc.nki.language as nl
import neuronxcc.nki.isa as nisa
import neuronxcc.nki.compiler as ncc
import math
import numpy as np
from neuronxcc import nki
from neuronxcc.nki.language import par_dim

@nki.jit
def allocated_fused_rms_norm_qkv(hidden, weights, norm_dtype=nl.float32, eps=1e-6):
"""
Allocated kernel that computes RMSNorm(hidden) @ wQKV. This kernel is designed to only handle fp16/bf16 tensor types.
Internally, normalizations are cast to fp32 to avoid NaN errors.
Args:
hidden (_type_): Input tensor of the attention block in BSH layout
weights (_type_): Fused QKV linear weights, assumed to be eltwise-multiplied with RMS norm weight vector (gamma)
out_tensor (_type_): Output tensor
norm_dtype (_type_, optional): Data type for RMS norm, should be f32 to avoid NaN. Defaults to nl.float32.
eps (_type_, optional): RMS norm epsilon term. Defaults to 1e-6.
"""
# Hidden should be in BSH layout.
batch, batchless_shape = hidden.shape[0], hidden.shape[1:]
seqlen, dim = batchless_shape
_dim, head_dim = weights.shape

assert dim <= 8192 and dim & 128 == 0, "Unsupported hidden dimension"
assert _dim == dim, "Reduction dimension must match"
assert head_dim <= 512, "Head dimension must be 512 or less"

out_tensor = nl.ndarray((batch, seqlen, head_dim), dtype=hidden.dtype, buffer=nl.shared_hbm)

pmax, fmax = nl.tile_size.pmax, nl.tile_size.psum_fmax # 128, 512
ix, iy = nl.mgrid[0:pmax, 0:dim]
i_lhs = nl.mgrid[0:pmax, 0:pmax]
i_rhs = nl.mgrid[0:pmax, 0:fmax]
i_res = nl.mgrid[0:pmax, 0:fmax]
M = math.ceil(dim / pmax)
NUM_TRANSP_TILES = math.ceil(dim / fmax)
NUM_TILES = math.ceil(seqlen / pmax)
TILES_INT = math.ceil(NUM_TILES / 2)
scale = 1 / dim

iden_x, iden_y = nl.mgrid[0:pmax, 0:128]

identity_a = nl.shared_constant(np.identity(n=128, dtype=np.int8), dtype=hidden.dtype)
identity_tensor = nl.ndarray((par_dim(pmax), 128), dtype=weights.dtype, buffer=ncc.sbuf.mod_alloc(base_addr=0))
identity_tensor[iden_x, iden_y] = nl.load(identity_a, dtype=weights.dtype)
bias_placeholder = nl.ndarray((par_dim(pmax), 1), dtype=np.float32, buffer=ncc.sbuf.mod_alloc(base_addr=128*2))
bias_placeholder[...] = 0

for b in nl.affine_range(batch):
weights_buffer = nl.ndarray((M, par_dim(pmax), fmax), dtype=weights.dtype,
buffer=ncc.sbuf.mod_alloc(base_addr=260+(3*dim+fmax)*2+(dim+1)*4, num_free_tiles=(M,)))
# Preload the entire weights tensor. everything fits in SBUF for LLaMA 3.1 70B
for m in nl.affine_range(M):
weights_buffer[m, i_rhs.p, i_rhs.x] = nl.load(weights[m*pmax+i_rhs.p, i_rhs.x],
mask=(m*pmax+i_rhs.p<dim) & (i_rhs.x<head_dim))
for i in nl.affine_range(TILES_INT):
# Double buffer the input tensor
in_bufs = nl.ndarray((2, par_dim(pmax), dim), dtype=hidden.dtype, buffer=ncc.sbuf.mod_alloc(base_addr=260, num_free_tiles=(2,)))
for i_interleave_grp in nl.affine_range(2):
in_bufs[i_interleave_grp] = nl.load(hidden[b, (2*i+i_interleave_grp)*pmax+ix, iy], mask=(2*i+i_interleave_grp)*pmax+ix < seqlen)
act = nl.ndarray((par_dim(pmax), dim), dtype=norm_dtype, buffer=ncc.sbuf.mod_alloc(base_addr=260+(2*dim)*2))

# Write the RMS and RMS Reciprocal tensors back out here, in-place
square_sum = nl.ndarray((par_dim(pmax), 1), dtype=norm_dtype, buffer=ncc.sbuf.mod_alloc(base_addr=260+(2*dim)*2+(dim)*4))

# Write the output of RMS and RMS^T (in-place) out to here
out_tile = nl.ndarray((par_dim(pmax), dim), dtype=weights.dtype,
buffer=ncc.sbuf.mod_alloc(base_addr=260+(2*dim)*2+(dim+1)*4))

# Store the final output tiles to here before sending back to DRAM
output_sbuf = nl.ndarray((par_dim(pmax), fmax), dtype=weights.dtype,
buffer=ncc.sbuf.mod_alloc(base_addr=260+(3*dim)*2+(dim+1)*4))

act[...] = nisa.activation_reduce(op=nl.square, data=in_bufs[i_interleave_grp], reduce_op=np.add, reduce_res=square_sum[...], bias=bias_placeholder[...])
square_sum[...] = nisa.tensor_scalar(square_sum[...], np.multiply, scale, op1=np.add, operand1=eps)
square_sum[...] = nisa.activation(op=nl.rsqrt, data=square_sum[...], bias=bias_placeholder[...])

# all PE array ops must output to FP32 on trn1 but must match input dtype in trn2
if nisa.get_nc_version() == nisa.nc_version.gen3:
transpose_res_psum = nl.ndarray((NUM_TRANSP_TILES, par_dim(pmax), 4*pmax), dtype=weights.dtype,
buffer=ncc.psum.mod_alloc(base_bank=0, num_bank_tiles=(1,))) # FIXME: perf is better when all tiles are on bank 0?
else:
transpose_res_psum = nl.ndarray((NUM_TRANSP_TILES, par_dim(pmax), 4*pmax), dtype=np.float32,
buffer=ncc.psum.mod_alloc(base_bank=0, num_bank_tiles=(1,))) # FIXME: perf is better when all tiles are on bank 0?

for m in nl.affine_range(NUM_TRANSP_TILES):
# Perform (hidden .* RMS Reciprocal)^T in tiles of fmax (512)
out_tile[i_rhs.p, m*fmax+i_rhs.x] = nl.multiply(in_bufs[i_interleave_grp, i_rhs.p, m*fmax + i_rhs.x], square_sum[...], dtype=weights.dtype)
for j in nl.affine_range(4):
transpose_res_psum[m, i_lhs.p, j*pmax+i_lhs.x] = nisa.nc_matmul(out_tile[i_lhs.p, (m*4+j) * pmax + i_lhs.x], identity_tensor[...],
is_transpose=True)
out_tile[i_rhs.p, m * 4*pmax + i_rhs.x] = nl.copy(transpose_res_psum[m], dtype=hidden.dtype)

# perform (RMSNorm(hidden)^T)^T @ wQKV
res_psum = nl.ndarray((1, par_dim(pmax), fmax), dtype=nl.float32,
buffer=ncc.psum.mod_alloc(base_bank=7, num_bank_tiles=(1,)))
for m in nl.affine_range(M):
res_psum[0] += nisa.nc_matmul(out_tile[i_lhs.p, m*pmax+i_lhs.x], weights_buffer[m, i_rhs.p, i_rhs.x])

output_sbuf[...] = nl.copy(res_psum[0], dtype=out_tensor.dtype)
nl.store(out_tensor[b, (2*i+i_interleave_grp)*pmax+i_res.p, i_res.x],
value=output_sbuf,
mask=((2*i+i_interleave_grp)*pmax+i_res.p<seqlen) & (i_res.x<head_dim))
return out_tensor
213 changes: 130 additions & 83 deletions src/reference/attention.py

Large diffs are not rendered by default.

22 changes: 12 additions & 10 deletions src/reference/tutorial.py
Original file line number Diff line number Diff line change
@@ -5,9 +5,14 @@
"""

from neuronxcc import nki
import neuronxcc.nki.language as nl

def add_kernel_nx8x128x512(a_ptr, b_ptr, c_ptr, n_elements):

@nki.jit
def add_kernel_nx8x128x512(a_ptr, b_ptr, n_elements):
c_ptr = nl.ndarray(a_ptr.shape, dtype=a_ptr.dtype, buffer=nl.shared_hbm)

ix = nl.arange(128)[:, None]
iy = nl.arange(512)[None, :]

@@ -18,12 +23,9 @@ def add_kernel_nx8x128x512(a_ptr, b_ptr, c_ptr, n_elements):

for i in nl.affine_range(8):
offset = j * block_size + i * tile_size + 512 * ix + iy
mask = offset < n_elements
a_ptr = a_ptr.ptr + offset
b_ptr = b_ptr.ptr + offset
c_ptr = c_ptr.ptr + offset

a = nl.load(a_ptr, mask=mask)
b = nl.load(b_ptr, mask=mask)
c = a + b
nl.store(c_ptr, value=c, mask=mask)
a = nl.load(a_ptr[j, i, ix, iy], mask=offset < n_elements)
b = nl.load(b_ptr[j, i, ix, iy], mask=offset < n_elements)
c = nl.add(a, b, mask=offset < n_elements)
nl.store(c_ptr[j, i, ix, iy], value=c, mask=offset < n_elements)

return c_ptr
21 changes: 17 additions & 4 deletions src/reference/vision.py
Original file line number Diff line number Diff line change
@@ -8,10 +8,13 @@

import neuronxcc.nki.language as nl
import neuronxcc.nki.isa as nisa
from neuronxcc import nki
from neuronxcc.nki.language import par_dim
import neuronxcc.nki.typing as nt

def select_and_scatter_kernel(operand_tensor, source_tensor, out_tensor):

@nki.jit
def select_and_scatter_kernel(operand_tensor, source_tensor):
"""
Implementation of a select-and-scatter kernel.
@@ -51,7 +54,10 @@ def select_and_scatter_kernel(operand_tensor, source_tensor, out_tensor):
assert C == 64 and N % 2 == 0

kernel_dtype = operand_tensor.dtype
assert operand_tensor.dtype == source_tensor.dtype == out_tensor.dtype
assert operand_tensor.dtype == source_tensor.dtype

out_tensor = nl.ndarray((N, C, H, W), dtype=operand_tensor.dtype,
buffer=nl.shared_hbm)

p = 128 # num of partitions to use
for ib in nl.affine_range(N // 2):
@@ -156,8 +162,11 @@ def select_and_scatter_kernel(operand_tensor, source_tensor, out_tensor):
nl.store(out_tensor[2 * ib + ib_1, 0:64, 0:H, 0:W],
value=out_local[(ib_1 * 64):((ib_1 + 1) * 64), 0:H, 0:W])

return out_tensor

def resize_nearest_fixed_dma_kernel(data_tensor, out_tensor):

@nki.jit
def resize_nearest_fixed_dma_kernel(data_tensor, out_shape):
"""
Resize the input image to the given size using the nearest interpolation mode. This kernel is designed to be used when the scaling factor is not an integer.
@@ -174,7 +183,9 @@ def resize_nearest_fixed_dma_kernel(data_tensor, out_tensor):
"""
in_b, in_h, in_w, in_c = data_tensor.shape
out_b, out_h, out_w, out_c = out_tensor.shape
out_b, out_h, out_w, out_c = out_shape
out_tensor = nl.ndarray(out_shape, dtype=data_tensor.dtype,
buffer=nl.shared_hbm)

assert in_b == out_b, "Input batch and output batch must be identical"
assert in_c == out_c, "Input channel and output channel must be identical"
@@ -198,3 +209,5 @@ def resize_nearest_fixed_dma_kernel(data_tensor, out_tensor):
local_data = nl.load(target_addr)
dst_addr_0 = out_tile[b_map, i, c_map]
nl.store(dst_addr_0, value=local_data)

return out_tensor
24 changes: 10 additions & 14 deletions src/tutorials/average_pool2d/average_pool2d_jax.py
Original file line number Diff line number Diff line change
@@ -4,42 +4,38 @@
JAX implementation for average pool 2D NKI tutorial.
"""
from functools import partial
from jax_neuronx import nki_call
import jax
# NKI_EXAMPLE_40_BEGIN
import jax.numpy as jnp

from average_pool2d_nki_kernels import tensor_avgpool_kernel_


def tensor_avgpool_kernel(in_array, pool_size):
return nki_call(
partial(tensor_avgpool_kernel_, pool_size=pool_size),
in_array,
out_shape=jax.ShapeDtypeStruct((C, HOUT, WOUT), dtype=in_array.dtype),
)
# NKI_EXAMPLE_40_END
from average_pool2d_nki_kernels import tensor_avgpool_kernel


# NKI_EXAMPLE_40_BEGIN
# Reference JAX implementation
def jax_average_pool_2D(in_tensor, pool_size):
c, h_in, w_in = in_tensor.shape
reshaped = in_tensor.reshape(c, h_in // pool_size, pool_size, w_in // pool_size, pool_size)
return jnp.nanmean(reshaped, axis=(2, 4))
# NKI_EXAMPLE_40_END


# NKI_EXAMPLE_41_BEGIN
if __name__ == "__main__":
POOL_SIZE = 2
C, HIN, WIN = 2, 6, 6
HOUT, WOUT = HIN//POOL_SIZE, WIN//POOL_SIZE

in_array = jnp.arange(C * HIN * WIN, dtype=jnp.float32).reshape(C, HIN, WIN)

# NKI_EXAMPLE_39_BEGIN
out_nki = tensor_avgpool_kernel(in_array, pool_size=POOL_SIZE)
# NKI_EXAMPLE_39_END
out_jax = jax_average_pool_2D(in_array, pool_size=POOL_SIZE)

print(in_array, out_nki, out_jax)

if jnp.allclose(out_nki, out_jax):
print("NKI and JAX match")
else:
print("NKI and JAX differ")
print("NKI and JAX differ")
# NKI_EXAMPLE_41_END
49 changes: 22 additions & 27 deletions src/tutorials/average_pool2d/average_pool2d_nki_kernels.py
Original file line number Diff line number Diff line change
@@ -5,59 +5,56 @@
"""
import numpy as np
# NKI_EXAMPLE_37_BEGIN
import neuronxcc.nki as nki
import neuronxcc.nki.language as nl
from neuronxcc.nki.typing import tensor


def tensor_avgpool_kernel_(in_tensor, out_tensor, pool_size):
@nki.jit
def tensor_avgpool_kernel(in_tensor, pool_size):
"""NKI kernel to compute a 2D avg-pool operation
Args:
in_tensor: an input tensor, of shape C x H x W
pool_size: an integer representing a (square) pool-window size
Return:
out_tensor: the resulting output tensor, of shape C x (H/pool_size) x (W/pool_size)
"""

# Get input/output dimensions
sz_cin, sz_hin, sz_win = in_tensor.shape
sz_cout, sz_hout, sz_wout = out_tensor.shape
assert sz_cin == sz_cout
sz_hout = sz_hin // pool_size
sz_wout = sz_win // pool_size
# Create output tensor shared between all SPMD instances as result tensor
out_tensor = nl.ndarray((sz_cin, sz_hout, sz_wout), dtype=in_tensor.dtype,
buffer=nl.shared_hbm)

# Set relevant sizes
sz_p = sz_cin
sz_pool = pool_size

# Generate tensor h/w index patterns
# 3D indexing according to [C, H, W]
i_p = nl.arange(sz_p)[:, None, None] # 3D for
i_win = nl.arange(sz_win)[None, None, :]
i_hin = nl.arange(sz_hin)[None, :, None]

i_wout = nl.arange(sz_wout)[None, None, :]
i_hout = nl.arange(sz_hout)[None, :, None]

# Generate pool index patterns (requires two extra dimensions, for the pool window)
i_0 = nl.arange(sz_p)[:, None, None, None, None] #
i_1 = nl.arange(sz_hin//sz_pool)[None, :, None, None, None] # y_outer
i_2 = nl.arange(sz_pool)[None, None, :, None, None] # y_inner
i_3 = nl.arange(sz_win//sz_pool)[None, None, None, :, None] # x_outer
i_4 = nl.arange(sz_pool)[None, None, None, None, :] # x_inner
i0, i1, i2, i3, i4 = nl.mgrid[0:sz_p, 0:sz_hin//sz_pool, 0:sz_pool, 0:sz_win//sz_pool, 0:sz_pool]

# Load input data from external memory to on-chip memory
# Declare ndarray to force a 3D tensor (temporary requirement)
in_tile = nl.ndarray([sz_p, sz_hin, sz_win], dtype=in_tensor.dtype)
in_tile[:,:,:] = nl.load(in_tensor[i_p, i_hin, i_win])
in_tile: tensor[sz_p, sz_hin, sz_win] = nl.load(in_tensor)

# Perform the pooling operation:
# We use numpy's advanced indexing, in order to extend in_tile to 5D, and then reduce-average two dimension.
# axis[0] is the index for p_dim, and thus doesn't participate in the reduction operation.
# axis[1] and axis[2] together index the rows, with axis[2] responsible for inner strides
# (i.e. inside a pooling window), and axis[1] responsible for the outer strides. As such, we reduce over axis[2].
# Similarly, axis[3] and axis[4] together index the columns, and we thus reduce over axis[4].
out_tile = nl.sum(in_tile[i_0, sz_pool*i_1+i_2, sz_pool*i_3+i_4], axis=[2,4]) / (pool_size*pool_size)
out_tile : tensor[sz_p, sz_hout, sz_wout] = nl.sum(in_tile[i0, sz_pool*i1+i2, sz_pool*i3+i4],
axis=[2,4]) / (pool_size*pool_size)

# Store the results back to hbm
nl.store(out_tensor, value=out_tile)

# Store the results back to external memory
nl.store(out_tensor[i_p, i_hout, i_wout], value=out_tile)
# Transfer the ownership of `out_tensor` to the caller
return out_tensor
# NKI_EXAMPLE_37_END


# Reference NumPy implementation
@@ -74,10 +71,8 @@ def np_average_pool_2D(in_tensor, pool_size):
HOUT, WOUT = HIN//POOL_SIZE, WIN//POOL_SIZE

in_tensor = np.arange(C * HIN * WIN, dtype=np.float16).reshape(C, HIN, WIN)
out_nki = np.zeros((C, HOUT, WOUT), dtype=np.float16)

tensor_avgpool_kernel_baremetal = nki.baremetal(tensor_avgpool_kernel_)
tensor_avgpool_kernel_baremetal(in_tensor, out_nki, POOL_SIZE)
out_nki = tensor_avgpool_kernel(in_tensor, POOL_SIZE)

out_np = np_average_pool_2D(in_tensor, POOL_SIZE)

11 changes: 6 additions & 5 deletions src/tutorials/average_pool2d/average_pool2d_torch.py
Original file line number Diff line number Diff line change
@@ -4,13 +4,14 @@
PyTorch implementation for average pool 2D NKI tutorial.
"""
# NKI_EXAMPLE_38_BEGIN
import torch
from torch_neuronx import nki_jit
from torch_xla.core import xla_model as xm

from average_pool2d_nki_kernels import tensor_avgpool_kernel_
# NKI_EXAMPLE_38_END
from average_pool2d_nki_kernels import tensor_avgpool_kernel


# NKI_EXAMPLE_38_BEGIN
if __name__ == "__main__":
device = xm.xla_device()

@@ -22,8 +23,7 @@
in_tensor = torch.arange(C * HIN * WIN, dtype=torch.bfloat16).reshape(C, HIN, WIN).to(device=device)
out_nki = torch.zeros((C, HOUT, WOUT), dtype=torch.bfloat16).to(device=device)

tensor_avgpool_kernel_torch = nki_jit(tensor_avgpool_kernel_)
tensor_avgpool_kernel_torch(in_tensor, out_nki, POOL_SIZE)
out_nki = tensor_avgpool_kernel(in_tensor, POOL_SIZE)

out_torch = torch.nn.functional.avg_pool2d(in_tensor, POOL_SIZE, POOL_SIZE)

@@ -33,3 +33,4 @@
print("NKI and Torch match")
else:
print("NKI and Torch differ")
# NKI_EXAMPLE_38_END
38 changes: 26 additions & 12 deletions src/tutorials/fused_mamba/mamba_nki_kernels.py
Original file line number Diff line number Diff line change
@@ -4,16 +4,19 @@
Mamba-v1 NKI kernel implementation.
"""
# NKI_EXAMPLE_25_BEGIN
import neuronxcc.nki as nki
import neuronxcc.nki.language as nl
import neuronxcc.nki.isa as nisa
import numpy as np
# NKI_EXAMPLE_25_END
import os
import argparse
import itertools


def mamba_v1(delta, u, A, B, C, output):
# NKI_EXAMPLE_25_BEGIN
@nki.jit
def mamba_v1(delta, u, A, B, C):
"""Computes the SSM operation in the Mamba model.
:param delta: (batch_size, channels, seq_len)
@@ -24,6 +27,9 @@ def mamba_v1(delta, u, A, B, C, output):
:return: (batch_size, channels, seq_len)
"""
batch_size, channels, seq_len = delta.shape
output = nl.ndarray((batch_size, channels, seq_len), dtype=delta.dtype,
buffer=nl.shared_hbm)

_, state_size = A.shape

# We can relax this using mask paramters in all the NKI API calls
@@ -84,8 +90,12 @@ def mamba_v1(delta, u, A, B, C, output):
nl.store(output[i_batch, channel_start:channel_start+channel_psize, 0:seq_len],
scanC_accum[i_channel_tile, 0:channel_psize, 0:seq_len])

return output
# NKI_EXAMPLE_25_END

def mamba_v2(delta, u, A, B, C, output):
# NKI_EXAMPLE_26_BEGIN
@nki.jit
def mamba_v2(delta, u, A, B, C):
"""Computes the SSM operation in the Mamba model.
:param delta: (batch_size, channels, seq_len)
@@ -96,6 +106,8 @@ def mamba_v2(delta, u, A, B, C, output):
:return: (batch_size, channels, seq_len)
"""
batch_size, channels, seq_len = delta.shape
output = nl.ndarray((batch_size, channels, seq_len), dtype=delta.dtype,
buffer=nl.shared_hbm)
_, state_size = A.shape

assert channels % 128 == 0
@@ -153,8 +165,12 @@ def mamba_v2(delta, u, A, B, C, output):
nl.store(output[i_batch, channel_start:channel_start+channel_psize, 0:seq_len],
scanC_accum[0:channel_psize, 0:seq_len])

return output
# NKI_EXAMPLE_26_END


def mamba_v3(delta, u, A, B, C, output):
@nki.jit
def mamba_v3(delta, u, A, B, C):
"""Computes the SSM operation in the Mamba model.
:param delta: (batch_size, channels, seq_len)
@@ -165,6 +181,8 @@ def mamba_v3(delta, u, A, B, C, output):
:return: (batch_size, channels, seq_len)
"""
batch_size, channels, seq_len = delta.shape
output = nl.ndarray((batch_size, channels, seq_len), dtype=delta.dtype,
buffer=nl.shared_hbm)
_, state_size = A.shape

# Map channels to the partition dimension
@@ -239,6 +257,7 @@ def mamba_v3(delta, u, A, B, C, output):
# Store scanC_accum for a single batch to output
nl.store(output[i_batch, channel_start:channel_start+channel_psize, 0:seq_len],
scanC_accum[0:channel_psize, 0:seq_len])
return output


def parse_args():
@@ -310,9 +329,7 @@ def parse_args():
if args.mode == "accuracy":
# v1: reference kernel
print(f">>>> Running v1 (reference).")
nki_out_v1 = np.empty((batch, channels, seq_len), dtype=dtype)
nki.baremetal(mamba_v1)\
(delta, u, A, B, C, nki_out_v1)
nki_out_v1 = mamba_v1(delta, u, A, B, C)

for version in args.version:
if version == "v1":
@@ -321,9 +338,7 @@ def parse_args():

print(f">>>> Running version {version}.")
func = func_dict[version]
nki_out_test = np.empty((batch, channels, seq_len), dtype=dtype)
nki.baremetal(func)\
(delta, u, A, B, C, nki_out_test)
nki_out_test = func(delta, u, A, B, C)
print(f">>>> mamba {version} matches?", np.all(nki_out_test == nki_out_v1))
assert np.all(nki_out_test == nki_out_v1)

@@ -333,11 +348,10 @@ def parse_args():
for version in args.version:
print(f">>>> Running version {version}.")
func = func_dict[version]
nki_out_test = np.empty((batch, channels, seq_len), dtype=dtype)
nki.benchmark(func,
save_neff_name='file.neff',
save_trace_name='profile.ntff')\
(delta, u, A, B, C, nki_out_test)
(delta, u, A, B, C)
# TODO: rename neff/ntff (bug in nki.benchmark with neff name)
os.rename("file.neff", f"{version}_b{batch}_sl{seq_len}_c{channels}_ss{state_size}.neff")
os.rename("profile.ntff", f"{version}_b{batch}_sl{seq_len}_c{channels}_ss{state_size}.ntff")
7 changes: 3 additions & 4 deletions src/tutorials/fused_mamba/mamba_torch.py
Original file line number Diff line number Diff line change
@@ -5,6 +5,7 @@
"""

# NKI_EXAMPLE_24_BEGIN
import torch
import torch_neuronx
import torch_xla.core.xla_model as xm
@@ -99,16 +100,14 @@ def parse_args():
torch_out = mamba_layer(delta, A, B, u, C)
xm.mark_step()
print(torch_out)
# NKI_EXAMPLE_24_END

if args.mode == "accuracy":
# Call NKI mamba_v1 kernel to check accuracy
from mamba_nki_kernels import mamba_v1
from torch_neuronx import nki_jit

nki_out = torch.empty((batch, channels, seq_len), dtype=dtype, device=device)

xm.mark_step()
nki_jit(mamba_v1)(delta, u, A, B, C, nki_out)
nki_out = mamba_v1(delta, u, A, B, C)
xm.mark_step()

allclose = torch.allclose(torch_out, nki_out, atol=1e-2, rtol=1e-2)
143 changes: 77 additions & 66 deletions src/tutorials/layernorm/layernorm_nki_kernel.py
Original file line number Diff line number Diff line change
@@ -4,21 +4,27 @@
LayerNorm NKI kernel implementation.
"""
# NKI_EXAMPLE_45_BEGIN
import neuronxcc.nki as nki
import neuronxcc.nki.language as nl
import neuronxcc.nki.isa as nisa
import numpy as np
import math
# NKI_EXAMPLE_45_END
import os
import argparse


def nki_layernorm_kernel_v1(input_tensor, epsilon, gamma_vector, beta_vector, output_tensor):
# NKI_EXAMPLE_45_BEGIN
@nki.jit
def nki_layernorm_kernel_v1(input_tensor, epsilon, gamma_vector, beta_vector):
"""Computes LayerNorm.
Used nki.language APIs only.
"""
output_tensor = nl.ndarray(input_tensor.shape, dtype=input_tensor.dtype,
buffer=nl.shared_hbm)

# Ensure that the shapes of tensors match
assert input_tensor.shape == output_tensor.shape
assert input_tensor.shape[1] == gamma_vector.shape[0] == beta_vector.shape[0]

# Generate tile indices for loading/storing data
@@ -58,12 +64,20 @@ def nki_layernorm_kernel_v1(input_tensor, epsilon, gamma_vector, beta_vector, ou
nl.store(output_tensor[i * nl.tile_size.pmax + i_p_io, i_f_io], value=output_sb,
mask=(i * nl.tile_size.pmax + i_p_io < num_rows))

def nki_layernorm_kernel_v2(input_tensor, epsilon, gamma_vector, beta_vector, output_tensor):
return output_tensor
# NKI_EXAMPLE_45_END


# NKI_EXAMPLE_46_BEGIN
@nki.jit
def nki_layernorm_kernel_v2(input_tensor, epsilon, gamma_vector, beta_vector):
"""Computes LayerNorm.
Used nki.isa APIs to calculate mean/variance and perform shift/scale.
"""
output_tensor = nl.ndarray(input_tensor.shape, dtype=input_tensor.dtype,
buffer=nl.shared_hbm)

# Ensure that the shapes of tensors match
assert input_tensor.shape == output_tensor.shape
assert input_tensor.shape[1] == gamma_vector.shape[0] == beta_vector.shape[0]

# Generate tile indices for loading/storing data
@@ -122,69 +136,66 @@ def nki_layernorm_kernel_v2(input_tensor, epsilon, gamma_vector, beta_vector, ou
nl.store(output_tensor[i * nl.tile_size.pmax + i_p_io, i_f_io], value=output_sb,
mask=(i * nl.tile_size.pmax + i_p_io < num_rows))

return output_tensor
# NKI_EXAMPLE_46_END


def parse_args():
parser = argparse.ArgumentParser(
"""Run LayerNorm pytorch implementation.
""")
parser.add_argument("--nrows",
default=4*1024,
type=int,
help="""The number of input rows""")
parser.add_argument("--ncols",
default=8*1024,
type=int,
help="""The number of input columns""")
parser.add_argument("--mode",
choices=["accuracy", "perf"],
default="accuracy",
help="""Do accuracy test or perf test.
Accuracy test compares LayerNorm kernel against PyTorch implementation.
Perf test will generate a NEFF for the PyTorch implementation in local directory
for a manual run of neuron-profile.
""")
args = parser.parse_args()
return args
parser = argparse.ArgumentParser(
"""Run LayerNorm pytorch implementation.
""")
parser.add_argument("--nrows",
default=4*1024,
type=int,
help="""The number of input rows""")
parser.add_argument("--ncols",
default=8*1024,
type=int,
help="""The number of input columns""")
parser.add_argument("--mode",
choices=["accuracy", "perf"],
default="accuracy",
help="""Do accuracy test or perf test.
Accuracy test compares LayerNorm kernel against PyTorch implementation.
Perf test will generate a NEFF for the PyTorch implementation in local directory
for a manual run of neuron-profile.
""")
args = parser.parse_args()
return args

if __name__ == "__main__":
args = parse_args()
func_dict = {"v1": nki_layernorm_kernel_v1,
"v2": nki_layernorm_kernel_v2,
}

# Generate toy example
num_rows = args.nrows
num_cols = args.ncols
input_tensor = np.random.rand(num_rows, num_cols).astype(np.float32)
gamma_vector = np.random.rand(num_cols).astype(np.float32)
beta_vector = np.random.rand(num_cols).astype(np.float32)
epsilon = 1e-5

if args.mode == "accuracy":
# version 1
print(f">>>> Running version 1")
nki_out_v1 = np.empty((num_rows, num_cols), dtype=np.float32)
nki.baremetal(nki_layernorm_kernel_v1)\
(input_tensor, epsilon, gamma_vector, beta_vector, nki_out_v1)
# version 2
print(f">>>> Running version 2")
nki_out_v2 = np.empty((num_rows, num_cols), dtype=np.float32)
nki.baremetal(nki_layernorm_kernel_v2)\
(input_tensor, epsilon, gamma_vector, beta_vector, nki_out_v2)
# compare
np_all = np.all(nki_out_v1 == nki_out_v1)
print(f">>>> LayerNorm V1 and V2 matches?", np_all)
assert np_all

else:
# perf mode
for version in ["v1", "v2"]:
print(f">>>> Running version {version}.")
func = func_dict[version]
nki_out_test = np.empty((num_rows, num_cols), dtype=np.float32)
nki.benchmark(func,
save_neff_name='file.neff',
save_trace_name='profile.ntff')\
(input_tensor, epsilon, gamma_vector, beta_vector, nki_out_test)
os.rename("file.neff", f"{version}_{num_rows}_{num_cols}.neff")
os.rename("profile.ntff", f"{version}_{num_rows}_{num_cols}.ntff")
args = parse_args()
func_dict = {"v1": nki_layernorm_kernel_v1,
"v2": nki_layernorm_kernel_v2,
}

# Generate toy example
num_rows = args.nrows
num_cols = args.ncols
input_tensor = np.random.rand(num_rows, num_cols).astype(np.float32)
gamma_vector = np.random.rand(num_cols).astype(np.float32)
beta_vector = np.random.rand(num_cols).astype(np.float32)
epsilon = 1e-5

if args.mode == "accuracy":
# version 1
print(f">>>> Running version 1")
nki_out_v1 = nki_layernorm_kernel_v1(input_tensor, epsilon, gamma_vector, beta_vector)
# version 2
print(f">>>> Running version 2")
nki_out_v2 = nki_layernorm_kernel_v2(input_tensor, epsilon, gamma_vector, beta_vector)
# compare
np_all = np.all(nki_out_v1 == nki_out_v1)
print(f">>>> LayerNorm V1 and V2 matches?", np_all)
assert np_all

else:
# perf mode
for version in ["v1", "v2"]:
print(f">>>> Running version {version}.")
func = func_dict[version]
benchmark_kernel = nki.benchmark(func, save_neff_name='file.neff',
save_trace_name='profile.ntff')
nki_out_test = benchmark_kernel(input_tensor, epsilon, gamma_vector, beta_vector)
os.rename("file.neff", f"{version}_{num_rows}_{num_cols}.neff")
os.rename("profile.ntff", f"{version}_{num_rows}_{num_cols}.ntff")
18 changes: 10 additions & 8 deletions src/tutorials/layernorm/layernorm_torch.py
Original file line number Diff line number Diff line change
@@ -4,9 +4,9 @@
LayerNorm NKI kernel implementation.
"""
# NKI_EXAMPLE_47_BEGIN
import torch
from torch_xla.core import xla_model as xm
from torch_neuronx import nki_jit
import argparse
import os

@@ -42,13 +42,16 @@ def parse_args():
args = parser.parse_args()
return args


from neuronxcc.nki.docs.examples.layernorm.layernorm_nki_kernel import nki_layernorm_kernel_v1, \
nki_layernorm_kernel_v2

if __name__ == "__main__":
args = parse_args()
from neuronxcc.nki.docs.examples.layernorm.layernorm_nki_kernel import nki_layernorm_kernel_v1, nki_layernorm_kernel_v2
func_dict = {"v1": nki_layernorm_kernel_v1,
"v2": nki_layernorm_kernel_v2,
}

device = xm.xla_device()
num_rows = args.nrows
num_cols = args.ncols
@@ -58,25 +61,23 @@ def parse_args():
gamma_vector = torch.rand((num_cols), dtype=torch.float32)
beta_vector = torch.rand((num_cols), dtype=torch.float32)
epsilon = 1e-5

# Compute torch layernorm layer in cpu
output_torch = layernorm_layer(input_tensor, epsilon, gamma_vector, beta_vector)

# Copy tensors to NeuronDevice
input_tensor = input_tensor.to(device=device)
gamma_vector = gamma_vector.to(device=device)
beta_vector = beta_vector.to(device=device)
output_nki = torch.zeros((num_rows, num_cols), dtype=torch.float32).to(device=device)

print(f">>>> Running version {args.version}.")
func = func_dict[args.version]

# add nki_jit decorator
nki_layernorm_kernel = nki_jit(func)

# Compute NKI layernorm kernel in NeuronDevice
xm.mark_step()
nki_layernorm_kernel(input_tensor, epsilon, gamma_vector, beta_vector, output_nki)
output_nki = func(input_tensor, epsilon, gamma_vector, beta_vector)
xm.mark_step()
output_nki = output_nki.to(device='cpu')

@@ -86,5 +87,6 @@ def parse_args():
print("NKI and Torch match")
else:
print("NKI and Torch differ")

# NKI_EXAMPLE_47_END

assert allclose
Original file line number Diff line number Diff line change
@@ -12,16 +12,21 @@
import numpy as np


def nki_matmul_basic_(lhsT, rhs, result):
# NKI_EXAMPLE_16_BEGIN
@nki.jit
def nki_matmul_basic_(lhsT, rhs):
"""NKI kernel to compute a 64x128x512 matrix multiplication operation
Args:
lhsT: an input tensor of shape [128,64], a left hand side argument of the
matrix multiplication, delivered transposed for optimal performance
rhs: an input tensor of shape [128,512], a right hand side argument of the
matrix multiplication
Returns:
result: the resulting output tensor of shape [64,512]
"""
result = nl.ndarray((64, 512), dtype=lhsT.dtype, buffer=nl.shared_hbm)

# Defining indexes for input LHS.T
# - Note: here we take LayoutConstraint #1 into account:
# "For MatMult, contraction axis must be mapped to P-dim"
@@ -53,8 +58,13 @@ def nki_matmul_basic_(lhsT, rhs, result):
# This dictates which indices to use to address the result tile.
nl.store(result[i_out_p, i_out_f], value=result_sbuf)

return result
# NKI_EXAMPLE_16_END


def nki_matmul_tiled_(lhsT, rhs, result):
# NKI_EXAMPLE_18_BEGIN
@nki.jit
def nki_matmul_tiled_(lhsT, rhs):
"""NKI kernel to compute a matrix multiplication operation in a tiled manner
Args:
@@ -64,12 +74,14 @@ def nki_matmul_tiled_(lhsT, rhs, result):
rhs: an input tensor of shape [K,N], where K is a multiple of 128, and N
is a multiple of 512. It is the right-hand-side argument of the matrix
multiplication.
Returns:
result: the resulting output tensor of shape [M,N]
"""

K, M = lhsT.shape
K_, N = rhs.shape
assert K == K_, "lhsT and rhs must have the same contraction dimension"
result = nl.ndarray((M, N), dtype=lhsT.dtype, buffer=nl.shared_hbm)

TILE_M = nl.tile_size.gemm_stationary_fmax # 128
TILE_K = nl.tile_size.pmax # 128
@@ -100,8 +112,13 @@ def nki_matmul_tiled_(lhsT, rhs, result):
nl.store(result[m * TILE_M:(m + 1) * TILE_M, n * TILE_N:(n + 1) * TILE_N],
value=res_sb)

return result
# NKI_EXAMPLE_18_END

def nki_matmul_hoist_load_(lhsT, rhs, result):

# NKI_EXAMPLE_19_BEGIN
@nki.jit
def nki_matmul_hoist_load_(lhsT, rhs):
"""NKI kernel to compute a matrix multiplication operation in a tiled manner
while hoisting the load of the lhsT and rhs to outer loops.
@@ -112,12 +129,14 @@ def nki_matmul_hoist_load_(lhsT, rhs, result):
rhs: an input tensor of shape [K,N], where K is a multiple of 128, and N
is a multiple of 512. It is the right-hand-side argument of the matrix
multiplication.
Returns:
result: the resulting output tensor of shape [M,N]
"""

K, M = lhsT.shape
K_, N = rhs.shape
assert K == K_, "lhsT and rhs must have the same contraction dimension"
result = nl.ndarray((M, N), dtype=lhsT.dtype, buffer=nl.shared_hbm)

TILE_M = nl.tile_size.gemm_stationary_fmax # 128
TILE_K = nl.tile_size.pmax # 128
@@ -163,8 +182,13 @@ def nki_matmul_hoist_load_(lhsT, rhs, result):
res_sb = nl.copy(res_psum, dtype=result.dtype)
nl.store(result[m * TILE_M + i_res.p, n * TILE_N + i_res.x], value=res_sb)

return result
# NKI_EXAMPLE_19_END


def nki_matmul_block_free_dimension_(lhsT, rhs, result):
# NKI_EXAMPLE_20_BEGIN
@nki.jit
def nki_matmul_block_free_dimension_(lhsT, rhs):
"""NKI kernel to compute a matrix multiplication operation while blocking the
free dimensions of the LHS and RHS to improve memory access pattern.
@@ -175,12 +199,14 @@ def nki_matmul_block_free_dimension_(lhsT, rhs, result):
rhs: an input tensor of shape [K,N], where K is a multiple of 128, and N
is a multiple of 512. It is the right-hand-side argument of the matrix
multiplication.
Returns:
result: the resulting output tensor of shape [M,N]
"""

K, M = lhsT.shape
K_, N = rhs.shape
assert K == K_, "lhsT and rhs must have the same contraction dimension"
result = nl.ndarray((M, N), dtype=lhsT.dtype, buffer=nl.shared_hbm)

TILE_M = nl.tile_size.gemm_stationary_fmax # 128
TILE_K = nl.tile_size.pmax # 128
@@ -243,11 +269,15 @@ def nki_matmul_block_free_dimension_(lhsT, rhs, result):
(n * TILES_IN_BLOCK_N + bn) * TILE_N + i_res.x],
value=res_sb)

return result
# NKI_EXAMPLE_20_END


# NKI_EXAMPLE_21_BEGIN
@nki.jit
def nki_matmul_fully_optimized_(
lhsT,
rhs,
result,
# Meta-parameters
TILES_IN_BLOCK_M=16,
TILES_IN_BLOCK_N=2,
@@ -264,13 +294,15 @@ def nki_matmul_fully_optimized_(
rhs: an input tensor of shape [K,N], where K is a multiple of 128 *
TILES_IN_BLOCK_K and N is a multiple of 512 * TILES_IN_BLOCK_N. It is
the right-hand-side argument of the matrix multiplication.
result: the resulting output tensor of shape [M,N]
TILES_IN_BLOCK_*: meta parameters to control blocking dimensions
Returns:
result: the resulting output tensor of shape [M,N]
"""

K, M = lhsT.shape
K_, N = rhs.shape
assert K == K_, "lhsT and rhs must have the same contraction dimension"
result = nl.ndarray((M, N), dtype=lhsT.dtype, buffer=nl.shared_hbm)

TILE_M = nl.tile_size.gemm_stationary_fmax # 128
TILE_K = nl.tile_size.pmax # 128
@@ -360,16 +392,19 @@ def nki_matmul_fully_optimized_(
BLOCK_N * n + i_res_packed.x],
value=result_packed[i_res_packed.p, i_res_packed.x])

return result
# NKI_EXAMPLE_21_END


# NKI_EXAMPLE_23_BEGIN
if __name__ == "__main__":
# Benchmarking with large matrices to show the differences more clearly
lhsT = nt.tensor[[8192, 4096], nl.bfloat16]
rhs = nt.tensor[[8192, 8192], nl.bfloat16]
output = nt.tensor[[4096, 8192], nl.bfloat16]

def benchmark_nki(nki_func):
bench_func = nki.benchmark(warmup=5, iters=10)(nki_func)
bench_func(lhsT, rhs, output)
bench_func(lhsT, rhs)
latency_res = bench_func.benchmark_result.nc_latency
p99 = latency_res.get_latency_percentile(99)
print("Latency: {:.2f} ms (P99)".format(p99 / 1000.0))
@@ -385,3 +420,4 @@ def benchmark_nki(nki_func):

print("Benchmarking nki_matmul_fully_optimized")
benchmark_nki(nki_matmul_fully_optimized_)
# NKI_EXAMPLE_23_END
Original file line number Diff line number Diff line change
@@ -7,23 +7,21 @@

import torch
from torch_xla.core import xla_model as xm
from torch_neuronx import nki_jit

from matrix_multiplication_nki_kernels import nki_matmul_basic_, nki_matmul_tiled_, nki_matmul_hoist_load_, nki_matmul_block_free_dimension_, nki_matmul_fully_optimized_

if __name__ == "__main__":

# NKI_EXAMPLE_17_BEGIN
device = xm.xla_device()
cpu = torch.device('cpu')

# Test the small workload with basic kernel
lhs_small = torch.rand((64, 128), dtype=torch.bfloat16, device=device)
rhs_small = torch.rand((128, 512), dtype=torch.bfloat16, device=device)
output_small = torch.zeros((64, 512), dtype=torch.bfloat16, device=device)

# Run NKI kernel
nki_matmul_basic_jit = nki_jit(nki_matmul_basic_)
nki_matmul_basic_jit(lhs_small.T, rhs_small, output_small)
output_small = nki_matmul_basic_(lhs_small.T, rhs_small)

# Run torch reference
output_small_torch = torch.matmul(lhs_small, rhs_small)
@@ -34,18 +32,18 @@
print("NKI and Torch match")
else:
print("NKI and Torch differ")
# NKI_EXAMPLE_17_END

# NKI_EXAMPLE_22_BEGIN
# Test the large workload with tiled kernels
lhs = torch.rand((4096, 1024), dtype=torch.bfloat16, device=device)
rhs = torch.rand((1024, 2048), dtype=torch.bfloat16, device=device)
output = torch.zeros((4096, 2048), dtype=torch.bfloat16, device=device)

# Run torch reference
output_torch = torch.matmul(lhs, rhs).to(device=cpu)

def check_match(nki_func):
jit_func = nki_jit(nki_func)
jit_func(lhs.T, rhs, output)
output = nki_func(lhs.T, rhs)
output_nki = output.to(device=cpu)
if torch.allclose(output_torch, output_nki, atol=1e-4, rtol=1e-2):
print("NKI and Torch match")
@@ -63,3 +61,4 @@ def check_match(nki_func):

print("Checking correctness of nki_matmul_fully_optimized")
check_match(nki_matmul_fully_optimized_)
# NKI_EXAMPLE_22_END
9 changes: 3 additions & 6 deletions src/tutorials/rmsnorm/rmsnorm_jax.py
Original file line number Diff line number Diff line change
@@ -7,9 +7,9 @@

import jax
import jax.numpy as jnp
from jax_neuronx import nki_call
from rmsnorm_nki_kernels import nki_rmsnorm_kernel

# NKI_EXAMPLE_44_BEGIN
# Reference JAX implementation
def jax_rms_norm(a_tensor, g_tensor):
# Square the tensor (element-wise)
@@ -26,11 +26,7 @@ def jax_rms_norm(a_tensor, g_tensor):
a_tensor = jax.random.uniform(a_key, (250, 512))
g_tensor = jax.random.uniform(g_key, (512,))

output_nki = nki_call(
nki_rmsnorm_kernel,
a_tensor, g_tensor,
out_shape=jax.ShapeDtypeStruct(a_tensor.shape, dtype=a_tensor.dtype),
)
output_nki = nki_rmsnorm_kernel(a_tensor, g_tensor)

print(a_tensor)

@@ -43,3 +39,4 @@ def jax_rms_norm(a_tensor, g_tensor):
print("NKI and JAX match")
else:
print("NKI and JAX differ")
# NKI_EXAMPLE_44_END
14 changes: 9 additions & 5 deletions src/tutorials/rmsnorm/rmsnorm_nki_kernels.py
Original file line number Diff line number Diff line change
@@ -6,20 +6,23 @@
"""

import numpy as np
# NKI_EXAMPLE_42_BEGIN
import math
import neuronxcc.nki as nki
import neuronxcc.nki.language as nl


def nki_rmsnorm_kernel(a_tensor, g_tensor, out_tensor):
@nki.jit
def nki_rmsnorm_kernel(a_tensor, g_tensor):
# Calculate out_tensor = a_tensor/RMS(a_tensor) * g_tensor
# Where RMS(a_tensor) = sqrt((1/N) * sum(a_tensor * a_tensor))
# and N = a_tensor.shape[1]
# Reduction (mean) is performed in the free (2nd) dimension
out_tensor = nl.ndarray(a_tensor.shape, dtype=a_tensor.dtype,
buffer=nl.shared_hbm)

# Make sure shapes match
assert a_tensor.shape[1] == g_tensor.shape[0]
assert a_tensor.shape == out_tensor.shape

# Generate tensor indices to index input tensor
ix = nl.arange(128)[:, None]
@@ -68,14 +71,15 @@ def nki_rmsnorm_kernel(a_tensor, g_tensor, out_tensor):
nl.store(out_tensor[i * 128 + ix, iy], value=out_tile,
mask=(i * 128 + ix < num_rows))

return out_tensor
# NKI_EXAMPLE_42_END


if __name__ == "__main__":
a = np.random.rand(128, 512).astype(np.float32)
g = np.random.rand(512).astype(np.float32)
output_nki = np.zeros(a.shape, dtype=a.dtype)

nki_rmsnorm_kernel_baremetal = nki.baremetal(nki_rmsnorm_kernel)
nki_rmsnorm_kernel_baremetal(a, g, output_nki)
output_nki = nki_rmsnorm_kernel(a, g)
print(f"output_nki={output_nki}")

# One-line numpy RMSNorm
8 changes: 3 additions & 5 deletions src/tutorials/rmsnorm/rmsnorm_torch.py
Original file line number Diff line number Diff line change
@@ -5,11 +5,11 @@
"""

from torch_neuronx.xla_impl.ops import nki_jit
import torch
import os
from rmsnorm_nki_kernels import nki_rmsnorm_kernel

# NKI_EXAMPLE_43_BEGIN
# Reference torch implementation
def torch_rmsnorm_kernel(a_tensor, g_tensor):
# Square the tensor (element-wise)
@@ -25,13 +25,10 @@ def torch_rmsnorm_kernel(a_tensor, g_tensor):
from torch_xla.core import xla_model as xm
device = xm.xla_device()

nki_rmsnorm_kernel = nki_jit(nki_rmsnorm_kernel)

a_tensor = torch.rand((250, 512), dtype=torch.float32).to(device=device)
g_tensor = torch.rand((512), dtype=torch.float32).to(device=device)
output_nki = torch.zeros((250, 512), dtype=torch.float32).to(device=device)

nki_rmsnorm_kernel(a_tensor, g_tensor, output_nki)
output_nki = nki_rmsnorm_kernel(a_tensor, g_tensor)
print(f"output_nki={output_nki}")

output_torch = torch_rmsnorm_kernel(a_tensor, g_tensor)
@@ -41,3 +38,4 @@ def torch_rmsnorm_kernel(a_tensor, g_tensor):
print("NKI and Torch match")
else:
print("NKI and Torch differ")
# NKI_EXAMPLE_43_END
101 changes: 53 additions & 48 deletions src/tutorials/sd_attention/sd_attention_nki_kernels.py
Original file line number Diff line number Diff line change
@@ -12,7 +12,9 @@
import argparse
from scipy.special import softmax

def fused_self_attn_for_SD_small_head_size(q_ref, k_ref, v_ref, out_ref, use_causal_mask=False,
# NKI_EXAMPLE_31_BEGIN
@nki.jit
def fused_self_attn_for_SD_small_head_size(q_ref, k_ref, v_ref, use_causal_mask=False,
mixed_percision=True):
"""
Fused self attention kernel for small head dimension Stable Diffusion workload,
@@ -44,7 +46,7 @@ def fused_self_attn_for_SD_small_head_size(q_ref, k_ref, v_ref, out_ref, use_cau
# Assume all IO tensors have the same dtype
kernel_dtype = q_ref.dtype
pe_in_dt = nl.bfloat16 if mixed_percision else np.float32
assert q_ref.dtype == k_ref.dtype == v_ref.dtype == out_ref.dtype
assert q_ref.dtype == k_ref.dtype == v_ref.dtype

# Shape checking
seqlen, d_head = q_ref.shape
@@ -53,7 +55,7 @@ def fused_self_attn_for_SD_small_head_size(q_ref, k_ref, v_ref, out_ref, use_cau
assert tuple(k_ref.shape) == (seqlen, d_head), 'Input shape mismatch!'
assert tuple(v_ref.shape) == (seqlen,d_head), \
f'Input shape mismatch! Expected: {(seqlen, d_head)} Actual: {tuple(v_ref.shape)}'
assert tuple(out_ref.shape) == (seqlen, d_head), 'Output shape mismatch!'
out_ref = nl.ndarray((seqlen, d_head), dtype=q_ref.dtype, buffer=nl.shared_hbm)

# Softmax scaling factor, multiplied onto Q
softmax_scale = 0.125
@@ -210,58 +212,61 @@ def fused_self_attn_for_SD_small_head_size(q_ref, k_ref, v_ref, out_ref, use_cau
out_ref[i_q_seq_tile * q_seq_tile_size + if_out, ip_out],
value=attn_res_div)

return out_ref
# NKI_EXAMPLE_31_END


def parse_args():
parser = argparse.ArgumentParser("Run Stable Diffusion Attention NKI kernel.")
parser.add_argument("--mode",
choices=["accuracy", "perf"],
default="accuracy",
help="""Do accuracy test or perf test.
Accuracy test uses cpu golden output as golden reference.
""")
parser = argparse.ArgumentParser("Run Stable Diffusion Attention NKI kernel.")
parser.add_argument("--mode",
choices=["accuracy", "perf"],
default="accuracy",
help="""Do accuracy test or perf test.
Accuracy test uses cpu golden output as golden reference.
""")

args = parser.parse_args()
return args

args = parser.parse_args()
return args

def cpu_golden_attn(q, k, v):
softmax_scale = 0.125
softmax_scale = 0.125

q_scaled = q * softmax_scale
raw_score = np.matmul(q_scaled, k.transpose())
norm_score = softmax(raw_score, axis=-1)
q_scaled = q * softmax_scale
raw_score = np.matmul(q_scaled, k.transpose())
norm_score = softmax(raw_score, axis=-1)

return np.matmul(norm_score, v)
return np.matmul(norm_score, v)


if __name__ == "__main__":
args = parse_args()

print(f"Running {args.mode} mode.")

seqlen, d_head = 4096, 64

# Set up input tensors
dtype = np.float32
q_tensor = np.random.rand(seqlen, d_head).astype(dtype)
k_tensor = np.random.rand(seqlen, d_head).astype(dtype)
v_tensor = np.random.rand(seqlen, d_head).astype(dtype)
output_nki = np.empty((seqlen, d_head), dtype=dtype)
output_golden = cpu_golden_attn(q_tensor, k_tensor, v_tensor)

if args.mode == "accuracy":
nki.baremetal(fused_self_attn_for_SD_small_head_size)\
(q_tensor, k_tensor, v_tensor, output_nki)
allclose = np.allclose(output_nki, output_golden, atol=1e-5, rtol=1e-3)
print(f">>>> SD attention matches CPU reference? {allclose}")
assert allclose, "Accuracy check fails!"

else:
benchmark_func = nki.benchmark(fused_self_attn_for_SD_small_head_size,
save_neff_name='file.neff',
save_trace_name='profile.ntff')
benchmark_func(q_tensor, k_tensor, v_tensor, output_nki)

metrics = benchmark_func.benchmark_result.nc_latency
print(">>>> SD attention benchmark results")
print("latency.p50 = " + str(metrics.get_latency_percentile(50)))
print("latency.p99 = " + str(metrics.get_latency_percentile(99)))
args = parse_args()

print(f"Running {args.mode} mode.")

seqlen, d_head = 4096, 64

# Set up input tensors
dtype = np.float32
q_tensor = np.random.rand(seqlen, d_head).astype(dtype)
k_tensor = np.random.rand(seqlen, d_head).astype(dtype)
v_tensor = np.random.rand(seqlen, d_head).astype(dtype)
output_nki = np.empty((seqlen, d_head), dtype=dtype)
output_golden = cpu_golden_attn(q_tensor, k_tensor, v_tensor)

if args.mode == "accuracy":
output_nki = fused_self_attn_for_SD_small_head_size(q_tensor, k_tensor, v_tensor)
allclose = np.allclose(output_nki, output_golden, atol=1e-5, rtol=1e-3)
print(f">>>> SD attention matches CPU reference? {allclose}")
assert allclose, "Accuracy check fails!"

else:
benchmark_func = nki.benchmark(fused_self_attn_for_SD_small_head_size,
save_neff_name='file.neff',
save_trace_name='profile.ntff')
benchmark_func(q_tensor, k_tensor, v_tensor)

metrics = benchmark_func.benchmark_result.nc_latency
print(">>>> SD attention benchmark results")
print("latency.p50 = " + str(metrics.get_latency_percentile(50)))
print("latency.p99 = " + str(metrics.get_latency_percentile(99)))
9 changes: 4 additions & 5 deletions src/tutorials/sd_attention/sd_attention_torch.py
Original file line number Diff line number Diff line change
@@ -5,8 +5,8 @@
"""

# NKI_EXAMPLE_32_BEGIN
import torch
from torch_neuronx.xla_impl.ops import nki_jit
from torch_xla.core import xla_model as xm

from sd_attention_nki_kernels import fused_self_attn_for_SD_small_head_size
@@ -28,10 +28,8 @@ def cpu_golden_attn(q, k, v):
q_tensor = torch.rand((4096, 64), dtype=torch.float32).to(device=device)
k_tensor = torch.rand((4096, 64), dtype=torch.float32).to(device=device)
v_tensor = torch.rand((4096, 64), dtype=torch.float32).to(device=device)
output_nki = torch.zeros((4096, 64), dtype=torch.float32).to(device=device)

nki_func = nki_jit(func=fused_self_attn_for_SD_small_head_size)
nki_func(q_tensor, k_tensor, v_tensor, output_nki)
output_nki = fused_self_attn_for_SD_small_head_size(q_tensor, k_tensor, v_tensor)

output_torch = cpu_golden_attn(q_tensor, k_tensor, v_tensor)

@@ -42,4 +40,5 @@ def cpu_golden_attn(q, k, v):
else:
print("NKI and Torch differ")

assert allclose
assert allclose
# NKI_EXAMPLE_32_END
36 changes: 5 additions & 31 deletions src/tutorials/tensor_addition/tensor_addition_jax.py
Original file line number Diff line number Diff line change
@@ -4,42 +4,15 @@
JAX implementation for tensor addition NKI tutorial.
"""
# NKI_EXAMPLE_30_BEGIN
import jax
import jax.numpy as jnp
from jax_neuronx import nki_call
# NKI_EXAMPLE_30_END

from tensor_addition_nki_kernels import nki_tensor_add_kernel_


def nki_tensor_add(a_input, b_input):
"""NKI kernel caller to compute element-wise addition of two input tensors
This kernel caller lifts tile-size restriction, by applying the kernel on tiles of the inputs/outputs
Args:
a_input: a first input tensor, of shape [N*128, M*512]
b_input: a second input tensor, of shape [N*128, M*512]
Returns:
a tensor of shape [N*128, M*512], the result of a_input + b_input
"""

# The SPMD launch grid denotes the number of kernel instances.
# In this case, we use a 2D grid where the size of each invocation is 128x512
grid_x = a_input.shape[0] // 128
grid_y = a_input.shape[1] // 512

out_shape = jax.ShapeDtypeStruct((a_input.shape[0], a_input.shape[1]), dtype=a_input.dtype)

return nki_call(
nki_tensor_add_kernel_,
a_input,
b_input,
grid=(grid_x, grid_y),
out_shape=out_shape,
)
from tensor_addition_nki_kernels import nki_tensor_add


# NKI_EXAMPLE_30_BEGIN
if __name__ == "__main__":

seed_a, seed_b = jax.random.split(jax.random.PRNGKey(42))
@@ -59,3 +32,4 @@ def nki_tensor_add(a_input, b_input):
print("NKI and JAX differ")

assert allclose
# NKI_EXAMPLE_30_END
28 changes: 18 additions & 10 deletions src/tutorials/tensor_addition/tensor_addition_nki_kernels.py
Original file line number Diff line number Diff line change
@@ -5,20 +5,26 @@
"""
import numpy as np
# NKI_EXAMPLE_27_BEGIN
import neuronxcc.nki as nki
import neuronxcc.nki.language as nl


def nki_tensor_add_kernel_(a_input, b_input, c_output):
@nki.jit
def nki_tensor_add_kernel_(a_input, b_input):
"""NKI kernel to compute element-wise addition of two input tensors
This kernel assumes strict input/output tile-sizes, of up-to [128,512]
This kernel assumes strict input/output sizes can be uniformly tiled to [128,512]
Args:
a_input: a first input tensor, of shape [128,512]
b_input: a second input tensor, of shape [128,512]
c_output: an output tensor, of shape [128,512]
a_input: a first input tensor
b_input: a second input tensor
Returns:
c_output: an output tensor
"""
# Create output tensor shared between all SPMD instances as result tensor
c_output = nl.ndarray(a_input.shape, dtype=a_input.dtype, buffer=nl.shared_hbm)

# Calculate tile offsets based on current 'program'
offset_i_x = nl.program_id(0) * 128
@@ -39,7 +45,12 @@ def nki_tensor_add_kernel_(a_input, b_input, c_output):
# store the addition results back to device memory (c_output)
nl.store(c_output[ix, iy], value=c_tile)

# Transfer the ownership of `c_output` to the caller
return c_output
# NKI_EXAMPLE_27_END


# NKI_EXAMPLE_28_BEGIN
def nki_tensor_add(a_input, b_input):
"""NKI kernel caller to compute element-wise addition of two input tensors
@@ -57,12 +68,9 @@ def nki_tensor_add(a_input, b_input):
# In this case, we use a 2D grid where the size of each invocation is 128x512
grid_x = a_input.shape[0] // 128
grid_y = a_input.shape[1] // 512
c_output = np.zeros(a_input.shape, dtype=a_input.dtype)

nki_tensor_add_kernel_baremetal = nki.baremetal(nki_tensor_add_kernel_)
nki_tensor_add_kernel_baremetal[grid_x, grid_y](a_input, b_input, c_output)

return c_output
return nki_tensor_add_kernel_[grid_x, grid_y](a_input, b_input)
# NKI_EXAMPLE_28_END


if __name__ == "__main__":
32 changes: 5 additions & 27 deletions src/tutorials/tensor_addition/tensor_addition_torch.py
Original file line number Diff line number Diff line change
@@ -4,38 +4,15 @@
PyTorch implementation for tensor addition NKI tutorial.
"""
# NKI_EXAMPLE_29_BEGIN
import torch
from torch_xla.core import xla_model as xm
from torch_neuronx import nki_jit
# NKI_EXAMPLE_29_END

from tensor_addition_nki_kernels import nki_tensor_add_kernel_
from tensor_addition_nki_kernels import nki_tensor_add


def nki_tensor_add(a_input, b_input):
"""NKI kernel caller to compute element-wise addition of two input tensors
This kernel caller lifts tile-size restriction, by applying the kernel on tiles of the inputs/outputs
Args:
a_input: a first input tensor, of shape [N*128, M*512]
b_input: a second input tensor, of shape [N*128, M*512]
Returns:
a tensor of shape [N*128, M*512], the result of a_input + b_input
"""

# The SPMD launch grid denotes the number of kernel instances.
# In this case, we use a 2D grid where the size of each invocation is 128x512
grid_x = a_input.shape[0] // 128
grid_y = a_input.shape[1] // 512
c_output = torch.zeros(a_input.shape, dtype=a_input.dtype).to(device=device)

# Decorate the NKI kernel for PyTorch tracing
nki_tensor_add_kernel_torch = nki_jit(nki_tensor_add_kernel_)
nki_tensor_add_kernel_torch[grid_x, grid_y](a_input, b_input, c_output)

return c_output

# NKI_EXAMPLE_29_BEGIN
if __name__ == "__main__":
device = xm.xla_device()

@@ -55,3 +32,4 @@ def nki_tensor_add(a_input, b_input):
print("NKI and Torch differ")

assert allclose
# NKI_EXAMPLE_29_END
16 changes: 5 additions & 11 deletions src/tutorials/transpose2d/transpose2d_jax.py
Original file line number Diff line number Diff line change
@@ -5,25 +5,18 @@
"""

# NKI_EXAMPLE_36_BEGIN
import jax
import jax.numpy as jnp
from functools import partial
from jax_neuronx import nki_call
# NKI_EXAMPLE_36_END

from transpose2d_nki_kernels import tensor_transpose2D_kernel_


def transpose2D(in_tensor, shape2D):
return nki_call(
partial(tensor_transpose2D_kernel_, shape2D=shape2D),
in_tensor,
out_shape=jax.ShapeDtypeStruct(in_tensor.shape, dtype=in_tensor.dtype)
)

# NKI_EXAMPLE_36_BEGIN
if __name__ == "__main__":
P, X, Y = 5, 37, 44
a = jax.random.uniform(jax.random.PRNGKey(42), (P, X * Y))
a_t_nki = transpose2D(a, (X, Y))
a_t_nki = tensor_transpose2D_kernel_(a, shape2D=(X, Y))

a_t_jax = jnp.transpose(a.reshape(P, X, Y), axes=(0, 2, 1)).reshape(P, X * Y)
print(a, a_t_nki, a_t_jax)
@@ -35,3 +28,4 @@ def transpose2D(in_tensor, shape2D):
print("NKI and JAX differ")

assert allclose
# NKI_EXAMPLE_36_END
13 changes: 9 additions & 4 deletions src/tutorials/transpose2d/transpose2d_nki_kernels.py
Original file line number Diff line number Diff line change
@@ -5,11 +5,13 @@
"""

import numpy as np
# NKI_EXAMPLE_33_BEGIN
import neuronxcc.nki as nki
import neuronxcc.nki.language as nl


def tensor_transpose2D_kernel_(in_tensor, out_tensor, shape2D):
@nki.jit
def tensor_transpose2D_kernel_(in_tensor, shape2D):
"""
NKI kernel to reorder the elements on axis[1] of the input tensor.
@@ -36,6 +38,8 @@ def tensor_transpose2D_kernel_(in_tensor, out_tensor, shape2D):
shape2D: tuple representing the dimensions to be transposed: (#rows, #cols)
out_tensor: an output (transposed) tensor
"""
out_tensor = nl.ndarray(in_tensor.shape, dtype=in_tensor.dtype,
buffer=nl.shared_hbm)
# Gather input shapes
sz_p, _ = in_tensor.shape

@@ -64,14 +68,15 @@ def tensor_transpose2D_kernel_(in_tensor, out_tensor, shape2D):
# Finally, we store out_tile to external memory
nl.store(out_tensor, value=out_tile)

return out_tensor
# NKI_EXAMPLE_33_END


if __name__ == "__main__":
P, X, Y = 5, 3, 4
a = np.arange(P*X*Y, dtype=np.int8).reshape((P, X*Y))
a_t_nki = np.zeros((P, Y*X), dtype=np.int8)

tensor_transpose2D_kernel_torch = nki.baremetal(tensor_transpose2D_kernel_)
tensor_transpose2D_kernel_torch(a, a_t_nki, (X, Y))
a_t_nki = tensor_transpose2D_kernel_(a, (X, Y))

a_t_np = np.transpose(a.reshape(P, X, Y), (0, 2, 1)).reshape(P, X * Y)

8 changes: 5 additions & 3 deletions src/tutorials/transpose2d/transpose2d_torch.py
Original file line number Diff line number Diff line change
@@ -4,22 +4,23 @@
PyTorch implementation for transpose2d NKI tutorial.
"""

# NKI_EXAMPLE_34_BEGIN
import torch
from torch_xla.core import xla_model as xm
from torch_neuronx import nki_jit
# NKI_EXAMPLE_34_END

from transpose2d_nki_kernels import tensor_transpose2D_kernel_


# NKI_EXAMPLE_34_BEGIN
if __name__ == "__main__":
device = xm.xla_device()

P, X, Y = 5, 3, 4
a = torch.arange(P*X*Y, dtype=torch.int8).reshape((P, X*Y)).to(device=device)
a_t_nki = torch.zeros((P, Y*X), dtype=torch.int8).to(device=device)

tensor_transpose2D_kernel_torch = nki_jit(tensor_transpose2D_kernel_)
tensor_transpose2D_kernel_torch(a, a_t_nki, (X, Y))
a_t_nki = tensor_transpose2D_kernel_(a, (X, Y))

a_t_torch = torch.transpose(a.reshape(P, X, Y), 1, 2).reshape(P, X * Y)

@@ -32,3 +33,4 @@
print("NKI and PyTorch differ")

assert allclose
# NKI_EXAMPLE_34_END
2 changes: 2 additions & 0 deletions test/integration/flash_attention/flash_attention_benchmark.py
Original file line number Diff line number Diff line change
@@ -14,6 +14,8 @@

from flash_attention import nki_flash_attn_func

parent_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(parent_dir)
from perf_utils.LatencyCollector import benchmark

if len(sys.argv) != 2:
Original file line number Diff line number Diff line change
@@ -8,6 +8,8 @@
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
from diffusers.models.unet_2d_condition import UNet2DConditionOutput

parent_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(parent_dir)
from perf_utils.LatencyCollector import benchmark

import sys
Original file line number Diff line number Diff line change
@@ -23,6 +23,8 @@
else:
from diffusers.models.cross_attention import CrossAttention

parent_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(parent_dir)
from perf_utils.LatencyCollector import benchmark

import sys
19 changes: 9 additions & 10 deletions test/unit/test_SD_attention_small_head.py
Original file line number Diff line number Diff line change
@@ -11,7 +11,7 @@

test_trace_file_path='local_trace.ntff'
numeric_func = baremetal(fused_self_attn_for_SD_small_head_size)
bench_func = benchmark(warmup=5, iters=10, save_trace_name=test_trace_file_path)(fused_self_attn_for_SD_small_head_size)
bench_func = benchmark(warmup=5, iters=20, save_trace_name=test_trace_file_path)(fused_self_attn_for_SD_small_head_size)

def cpu_golden_attn(q, k, v):
softmax_scale = 0.125
@@ -34,16 +34,16 @@ def test_attention_for_SD_perf(self, bs, seqlen, d, dtype, latency):
q = np.random.random_sample((bs, d, seqlen)).astype(np.float32)
k = np.random.random_sample((bs, seqlen, d)).astype(np.float32)
v = np.random.random_sample((bs, seqlen, d)).astype(np.float32)
out = nl.static_cast(np.ndarray(shape=(bs, seqlen, d)), dtype)


q_dev = nl.static_cast(q, dtype)
k_dev = nl.static_cast(k, dtype)
v_dev = nl.static_cast(v, dtype)

bench_func[bs](q_dev, k_dev, v_dev, out)
latency_res = bench_func.benchmark_result.nc_latency
p99 = latency_res.get_latency_percentile(99)
assert p99 <= latency
bench_func_ = bench_func[bs]
bench_func_(q_dev, k_dev, v_dev)
latency_res = bench_func_.benchmark_result.nc_latency
p50 = latency_res.get_latency_percentile(50)
assert p50 <= latency*1.05 # short running kernels are subjected to hardware fluctuation
assert os.path.getsize(test_trace_file_path) > 0

@pytest.mark.parametrize("bs,seqlen,d,dtype", [
@@ -54,13 +54,12 @@ def test_attention_for_SD_numberic(self, bs, seqlen, d, dtype):
q = np.random.random_sample((bs, d, seqlen)).astype(np.float32)
k = np.random.random_sample((bs, seqlen, d)).astype(np.float32)
v = np.random.random_sample((bs, seqlen, d)).astype(np.float32)
out = nl.static_cast(np.ndarray(shape=(bs, seqlen, d)), dtype)


q_dev = nl.static_cast(q, dtype)
k_dev = nl.static_cast(k, dtype)
v_dev = nl.static_cast(v, dtype)

numeric_func[bs](q_dev, k_dev, v_dev, out)
out = numeric_func[bs](q_dev, k_dev, v_dev)
out = nl.static_cast(out, np.float32)
golden_result = cpu_golden_attn(q, k, v)
assert np.allclose(out, golden_result, atol=1e-2)
67 changes: 67 additions & 0 deletions test/unit/test_allocated_SD_attention_small_head.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
"""
Copyright (c) 2023, Amazon.com. All Rights Reserved
"""
import os
import pytest
from neuronxcc.nki.kernels.allocated_attention import allocated_fused_self_attn_for_SD_small_head_size
from neuronxcc.nki import benchmark, baremetal
import neuronxcc.nki as nki
import neuronxcc.nki.language as nl
import numpy as np
from scipy.special import softmax

test_trace_file_path='local_trace.ntff'
numeric_func = baremetal(allocated_fused_self_attn_for_SD_small_head_size)
bench_func = benchmark(warmup=5, iters=20, save_trace_name=test_trace_file_path)(allocated_fused_self_attn_for_SD_small_head_size)

def cpu_golden_attn(q, k, v):
softmax_scale = 0.125
q_scaled = q * softmax_scale
raw_score = np.matmul(q_scaled.transpose(0, 2, 1), k)

norm_score = softmax(raw_score, axis=-1)

# Transpose the result so it has the same layout as ours
return np.matmul(norm_score, v).transpose(0, 2, 1)

class TestAttention:

@pytest.mark.parametrize("bs,seqlen,d,dtype,latency", [
[1, 4096, 128, np.float32, 410],
[1, 4096, 128, nl.bfloat16, 350],
[1, 5120, 128, nl.bfloat16, 586]
])
def test_allocated_attention_for_SD_perf(self, bs, seqlen, d, dtype, latency):
q = np.random.random_sample((bs, d, seqlen)).astype(np.float32)
k = np.random.random_sample((bs, d, seqlen)).astype(np.float32)
v = np.random.random_sample((bs, seqlen, d)).astype(np.float32)

q_dev = nl.static_cast(q, dtype)
k_dev = nl.static_cast(k, dtype)
v_dev = nl.static_cast(v, dtype)

bench_func_ = bench_func[bs]
bench_func_(q_dev, k_dev, v_dev)
latency_res = bench_func_.benchmark_result.nc_latency
p50 = latency_res.get_latency_percentile(50)
assert p50 <= latency * 1.05 # short running kernels are subjected to hardware fluctuation
assert os.path.getsize(test_trace_file_path) > 0

@pytest.mark.parametrize("bs,seqlen,d,dtype", [
[1, 4096, 128, np.float32],
[1, 4096, 128, nl.bfloat16],
[1, 5120, 128, nl.bfloat16]
])
def test_allocated_attention_for_SD_numberic(self, bs, seqlen, d, dtype):
q = np.random.random_sample((bs, d, seqlen)).astype(np.float32)
k = np.random.random_sample((bs, d, seqlen)).astype(np.float32)
v = np.random.random_sample((bs, seqlen, d)).astype(np.float32)

q_dev = nl.static_cast(q, dtype)
k_dev = nl.static_cast(k, dtype)
v_dev = nl.static_cast(v, dtype)

out = numeric_func[bs](q_dev, k_dev, v_dev)
out = nl.static_cast(out, np.float32)
golden_result = cpu_golden_attn(q, k, v)
assert np.allclose(out, golden_result, atol=1e-2)
31 changes: 12 additions & 19 deletions test/unit/test_flash_attn_bwd.py
Original file line number Diff line number Diff line change
@@ -7,6 +7,8 @@
import neuronxcc.nki.language as nl
import numpy as np

from TestDecorators import xfail

numeric_func = baremetal(flash_attn_bwd)
bench_func = benchmark(warmup=5, iters=10)(flash_attn_bwd)

@@ -85,6 +87,7 @@ def mixed_precision_matmul(a, b):

class TestAttention:

@xfail # P167481231
@pytest.mark.parametrize("bs, nheads, seqlen, d, dtype, latency", [
[1, 4, 32*1024, 128, nl.bfloat16, 117000],
])
@@ -97,23 +100,16 @@ def test_flash_attn_bwd_perf(self, bs, nheads, seqlen, d, dtype, latency):
lse = np.random.random_sample([bs, nheads, nl.tile_size.pmax, seqlen // nl.tile_size.pmax]).astype(np.float32)
seed = None

out_dq = np.zeros(shape=[bs, nheads, d, seqlen], dtype=dtype)
out_dk = np.zeros(shape=[bs, nheads, d, seqlen], dtype=dtype)
out_dv = np.zeros(shape=[bs, nheads, d, seqlen], dtype=dtype)

q = nl.static_cast(q, dtype)
k = nl.static_cast(k, dtype)
v = nl.static_cast(v, dtype)
o_proj = nl.static_cast(o_proj, dtype)
dy = nl.static_cast(dy, dtype)
out_dq = nl.static_cast(out_dq, dtype)
out_dk = nl.static_cast(out_dk, dtype)
out_dv = nl.static_cast(out_dv, dtype)

bench_func[bs, nheads](q, k, v, o_proj, dy, lse, seed,
out_dq, out_dk, out_dv,
use_causal_mask=True, mixed_precision=True)
latency_res = bench_func.benchmark_result.nc_latency

bench_func_ = bench_func[bs, nheads]
bench_func_(q, k, v, o_proj, dy, lse, seed,
use_causal_mask=True, mixed_precision=True)
latency_res = bench_func_.benchmark_result.nc_latency
p99 = latency_res.get_latency_percentile(99)
assert p99 <= latency

@@ -130,10 +126,7 @@ def test_flash_attn_bwd_numerical(self, bs, nheads, seqlen, d, dtype):
v = nl.static_cast(v, dtype)
dy = nl.static_cast(dy, dtype)
seed = None
out_dq = np.zeros(shape=[bs, nheads, d, seqlen], dtype=dtype)
out_dk = np.zeros(shape=[bs, nheads, d, seqlen], dtype=dtype)
out_dv = np.zeros(shape=[bs, nheads, d, seqlen], dtype=dtype)


dq_golden, dk_golden, dv_golden, cached_negative_max, cached_sum_reciprocal, o_proj = \
cpu_attention_backward(q, k, v, dy, use_causal_mask=True)
cached_negative_max = cached_negative_max.reshape(bs, nheads, seqlen // nl.tile_size.pmax,
@@ -142,9 +135,9 @@ def test_flash_attn_bwd_numerical(self, bs, nheads, seqlen, d, dtype):
nl.tile_size.pmax).transpose(0, 1, 3, 2)
lse = -1.0 * (cached_negative_max + np.log(cached_sum_reciprocal))

numeric_func[bs, nheads](q, k, v, o_proj, dy, lse, seed,
out_dq, out_dk, out_dv,
use_causal_mask=True, mixed_precision=True)
out_dq, out_dk, out_dv = numeric_func[bs, nheads](q, k, v, o_proj, dy, lse, seed,
use_causal_mask=True,
mixed_precision=True)

assert np.allclose(out_dq, dq_golden, atol=1e-2)
assert np.allclose(out_dk, dk_golden, atol=1e-2)
69 changes: 39 additions & 30 deletions test/unit/test_flash_attn_fwd.py
Original file line number Diff line number Diff line change
@@ -63,75 +63,84 @@ def mixed_precision_matmul(a, b):

class TestAttention:

@pytest.mark.parametrize("bs, nheads, seqlen, d, dtype, use_causal_mask,\
@pytest.mark.parametrize("bs, nheads, seqlen_q, seqlen_k, d, dtype, use_causal_mask,\
mixed_precision, training, tile_size, kv_heads, should_transpose_v, latency", [
[1, 6, 32*1024, 96, nl.bfloat16, True, True, True, 2048, 3, False, 87000000000],
[1, 1, 32*1024, 96, nl.bfloat16, True, True, False, 2048, None, False, 15100000000],
[1, 6, 32*1024, 32*1024, 96, nl.bfloat16, True, True, True, 2048, 3, False, 87000000000],
[1, 1, 32*1024, 32*1024, 96, nl.bfloat16, True, True, False, 2048, None, False, 15100000000],
# Non-square
[1, 3, 32*1024, 16*1024, 96, nl.bfloat16, True, True, False, 2048, None, False, 7550000000],
[1, 3, 16*1024, 32*1024, 96, nl.bfloat16, True, True, False, 2048, None, False, 7550000000],
])
def test_flash_attn_fwd_perf(self, bs, nheads, seqlen, d, dtype, use_causal_mask,
def test_flash_attn_fwd_perf(self, bs, nheads, seqlen_q, seqlen_k, d, dtype, use_causal_mask,
mixed_precision, training, tile_size, kv_heads, should_transpose_v,latency):
q = (np.random.random_sample([bs, nheads, d, seqlen]) - 0.5) * 2
k = (np.random.random_sample([bs, nheads, d, seqlen]) - 0.5) * 2
q = (np.random.random_sample([bs, nheads, d, seqlen_q]) - 0.5) * 2
k = (np.random.random_sample([bs, nheads, d, seqlen_k]) - 0.5) * 2
if should_transpose_v:
v = (np.random.random_sample([bs, nheads, d, seqlen]) - 0.5) * 2
v = (np.random.random_sample([bs, nheads, d, seqlen_k]) - 0.5) * 2
else:
v = (np.random.random_sample([bs, nheads, seqlen, d]) - 0.5) * 2
o_proj = np.zeros(shape=[bs, nheads, seqlen, d], dtype=dtype)
out_lse = np.zeros(shape=[bs, nheads, int(nl.tile_size.pmax), seqlen // nl.tile_size.pmax],
v = (np.random.random_sample([bs, nheads, seqlen_k, d]) - 0.5) * 2
o_proj = np.zeros(shape=[bs, nheads, seqlen_q, d], dtype=dtype)
out_lse = np.zeros(shape=[bs, nheads, int(nl.tile_size.pmax), seqlen_q // nl.tile_size.pmax],
dtype=nl.float32 if mixed_precision else dtype) if training else None
seed = None

q = nl.static_cast(q, dtype)
k = nl.static_cast(k, dtype)
v = nl.static_cast(v, dtype)
o_proj = nl.static_cast(o_proj, dtype)
config = FlashConfig(**{'seq_tile_size':tile_size, 'training':training, 'should_transpose_v':should_transpose_v})

heads = nheads if kv_heads is None else kv_heads

bench_func[bs, heads](q, k, v, seed, o_proj, out_lse,
use_causal_mask=use_causal_mask, mixed_precision=mixed_precision, config=config)
latency_res = bench_func.benchmark_result.nc_latency
bench_func_ = bench_func[bs, heads]
bench_func_(q, k, v, seed, use_causal_mask=use_causal_mask,
mixed_precision=mixed_precision, config=config)
latency_res = bench_func_.benchmark_result.nc_latency
p99 = latency_res.get_latency_percentile(99)
assert p99 <= latency

@pytest.mark.parametrize("bs, nheads, seqlen, d, dtype, use_causal_mask,\
@pytest.mark.parametrize("bs, nheads, seqlen_q, seqlen_k, d, dtype, use_causal_mask,\
training, tile_size, kv_heads, should_transpose_v", [
[1, 6, 4096, 128, np.float32, True, True, 2048, 3, False],
[1, 1, 4096, 128, np.float32, True, False, 2048, None, False],
[1, 6, 4096, 4096, 128, np.float32, True, True, 2048, 3, False],
[1, 1, 4096, 4096, 128, np.float32, True, False, 2048, None, False],
[1, 1, 8192, 4096, 128, np.float32, True, False, 2048, None, False],
[1, 1, 4096, 8192, 128, np.float32, True, False, 2048, None, False],
])
def test_flash_attn_fwd_numerical(self, bs, nheads, seqlen, d, dtype, use_causal_mask,
def test_flash_attn_fwd_numerical(self, bs, nheads, seqlen_q, seqlen_k, d, dtype, use_causal_mask,
training, tile_size, kv_heads, should_transpose_v):
q = (np.random.random_sample([bs, nheads, d, seqlen]) - 0.5) * 2
k = (np.random.random_sample([bs, kv_heads or nheads, d, seqlen]) - 0.5) * 2
q = (np.random.random_sample([bs, nheads, d, seqlen_q]) - 0.5) * 2
k = (np.random.random_sample([bs, kv_heads or nheads, d, seqlen_k]) - 0.5) * 2
if should_transpose_v:
v = (np.random.random_sample([bs, nheads, d, seqlen]) - 0.5) * 2
v = (np.random.random_sample([bs, nheads, d, seqlen_k]) - 0.5) * 2
cpu_permute = (0, 1, 2, 3)
else:
v = (np.random.random_sample([bs, kv_heads or nheads, seqlen, d]) - 0.5) * 2
v = (np.random.random_sample([bs, kv_heads or nheads, seqlen_k, d]) - 0.5) * 2
cpu_permute = (0, 1, 3, 2)
o_proj = np.zeros(shape=[bs, nheads, seqlen, d], dtype=dtype)

q = nl.static_cast(q, dtype)
k = nl.static_cast(k, dtype)
v = nl.static_cast(v, dtype)
seed = None
out_lse = np.zeros(shape=[bs, nheads, int(nl.tile_size.pmax), seqlen // nl.tile_size.pmax],
dtype=np.float32) if training else None

o_proj_golden, cached_negative_max, cached_sum_reciprocal = \
cpu_attention_forward(q, k, v.transpose(cpu_permute), use_causal_mask=use_causal_mask,mixed_precision=True)
o_proj_golden = o_proj_golden.transpose(0,1,3,2) # (b,h, d, seq)
cached_negative_max = cached_negative_max.reshape(bs, nheads, seqlen // nl.tile_size.pmax,
cached_negative_max = cached_negative_max.reshape(bs, nheads, seqlen_q // nl.tile_size.pmax,
nl.tile_size.pmax).transpose(0, 1, 3, 2)
cached_sum_reciprocal = cached_sum_reciprocal.reshape(bs, nheads, seqlen // nl.tile_size.pmax,
cached_sum_reciprocal = cached_sum_reciprocal.reshape(bs, nheads, seqlen_q // nl.tile_size.pmax,
nl.tile_size.pmax).transpose(0, 1, 3, 2)
lse_golden = -1.0 * (cached_negative_max + np.log(cached_sum_reciprocal)) if training else None
config = FlashConfig(**{'seq_tile_size':tile_size, 'training':training, 'should_transpose_v':should_transpose_v})

heads = nheads if kv_heads is None else kv_heads
numeric_func[bs, heads](q, k, v, seed, o_proj, out_lse, seed,
use_causal_mask=use_causal_mask, mixed_precision=True, config=config)
results = numeric_func[bs, heads](q, k, v, seed,
use_causal_mask=use_causal_mask,
mixed_precision=True,
config=config)

assert np.allclose(o_proj, o_proj_golden, atol=1e-2)
if training:
o_proj, out_lse = results
assert np.allclose(o_proj, o_proj_golden, atol=1e-2)
assert np.allclose(out_lse, lse_golden, atol=1e-2)
else:
o_proj = results
assert np.allclose(o_proj, o_proj_golden, atol=1e-2)
16 changes: 8 additions & 8 deletions test/unit/test_resize_nearest.py
Original file line number Diff line number Diff line change
@@ -11,6 +11,7 @@
numeric_func = baremetal(resize_nearest_fixed_dma_kernel)
bench_func = benchmark(warmup=5, iters=10)(resize_nearest_fixed_dma_kernel)


def cpu_golden_result(data_tensor, output_shape):
in_b, in_h, in_w, in_c = data_tensor.shape
out_b, out_h, out_w, out_c = output_shape
@@ -36,18 +37,18 @@ def cpu_golden_result(data_tensor, output_shape):
class TestResizeNearest:

@pytest.mark.parametrize("in_b, in_h, in_w, in_c, out_b, out_h, out_w, out_c, dtype, latency", [
[10, 30, 20, 1280, 10, 59, 38, 1280, np.float32, 1722],
[10, 30, 20, 1280, 10, 59, 38, 1280, np.float32, 1740],
[1, 30, 20, 1280, 1, 59, 38, 1280, nl.float16, 659],
[1, 30, 20, 1280, 1, 59, 38, 1280, nl.bfloat16, 659],
])
def test_resize_nearest_for_perf(self, in_b, in_h, in_w, in_c, out_b, out_h, out_w, out_c, dtype, latency):
input_tensor = np.random.random_sample((in_b, in_h, in_w, in_c)).astype(np.float32)
output_tensor = nl.static_cast(np.ndarray(shape=(out_b, out_h, out_w, out_c)), dtype)


input_dev = nl.static_cast(input_tensor, dtype)

bench_func[in_b](input_dev, output_tensor)
latency_res = bench_func.benchmark_result.nc_latency
bench_func_ = bench_func[in_b]
bench_func_(input_dev, (out_b, out_h, out_w, out_c))
latency_res = bench_func_.benchmark_result.nc_latency
p99 = latency_res.get_latency_percentile(99)
assert p99 <= latency

@@ -58,11 +59,10 @@ def test_resize_nearest_for_perf(self, in_b, in_h, in_w, in_c, out_b, out_h, out
])
def test_resize_nearest_for_numberic(self, in_b, in_h, in_w, in_c, out_b, out_h, out_w, out_c, dtype):
input_tensor = np.random.random_sample((in_b, in_h, in_w, in_c)).astype(np.float32)
output_tensor = nl.static_cast(np.ndarray(shape=(out_b, out_h, out_w, out_c)), dtype)


input_dev = nl.static_cast(input_tensor, dtype)

numeric_func[in_b](input_dev, output_tensor)
output_tensor = numeric_func[in_b](input_dev, (out_b, out_h, out_w, out_c))
output_tensor = nl.static_cast(output_tensor, np.float32)
golden_result = cpu_golden_result(input_tensor, output_tensor.shape)
assert np.allclose(output_tensor, golden_result, atol=1e-2)
65 changes: 65 additions & 0 deletions test/unit/test_rmsnorm_qkv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
"""
Copyright (c) 2024, Amazon.com. All Rights Reserved
"""
import pytest
from neuronxcc.nki.kernels.allocated_fused_linear import allocated_fused_rms_norm_qkv
from neuronxcc.nki import benchmark, baremetal
import neuronxcc.nki.language as nl
import numpy as np

numeric_func = baremetal(allocated_fused_rms_norm_qkv)
bench_func = benchmark(warmup=5, iters=10)(allocated_fused_rms_norm_qkv)

np.random.seed(0)


def rms_norm(hidden, gamma, eps=1e-6):
rms = np.sqrt(np.mean(np.square(hidden), axis=-1, keepdims=True))
norm = hidden * np.reciprocal(rms + eps)
if gamma is not None:
norm *= gamma
return norm

def cpu_golden_result(hidden, gamma, qkv_weights, dtype, do_norm=True):
if do_norm:
hidden = rms_norm(hidden, gamma)
qkv_out = (hidden @ qkv_weights).astype(dtype)
return qkv_out

class TestRMSNormQKV:
@pytest.mark.parametrize("batch, seqlen, dim, d_head, dtype, latency", [
[1, 128, 512, 512, np.float16, 25],
[1, 512, 1024, 384, nl.bfloat16, 40],
[1, 128, 1024, 512, nl.bfloat16, 28],
[1, 1024, 8192, 512, nl.bfloat16, 301 * 1.02]
])
def test_allocated_rmsnorm_qkv_perf(self, batch, seqlen, dim, d_head, dtype, latency):
hidden = np.random.random_sample((batch, seqlen, dim)).astype(np.float32)
weights = np.random.random_sample((dim, d_head)).astype(np.float32)

hidden = nl.static_cast(hidden, dtype)
weights = nl.static_cast(weights, dtype)

bench_func(hidden, weights)
latency_res = bench_func.benchmark_result.nc_latency
p99 = latency_res.get_latency_percentile(99)
assert p99 <= latency

@pytest.mark.parametrize("batch, seqlen, dim, d_head, dtype", [
[1, 128, 512, 512, np.float16],
[1, 512, 1024, 384, nl.bfloat16],
[1, 128, 1024, 512, nl.bfloat16],
[1, 1024, 8192, 512, nl.bfloat16]
])
def test_allocated_rmsnorm_qkv_numeric(self, batch, seqlen, dim, d_head, dtype):
hidden = np.random.random_sample((batch, seqlen, dim))
weights = np.random.random_sample((dim, d_head))

hidden_dev = nl.static_cast(hidden, dtype)
weights_dev = nl.static_cast(weights, dtype)

out = numeric_func(hidden_dev, weights_dev)
out = nl.static_cast(out, np.float32)
golden_res = nl.static_cast(cpu_golden_result(hidden, None, weights, dtype, do_norm=True), np.float32)
assert np.allclose(out, golden_res, atol=1e-2, rtol=1e-2)

32 changes: 14 additions & 18 deletions test/unit/test_select_and_scatter.py
Original file line number Diff line number Diff line change
@@ -39,39 +39,35 @@ def cpu_golden_result(operand_tensor, source_tensor, window_dimensions=(3, 3), w
out_h = h * stride_h + local_h - padding[0]
out_w = w * stride_w + local_w - padding[1]
output_tensor[n, c, out_h, out_w] += source_tensor[n, c, h, w]

return output_tensor

class TestSelectAndScatter:
@pytest.mark.parametrize("n, c, operand_h, operand_w, source_h, source_w, dtype, latency", [
[8, 64, 112, 112, 56, 56, np.float32, 4500],
])
def test_select_and_scatter_for_perf(self, n, c, operand_h, operand_w, source_h, source_w, dtype, latency):
operand_tensor = np.random.random_sample((n, c, operand_h, operand_w)).astype(np.float32)
source_tensor = np.random.random_sample((n, c, source_h, source_w)).astype(np.float32)
output_tensor = nl.static_cast(np.ndarray(shape=(n, c, operand_h, operand_w)), dtype)

operand_dev = nl.static_cast(operand_tensor, dtype)
source_dev = nl.static_cast(source_tensor, dtype)
operand_dev = nl.static_cast(np.random.random_sample((n, c, operand_h, operand_w)), dtype)
source_dev = nl.static_cast(np.random.random_sample((n, c, source_h, source_w)), dtype)

bench_func(operand_dev, source_dev, output_tensor)
bench_func(operand_dev, source_dev)
latency_res = bench_func.benchmark_result.nc_latency
p99 = latency_res.get_latency_percentile(99)
assert p99 <= latency

@pytest.mark.parametrize("n, c, operand_h, operand_w, source_h, source_w, dtype", [
[8, 64, 112, 112, 56, 56, np.float32],
pytest.param(8, 64, 112, 112, 56, 56, nl.bfloat16, marks=pytest.mark.xfail),
[8, 64, 112, 112, 56, 56, nl.bfloat16],
])
def test_select_and_scatter_for_numeric(self, n, c, operand_h, operand_w, source_h, source_w, dtype):
operand_tensor = np.random.random_sample((n, c, operand_h, operand_w)).astype(np.float32)
source_tensor = np.random.random_sample((n, c, source_h, source_w)).astype(np.float32)
output_tensor = nl.static_cast(np.ndarray(shape=(n, c, operand_h, operand_w)), dtype)

operand_dev = nl.static_cast(operand_tensor, dtype)
source_dev = nl.static_cast(source_tensor, dtype)
operand_dev = nl.static_cast(np.random.random_sample((n, c, operand_h, operand_w)), dtype)
source_dev = nl.static_cast(np.random.random_sample((n, c, source_h, source_w)), dtype)

sw = nl.static_cast(np.ndarray(shape=(n, c, source_h, source_w, 3, 3)), dtype)
operand_tensor = nl.static_cast(operand_dev, np.float32)
source_tensor = nl.static_cast(source_dev, np.float32)

numeric_func(operand_dev, source_dev, output_tensor)
output_dev = numeric_func(operand_dev, source_dev)
golden_result = cpu_golden_result(operand_tensor, source_tensor)
output_tensor = nl.static_cast(output_tensor, np.float32)
assert np.allclose(output_tensor, golden_result)
nki_result = nl.static_cast(output_dev, np.float32)

assert np.allclose(nki_result, golden_result, rtol=1e-2, atol=1e-2)