From a33dfbc883a80f3875ec534a6bb5e7d3b159543e Mon Sep 17 00:00:00 2001
From: neuron-code-sharing-robot
 <163452788+neuron-code-sharing-robot@users.noreply.github.com>
Date: Fri, 13 Dec 2024 13:29:52 -0800
Subject: [PATCH 1/3] NeuronSDK 2.21: NKI-Samples Update (#6)

Update nki-samples for NeuronSDK 2.21 Beta Release.
1. Use new nki.jit decorator for kernels.
2. Added samples for the new direct allocation feature.
3. Misc tests and documentation improvements.

Co-authored-by: aws-qieqingy <122939906+aws-qieqingy@users.noreply.github.com>
---
 CONTRIBUTING.md                               |  69 +----
 LICENSE                                       |  16 -
 LICENSE.txt                                   |   1 +
 README.md                                     |  93 ++++--
 src/reference/__init__.py                     |  22 +-
 src/reference/allocated_attention.py          | 283 ++++++++++++++++++
 src/reference/allocated_fused_linear.py       | 114 +++++++
 src/reference/attention.py                    | 213 ++++++++-----
 src/reference/tutorial.py                     |  22 +-
 src/reference/vision.py                       |  21 +-
 .../average_pool2d/average_pool2d_jax.py      |  24 +-
 .../average_pool2d_nki_kernels.py             |  49 ++-
 .../average_pool2d/average_pool2d_torch.py    |  11 +-
 .../fused_mamba/mamba_nki_kernels.py          |  38 ++-
 src/tutorials/fused_mamba/mamba_torch.py      |   7 +-
 .../layernorm/layernorm_nki_kernel.py         | 143 +++++----
 src/tutorials/layernorm/layernorm_torch.py    |  18 +-
 .../matrix_multiplication_nki_kernels.py      |  52 +++-
 .../matrix_multiplication_torch.py            |  13 +-
 src/tutorials/rmsnorm/rmsnorm_jax.py          |   9 +-
 src/tutorials/rmsnorm/rmsnorm_nki_kernels.py  |  14 +-
 src/tutorials/rmsnorm/rmsnorm_torch.py        |   8 +-
 .../sd_attention/sd_attention_nki_kernels.py  | 101 ++++---
 .../sd_attention/sd_attention_torch.py        |   9 +-
 .../tensor_addition/tensor_addition_jax.py    |  36 +--
 .../tensor_addition_nki_kernels.py            |  28 +-
 .../tensor_addition/tensor_addition_torch.py  |  32 +-
 src/tutorials/transpose2d/transpose2d_jax.py  |  16 +-
 .../transpose2d/transpose2d_nki_kernels.py    |  13 +-
 .../transpose2d/transpose2d_torch.py          |   8 +-
 .../flash_attention_benchmark.py              |   2 +
 .../sd2_512_benchmark.py                      |   2 +
 .../sd2_inpainting_936_624_benchmark.py       |   2 +
 test/unit/test_SD_attention_small_head.py     |  19 +-
 .../test_allocated_SD_attention_small_head.py |  67 +++++
 test/unit/test_flash_attn_bwd.py              |  31 +-
 test/unit/test_flash_attn_fwd.py              |  69 +++--
 test/unit/test_resize_nearest.py              |  16 +-
 test/unit/test_rmsnorm_qkv.py                 |  65 ++++
 test/unit/test_select_and_scatter.py          |  32 +-
 40 files changed, 1202 insertions(+), 586 deletions(-)
 delete mode 100644 LICENSE
 create mode 100644 LICENSE.txt
 create mode 100644 src/reference/allocated_attention.py
 create mode 100644 src/reference/allocated_fused_linear.py
 create mode 100644 test/unit/test_allocated_SD_attention_small_head.py
 create mode 100644 test/unit/test_rmsnorm_qkv.py

diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index 4c16260..32ce44e 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -1,6 +1,6 @@
 # Contributing Guidelines
 
-Thank you for your interest in contributing to our project. Whether it's a new NKI kernel, improving existing kernel code, bug fix, new feature, correction, or additional
+Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional
 documentation, we greatly value feedback and contributions from our community.
 
 Please read through this document before submitting any issues or pull requests to ensure we have all the necessary
@@ -24,13 +24,14 @@ reported the issue. Please try to include as much information as you can. Detail
 Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that:
 
 1. You are working against the latest source on the *main* branch.
-2. You check existing open, and recently merged pull requests to make sure someone else hasn't addressed the problem already.
+2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already.
+3. You open an issue to discuss any significant work - we would hate for your time to be wasted.
 
 To send us a pull request, please:
 
 1. Fork the repository.
-2. Modify the source; please focus on the specific changes you are contributing. If you also reformat all the code, it will be hard for us to focus on your change.
-3. Please ensure your change satisfies the requirements listed in [Testing Requirements](#testing-requirements) and [Coding Guidelines](#coding-guidelines)
+2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change.
+3. Ensure local tests pass.
 4. Commit to your fork using clear commit messages.
 5. Send us a pull request, answering any default questions in the pull request interface.
 6. Wait for a repository collaborator to look at your pull request, run the automated tests, and review. If additional changes or discussion is needed, a collaborator will get back to you, so please stay involved in the conversation.
@@ -39,64 +40,8 @@ To send us a pull request, please:
 GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and
 [creating a pull request](https://help.github.com/articles/creating-a-pull-request/).
 
-### Testing Requirements
-Running the binaries for a NKI kernel require Neuron devices on an AWS EC2 instance from trn1, trn1n, or inf2 instance families. 
-Details on setting up an instance can be found in [here](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/frameworks/torch/torch-setup.html).
 
-If you would like to test your kernel without requiring a Neuron device, you can use `nki.simulate()` to run your kernel using `NumPy` input/output tensors and types. 
-An example can be found in the [layernorm tutorial test](test/unit/test_tutorials_layernorm.py). However, kernels with _only_ simulation tests will not be accepted.
-
-#### Requirements for Kernels Targeting `src/reference/`
-
-All kernels located in this folder need to have the following tests.
-
-1. Numeric accuracy tests with `nki.baremetal`. The output from the kernel
-must be validated against a CPU reference implementation. See `test_flash_attn_fwd_numerical` in [test_flash_attn_fwd.py](test/unit/test_flash_attn_fwd.py) as an example. Documentation for `nki.baremetal` is available at [here](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/nki/api/generated/nki.baremetal.html).
-
-2. Performance benchmark tests with `nki.benchmark`. The unit test must have performance checks. At a minimum, put an assertion to verify p99 latency meets a certain threshold. See `test_flash_attn_fwd_perf` in [test_flash_attn_fwd.py](test/unit/test_flash_attn_fwd.py) as an example. Documentation for `nki.benchmark` is available at [here](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/nki/api/generated/nki.benchmark.html)
-
-3. End-to-End integration tests that use your kernel in a model. 
-    
-    a. Each test should be in its own separate folder.
-
-    b. Each Test must have a `run.sh` script, that accepts an argument \<path to test_result.json\>. See [run.sh of FlashAttention](test/integration/flash_attention/run.sh) as an example. 
-
-    c. The test scripts must produce benchmark results with the `benchmark` function, located in [LatencyCollector.py](test/integration/perf_utils/LatencyCollector.py). The `benchmark` function will write the latency of your E2E model to the `test_result.json`.
-
-    d. Register your test target in [run_integration.sh](test/integration/run_integration.sh).
-
-
-### Coding Guidelines
-Most guidelines are covered by a **PEP-8** check on all newly submitted code, which covers aspects such as code layout and basic Python naming conventions. 
-In addition to PEP-8, we use the following NKI specific style guidelines:
-
-1. **Abbreviations**
-    * Importing NKI modules should use consistent names. For example,
-    ```
-    import neuronxcc.nki as nki
-    import neuronxcc.nki.isa as nisa
-    import neuronxcc.nki.language as nl
-    import neuronxcc.nki.typing as nt
-    import numpy as np
-    ```   
-2. Variable Names
-    * Indexing should specify partition and free dimensions along with the variable they are used for. For example:
-        The index for the partition dimension for tile `a` would be
-        ```
-        i_p_a = nl.arange(128)[:, None]
-        ```
-        while the index for the free dimension for tile `b` would be
-        ```
-        i_f_b = nl.arange(512)[None, :]
-        ```
-    * Name loop variables, indices, and buffers consistently, and specify their intended use in the name.
-
-3. Documentation
-   * New kernels should containing inline docstrings that describe the semantics of the kernel, and provide information on the IO layout. 
-   Upon release, we generate the documentation for our kernels and merge them into the NKI API documentation which will appear in the official AWS NKI documentation. 
-
-
-## Finding Contributions to Work on
+## Finding contributions to work on
 Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any ['help wanted'](https://github.com/aws-neuron/nki-samples/labels/help%20wanted) issues is a great place to start.
 
 
@@ -106,7 +51,7 @@ For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of
 opensource-codeofconduct@amazon.com with any additional questions or comments.
 
 
-## Security Issue Notifications
+## Security issue notifications
 If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue.
 
 
diff --git a/LICENSE b/LICENSE
deleted file mode 100644
index 3b1fad4..0000000
--- a/LICENSE
+++ /dev/null
@@ -1,16 +0,0 @@
-MIT No Attribution
-
-Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
-
-Permission is hereby granted, free of charge, to any person obtaining a copy of
-this software and associated documentation files (the "Software"), to deal in
-the Software without restriction, including without limitation the rights to
-use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
-the Software, and to permit persons to whom the Software is furnished to do so.
-
-THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
-FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
-COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
-IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
-CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
\ No newline at end of file
diff --git a/LICENSE.txt b/LICENSE.txt
new file mode 100644
index 0000000..e7f39e2
--- /dev/null
+++ b/LICENSE.txt
@@ -0,0 +1 @@
+TODO: Fill LICENSE after it is finalized
\ No newline at end of file
diff --git a/README.md b/README.md
index 60602a9..2d97f6b 100644
--- a/README.md
+++ b/README.md
@@ -6,7 +6,7 @@ At the core of the Neuron SDK is the Neuron Compiler, which takes computation gr
 them into highly optimized machine code. 
 
 [NKI](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/nki) is a Python-based programming environment designed for the compiler which
-adopts commonly used NumPy and Triton-like syntax along with tile-level semantics. 
+adopts commonly used NumPy andTriton-like syntax along with tile-level semantics. 
 NKI also interoperates with the Neuron Profiler, providing insights into performance bottlenecks and instruction latencies. 
 It offers tensor printing support, standard error messaging, and built-in kernel simulation capabilities for efficient debugging purposes. 
 NKI offers two types of programming interfaces: 
@@ -16,25 +16,31 @@ enabling bare-metal access to the chip for full control.
 
 ![alt "High-level flow of NKI in the Neuron Compiler. NKI emits IR immediately before the backend-IR compilation stage"](doc_assets/high-level-nki-flow.png#center "High-Level NKI Flow")
 
-## Documentation
-The latest NKI documentation can be found on the AWS Documentation site, [here](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/nki/). 
-Documentation for NKI kernels are both inline (docstring) and available on the documentation site's 
-[kernel API reference page](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/nki/api/nki.kernels.html).
+### nki.language 
+**nki.language** enables precise control over computation and data movement on NeuronCores-- the processing units within AWS Inferentia and Trainium chips. 
+Developers can control data movement between device memory and on-chip memory explicitly using `nl.load()` and `nl.store()` operations. 
+Developers can then perform the desired computations on loaded tensors, such as element-wise operations or tensor contractions, 
+providing crucial performance improvements. Additionally, developers can control how computation is performed on different compute engines inside NeuronCores. 
+nki.language APIs are considered high-level APIs and are designed for "ease of use" for ML practitioners. 
+To achieve the best performance, developers can enlist the nki.isa APIs.
+
+![alt "Diagram of the NeuronCore Architecture. It shows 4 engines: tensor, vector, scalar, and GPSIMD, connected to SBUF memory. The tensor, vector, and scalar engines are also connected to a high-speed PSUM memory bank that supports accumulate on write. Lastly the HBM (DRAM) is connected to both SBUF and PSUM memory banks."](doc_assets/pm-nc.png#scale_50#center "NeuronCore Architecture")
+
+### nki.isa
+
+**nki.isa** provides direct access to chip instructions to offer flexibility and fine-grained control over instruction usage and performance optimizations. 
+Developers can utilize various `nki.isa` instructions using the Tensor, Vector, Scalar, GP-SIMD, and DMA engines. 
+For example, developers can use `nki.isa.nc_matmul()` to compute a matrix multiplication using Tensor Engine. 
+Alternatively, developers can use `nki.isa.activation()` to apply an activation function on every element of the input tile using Scalar Engine.
 
 ## Repository Structure
 
 ### src
 
 #### reference
-This folder contains the source code of the `neuronxcc.nki.kernels`, and they are optimized kernels from the Neuron Team serving as samples. 
-
-All kernels located in this folder have numeric accuracy tests 
+The [reference kernels](src/reference/) are optimized reference kernels. All kernels located in this folder must have all of numeric accuracy tests 
 and performance benchmarks defined in the [test](test/) directory. We also demonstrate using these kernels end-to-end in our [integration tests](test/integration/).
 
-Note that these kernels are already being deployed as part of the Neuron stack. With flash attention as an example,
-[compiling Llama models with transformers-neuronx](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/libraries/transformers-neuronx/transformers-neuronx-developer-guide.html)
-will automatically invoke the `flash_fwd` kernel in [attention.py](src/reference/attention.py). Therefore, replacing the framework operators with these NKI kernels likely won't result in extra performance benefit.
-
 
 #### tutorials
 The [tutorial kernels](src/tutorials/) are for educational purpose and include the kernels that are used in NKI guides. 
@@ -52,16 +58,65 @@ verify the numeric accuracy of the operation, and publish performance results to
 The [integration tests](tests/integration) folder contains integration tests of (selected) kernels. They verify the numeric accuracy of the model’s output, 
 and publish end-to-end performance results into the [integration benchmarks](docs/benchmarks/integration) folder.
 
-## Maintenance Policy
-NKI is currently released as **beta** while we gather feedback from our users and integrate it into the API. NKI API follow the [Neuron SDK Maintenance Policy](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/sdk-policy.html).
+## Documentation
+The latest NKI documentation can be found on the AWS Documentation site, [here](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/nki/). 
+Documentation for NKI kernels are both inline (docstring) and available on the documentation site's 
+[kernel API reference page](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/nki/api/nki.kernels.html).
 
-## Getting Help
-Have a look at the GitHub issues for this repository where you will find past issues customers have encountered with workarounds and clarifications. 
-If you cannot find a suitable issue for your use-case feel free to [file an issue](https://github.com/aws-neuron/nki-samples/issues/new) to ask for assistance or to suggest improvements. Please read [CONTRIBUTING.md](CONTRIBUTING.md) for detailed information on submitting issues.
+## Versioning
+NKI is currently released as **beta** while we gather feedback from our users and integrate it into the API. We will also be updating the NKI API as needed 
+to support new Neuron and Neuron Compiler features. While NKI is in beta we may need to make backwards-incompatible changes to incorporate feedback from 
+our users or to support new use-cases of NKI on Neuron devices. Upon releasing NKI as generally available (GA), we will commit to not making backwards 
+incompatible changes to the NKI API for any supported version of the Neuron compiler. 
 
 ## Contributing
-We invite you to join the NKI community! If you'd like to share kernels you create with the community, we welcome your contributions to this repository via
-GitHub pull-requests as well as through filed issues discussing features, bug fixes, new use-cases, and API improvements. Please see [CONTRIBUTING.md](CONTRIBUTING.md) for more information
+We invite you to join the NKI community! If you'd like to share kernels you create with the community, we welcome your contributions to this repository via. 
+GitHub pull-requests as well as through filed issues discussing features, bug fixes, new use-cases, and API improvements.
+
+### Getting Help
+Have a look at the GitHub issues for this repository where you will find past issues customers have encountered with workarounds and clarifications. 
+If you cannot find a suitable issue for your use-case feel free to file an issue asking for assistance or to suggest improvements.
+
+In addition, extensive NKI documentation can be found [here](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/nki).
+
+### Testing and Merging
+Running the binaries for a NKI kernel require Neuron devices on an AWS EC2 instance from trn1, trn1n, or inf2 instance families. 
+Details on setting up an instance can be found in [here](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/frameworks/torch/torch-setup.html).
+
+Before merging, the Neuron team will need to internally test and verify kernels work as expected. If the change is accepted, 
+we will manually merge your changes, and it will be merged here upon the next release. 
+
+If you would like to test your kernel without a requiring a Neuron device, you can use `nki.simulate()` to run your kernel using `NumPy` tensors and types. 
+An example can be found in the [layernorm tutorial test](test/unit/test_tutorials_layernorm.py).
+
+### Coding Guidelines
+Most guidelines are covered by a **PEP-8** check on all newly submitted code, which covers aspects such as code layout and basic Python naming conventions. 
+In addition to PEP-8, we use the following NKI specific style guidelines:
+
+1. **Abbreviations**
+    * Importing NKI modules should use consistent names. For example,
+    ```
+    import neuronxcc.nki as nki
+    import neuronxcc.nki.isa as nisa
+    import neuronxcc.nki.language as nl
+    import neuronxcc.nki.typing as nt
+    import numpy as np
+    ```   
+2. Variable Names
+    * Indexing should specify partition and free dimensions along with the variable they are used for. For example:
+        The index for the partition dimension for tile `a` would be
+        ```
+        i_p_a = nl.arange(128)[:, None]
+        ```
+        while the index for the free dimension for tile `b` would be
+        ```
+        i_f_b = nl.arange(512)[None, :]
+        ```
+    * Name loop variables, indices, and buffers consistently, and specify their intended use in the name.
+
+3. Documentation
+   * New kernels should containing inline docstrings that describe the semantics of the kernel, and provide information on the IO layout. 
+   Upon release, we generate the documentation for our kernels and merge them into the NKI API documentation which will appear in the official AWS NKI documentation. 
 
 ## Licensing
 This repository is licensed under the terms of the [MIT-0 License](LICENSE.txt)
\ No newline at end of file
diff --git a/src/reference/__init__.py b/src/reference/__init__.py
index ad4a18a..922dd83 100644
--- a/src/reference/__init__.py
+++ b/src/reference/__init__.py
@@ -6,7 +6,27 @@
 Kernels here are the same to the ones available in the 
 NKI Github Sample Repo.
 
-TODO: Insert link to Github Repo when available
+https://github.com/aws-neuron/nki-samples
 """
 from neuronxcc.nki.kernels.attention import fused_self_attn_for_SD_small_head_size, flash_attn_bwd, flash_fwd
 from neuronxcc.nki.kernels.vision import resize_nearest_fixed_dma_kernel, select_and_scatter_kernel
+from neuronxcc.nki.kernels.tutorial import add_kernel_nx8x128x512
+from neuronxcc.nki.kernels.allocated_attention import allocated_fused_self_attn_for_SD_small_head_size
+from neuronxcc.nki.kernels.allocated_fused_linear import allocated_fused_rms_norm_qkv
+
+from neuronxcc.nki._private_kernels.legacy.attention import \
+  (fused_self_attn_for_SD_small_head_size as _fused_self_attn_for_SD_small_head_size,
+   flash_attn_bwd as _flash_attn_bwd, flash_fwd as _flash_fwd)
+from neuronxcc.nki._private_kernels.legacy.vision import (
+  resize_nearest_fixed_dma_kernel as _resize_nearest_fixed_dma_kernel,
+  select_and_scatter_kernel as _select_and_scatter_kernel)
+from neuronxcc.nki._private_kernels.legacy.tutorial import add_kernel_nx8x128x512 as _add_kernel_nx8x128x512
+from neuronxcc.nki._private_kernels.legacy.allocated_fused_linear import _allocated_fused_rms_norm_qkv
+
+fused_self_attn_for_SD_small_head_size._legacy_func = _fused_self_attn_for_SD_small_head_size
+flash_attn_bwd._legacy_func = _flash_attn_bwd
+flash_fwd._legacy_func = _flash_fwd
+resize_nearest_fixed_dma_kernel._legacy_func = _resize_nearest_fixed_dma_kernel
+select_and_scatter_kernel._legacy_func = _select_and_scatter_kernel
+add_kernel_nx8x128x512._legacy_func = _add_kernel_nx8x128x512
+allocated_fused_rms_norm_qkv._legacy_func = _allocated_fused_rms_norm_qkv
diff --git a/src/reference/allocated_attention.py b/src/reference/allocated_attention.py
new file mode 100644
index 0000000..564412c
--- /dev/null
+++ b/src/reference/allocated_attention.py
@@ -0,0 +1,283 @@
+import functools
+import neuronxcc.nki as nki
+import neuronxcc.nki.language as nl
+import neuronxcc.nki.isa as nisa
+import neuronxcc.nki.compiler as ncc
+from neuronxcc.nki.language import par_dim
+import numpy as np
+
+@nki.jit
+def allocated_fused_self_attn_for_SD_small_head_size(q_ref, k_ref, v_ref,
+                                           use_causal_mask=False,
+                                           mixed_percision=True):
+  """
+  Allocated fused self attention kernel for small head size Stable Diffusion workload.
+  
+  Computes (softmax(Q.T@K)V).T. The wired layout is choosen to avoid transpose as
+  much as possible to simplify the debug. The kernel uses the direct allocation API,
+  and implements double buffering to achive better performance than automatic allocation.
+  As of NeuronSDK 2.21, it achieves 18% better performance than auto allocated equivalent.
+  To see the performance gap, you can use ``force_auto_alloc`` decorator to override
+  manual allocation and benchmark the performance difference.
+
+  This kernel is designed to be used for Stable Diffusion models where the 
+  n_heads is equal to 128. Seqlen must be divisible by 1024, and smaller than 5120. 
+  Assertion is thrown if ``n_heads`` or sequence length does not satisfy the requirement.
+  These restrictions are to simplify the address calculation in allocations.
+
+  IO tensor layouts:
+   - q_ptr: shape   (bs, d_heads, seq_q)
+   - k_ptr: shape   (bs, d_heads, seq_k)
+   - v_ptr: shape   (bs, seq_v, n_heads)
+   - out_ptr: shape (bs, d_heads, seq_q)
+   - We use seq_q and seq_k just for clarity, this kernel requires seq_q == seq_k
+
+  IO tensor dtypes:
+   - This kernel assumes all IO tensors have the same dtype
+   - If mixed_percision is True, then all Tensor Engine operation will be performed in
+     bfloat16 and accumulation will be performed in float32. Otherwise the intermediates
+     will be in the same type as the inputs.
+  """
+  # Use q_ref dtype as the intermediate tensor dtype
+  # Assume all IO tensors have the same dtype
+  kernel_dtype = np.float32
+  pe_in_dt = nl.bfloat16 if mixed_percision else kernel_dtype
+
+  kernel_dtype_itemsize = np.dtype(kernel_dtype).itemsize
+  pe_in_dt_itemsize = np.dtype(pe_in_dt).itemsize
+  assert q_ref.dtype == k_ref.dtype == v_ref.dtype
+
+  # Shape checking
+  bs, d_head, seqlen = q_ref.shape
+  assert d_head <= 128, "Cannot use this kernel for d_head > 128"
+  assert tuple(q_ref.shape) == (bs, d_head, seqlen), 'Input shape mismatch!'
+  assert tuple(k_ref.shape) == (bs, d_head, seqlen), 'Input shape mismatch!'
+  assert tuple(v_ref.shape) == (bs, seqlen,
+                                d_head), f'Input shape mismatch! Expected: {(bs, seqlen, d_head)} Actual: {tuple(v_ref.shape)}'
+  out_ref = nl.ndarray((bs, d_head, seqlen), dtype=q_ref.dtype, buffer=nl.shared_hbm)
+
+  assert d_head == 128
+
+  cur_addr = 0
+
+  id0 = nl.arange(0, 128)[:, None]
+  id1 = nl.arange(0, 128)[None, :]
+  identity = nl.shared_constant(np.identity(128, dtype=np.int8), dtype=nl.bfloat16)
+  identity_load = nl.ndarray((par_dim(128), 128), dtype=pe_in_dt, buffer=ncc.sbuf.mod_alloc(base_addr=cur_addr))
+  cur_addr += 128 * pe_in_dt_itemsize
+  identity_load[id0, id1] = nl.load(identity)
+
+  identity_load_fp32 = nl.ndarray((par_dim(128), 128), dtype=np.float32, buffer=ncc.sbuf.mod_alloc(base_addr=cur_addr))
+  cur_addr += 128 * np.dtype(np.float32).itemsize
+  identity_load_fp32[id0, id1] = nl.load(identity)
+
+  # Softmax scaling factor, multiplied onto Q
+  softmax_scale = 0.125
+
+  # Different batch samples/attention heads have independent attention
+  batch_id = nl.program_id(axis=0)
+
+  q_seq_n_tiles, q_seq_tile_size = seqlen // 128, 128
+  k_seq_n_tiles, k_seq_tile_size = seqlen // 512, 512
+  # No tiling on d_head dimension since the number of d_head fits in SB
+  d_head_tile_size = d_head
+  v_seq_n_tiles, v_seq_tile_size = seqlen // 128, 128
+
+  ###################################
+  # Step 1. preload tensors
+  ###################################
+  v_local = nl.ndarray((v_seq_n_tiles, par_dim(v_seq_tile_size), d_head), dtype=pe_in_dt, buffer=ncc.sbuf.mod_alloc(base_addr=cur_addr, num_free_tiles=(v_seq_n_tiles, ))) # 8kb
+  cur_addr += v_seq_n_tiles * d_head * pe_in_dt_itemsize
+
+  for i_v_seq_tile in nl.affine_range(v_seq_n_tiles):
+    ip_v = nl.arange(v_seq_tile_size)[:, None]
+    if_v = nl.arange(d_head_tile_size)[None, :]
+    v_local[i_v_seq_tile, ip_v, if_v] = nl.load(
+      v_ref[batch_id, i_v_seq_tile * v_seq_tile_size + ip_v, if_v],
+      dtype=pe_in_dt)
+
+  q_local = nl.ndarray((q_seq_n_tiles, par_dim(d_head_tile_size), q_seq_tile_size), dtype=pe_in_dt, buffer=ncc.sbuf.mod_alloc(base_addr=cur_addr, num_free_tiles=(q_seq_n_tiles, ))) # 8kb
+  cur_addr += q_seq_n_tiles * q_seq_tile_size * pe_in_dt_itemsize
+  ip_q = nl.arange(d_head_tile_size)[:, None]
+  if_q = nl.arange(q_seq_tile_size)[None, :]
+  for i_q_seq_tile in nl.affine_range(q_seq_n_tiles):
+    q_local[i_q_seq_tile, ip_q, if_q] = nl.load(
+      q_ref[batch_id, ip_q, i_q_seq_tile * q_seq_tile_size + if_q],
+      dtype=pe_in_dt)
+    q_local[i_q_seq_tile, ip_q, if_q] = nl.multiply(q_local[i_q_seq_tile, ip_q, if_q], softmax_scale)
+
+  k_local = nl.ndarray((k_seq_n_tiles, par_dim(d_head_tile_size), k_seq_tile_size), dtype=pe_in_dt, buffer=ncc.sbuf.mod_alloc(base_addr=cur_addr, num_free_tiles=(k_seq_n_tiles, ))) # 8kb
+  cur_addr += k_seq_n_tiles * k_seq_tile_size * pe_in_dt_itemsize
+  ip_k = nl.arange(d_head_tile_size)[:, None]
+  if_k = nl.arange(k_seq_tile_size)[None, :]
+  for i_k_seq_tile in nl.affine_range(k_seq_n_tiles):
+    k_local[i_k_seq_tile, ip_k, if_k] = nl.load(
+      k_ref[batch_id,
+            ip_k,
+            i_k_seq_tile * k_seq_tile_size + if_k
+            ],
+      dtype=pe_in_dt)
+
+  for i_q_seq_tile in nl.affine_range(q_seq_n_tiles//2):  # indent = 2
+    # perform activation and reduction in softmax in larger tile to amortize instruction overhead
+    reduction_size = 1024
+    reduction_tiles = seqlen // reduction_size
+
+    # =================================== SBUF Allocation Starts ===================================
+
+    # The num_free_tiles is intentionally set to (1, ) to disable double buffering on the first matmul.
+    # From the profile, when the first matmul is double buffered, the tensor_scalar_reduce instruction that writes to this buffer
+    # spends long time waiting for the matmul it depends on to be executed. The instruction scheduler made a bad decision and 
+    # clogged the pipeline when double buffering is on. This is a workaround to hint the scheduler.
+    qk_res_buf = nl.ndarray((2, par_dim(q_seq_tile_size), seqlen), dtype=kernel_dtype, buffer=ncc.sbuf.mod_alloc(base_addr=cur_addr, num_free_tiles=(1, ))) # 32 k
+    cur_addr += seqlen * kernel_dtype_itemsize
+    exp_res = nl.ndarray((2, par_dim(q_seq_tile_size), seqlen),dtype=pe_in_dt, buffer=ncc.sbuf.mod_alloc(base_addr=cur_addr, num_free_tiles=(2, ))) # 16 kb
+    cur_addr += seqlen * 2 * pe_in_dt_itemsize
+    trans_softmax_res = nl.ndarray(
+        (2, par_dim(v_seq_tile_size), seqlen), name='trans_softmax_res',
+        dtype=pe_in_dt, buffer=ncc.sbuf.mod_alloc(base_addr=cur_addr, num_free_tiles=(2, ))) # 16kb
+    cur_addr += seqlen * 2 * pe_in_dt_itemsize
+    
+    sum_divisor = nl.ndarray((2, par_dim(q_seq_tile_size), d_head_tile_size), dtype=kernel_dtype, buffer=ncc.sbuf.mod_alloc(base_addr=cur_addr, num_free_tiles=(2, ))) # 1kb
+    cur_addr += 2 * d_head_tile_size * kernel_dtype_itemsize
+    sum_reciprocal_broadcast = nl.ndarray((2, par_dim(q_seq_tile_size), d_head_tile_size), dtype=kernel_dtype, buffer=ncc.sbuf.mod_alloc(base_addr=cur_addr, num_free_tiles=(2, ))) # 1kb
+    cur_addr += 2 * d_head_tile_size * kernel_dtype_itemsize
+    
+    attn_res_sbuf = nl.ndarray((2, par_dim(d_head_tile_size), q_seq_tile_size), dtype=kernel_dtype,
+                                buffer=ncc.sbuf.mod_alloc(base_addr=cur_addr, num_free_tiles=(2, )), name="attn_res_sbuf") # 1kb
+    cur_addr += 2 * q_seq_tile_size * kernel_dtype_itemsize
+    attn_res_div = nl.ndarray((2, par_dim(q_seq_tile_size), d_head_tile_size), dtype=kernel_dtype,
+                                buffer=ncc.sbuf.mod_alloc(base_addr=cur_addr, num_free_tiles=(2,))) # 1kb
+    cur_addr += 2 * d_head_tile_size * kernel_dtype_itemsize
+    
+    neg_max_res = nl.ndarray((2, par_dim(q_seq_tile_size), k_seq_n_tiles), dtype=kernel_dtype, buffer=ncc.sbuf.mod_alloc(base_addr=cur_addr, num_free_tiles=(2, ))) # 64b
+    cur_addr += 2 * k_seq_n_tiles * kernel_dtype_itemsize
+    partial_sum_res = nl.ndarray((2, par_dim(q_seq_tile_size), reduction_tiles), dtype=kernel_dtype, buffer=ncc.sbuf.mod_alloc(base_addr=cur_addr, num_free_tiles=(2, ))) # 32b
+    cur_addr += 2 * reduction_tiles * kernel_dtype_itemsize
+    neg_max_res_final = nl.ndarray((2, par_dim(q_seq_tile_size), 1), dtype=kernel_dtype, buffer=ncc.sbuf.mod_alloc(base_addr=cur_addr, num_free_tiles=(2, ))) # 8b
+    cur_addr += 2 * 1 * kernel_dtype_itemsize
+    sum_res = nl.ndarray((2, par_dim(q_seq_tile_size), 1), dtype=kernel_dtype, buffer=ncc.sbuf.mod_alloc(base_addr=cur_addr, num_free_tiles=(2, ))) # 8b
+    cur_addr += 2 * 1 * kernel_dtype_itemsize
+    sum_reciprocal = nl.ndarray((2, par_dim(q_seq_tile_size), 1), dtype=kernel_dtype, buffer=ncc.sbuf.mod_alloc(base_addr=cur_addr, num_free_tiles=(2, ))) # 8b
+    cur_addr += 2 * 1 * kernel_dtype_itemsize
+
+    # =================================== SBUF Allocation End ===================================
+ 
+    qk_psum = nl.ndarray((2, k_seq_n_tiles, par_dim(q_seq_tile_size), k_seq_tile_size),
+                          dtype=np.float32, buffer=ncc.psum.mod_alloc(base_bank=0, num_bank_tiles=(2, 4)))
+    
+    assert k_seq_tile_size == 4 * v_seq_tile_size
+    local_tp_buf = nl.ndarray((2, k_seq_n_tiles, par_dim(q_seq_tile_size), k_seq_tile_size), dtype=np.float32,
+                                  buffer=ncc.psum.mod_alloc(base_bank=0, num_bank_tiles=(2, 4)))
+    
+    def psum_addr(bank_map, idx, pdim_size, fdim_size):
+      return (bank_map[idx], 0, 0)
+    
+    # Result psum buffer has the hidden dim as P
+    # qk_psum is using 0, 1, 2, 3 for fisrt interleave group, and 4, 5, 6, 7 for the second.
+    # assign 1 and 5 avoid bank collision between groups
+    attn_res_psum = nl.ndarray((2, par_dim(d_head_tile_size), q_seq_tile_size),
+                            dtype=np.float32, buffer=ncc.psum.alloc(functools.partial(psum_addr, bank_map={(0, ): 1, (1, ): 5})))
+
+    sum_local_tp_buf = nl.ndarray((2, par_dim(q_seq_tile_size), k_seq_tile_size), dtype=np.float32,
+                                  buffer=ncc.psum.alloc(functools.partial(psum_addr, bank_map={(0, ): 2, (1, ): 7})))
+
+    for i_interleave_grp in nl.affine_range(2):
+      # A SBUF buffer tile for an independent softmax tile
+      ip_max = nl.arange(q_seq_tile_size)[:, None]
+      if_max = nl.arange(k_seq_n_tiles)[None, :]
+
+      # Loop over RHS free of matmul(stationary=tensor_q, moving=tensor_k, contract=d_head)
+      for i_k_seq_tile in nl.affine_range(k_seq_n_tiles):  # indent = 4
+
+        # Tensor indices for accessing qk result in k_seq_tile_size
+        ip_qk = nl.arange(q_seq_tile_size)[:, None]
+        if_qk = nl.arange(k_seq_tile_size)[None, :]
+
+        ##############################################################
+        # Step 2. matmul(stationary=tensor_q, moving=tensor_k, contract=d_head)
+        ##############################################################
+        qk_psum[i_interleave_grp, i_k_seq_tile, ip_qk, if_qk] = nisa.nc_matmul(moving=k_local[i_k_seq_tile, ip_k, if_k],
+                                                stationary=q_local[i_q_seq_tile*2+i_interleave_grp, ip_q, if_q])
+
+        ###################################
+        # Step 3. Apply optional causal mask
+        ###################################
+        if use_causal_mask:
+          assert not use_causal_mask, "Causal mask not supported yet!"
+          # Magic number -9984.0 to replace -inf similar to what Tensorizer uses
+          qk_res_buf[i_interleave_grp, ip_qk, i_k_seq_tile * k_seq_tile_size + if_qk] = nisa.affine_select(
+            pred=(i_q_seq_tile * q_seq_tile_size + ip_qk >= i_k_seq_tile * k_seq_tile_size + if_qk),
+            on_true_tile=qk_psum[i_interleave_grp, i_k_seq_tile,ip_qk, if_qk], on_false_value=-9984.0, dtype=kernel_dtype)
+        else:
+          # Copy result to SBUF and find partial maximum for softmax
+          qk_res_buf[i_interleave_grp, ip_qk, i_k_seq_tile * k_seq_tile_size + if_qk] = nisa.tensor_scalar_reduce(qk_psum[i_interleave_grp, i_k_seq_tile,ip_qk, if_qk], np.add, 1.0,
+              reduce_op=np.max, reduce_res=neg_max_res[i_interleave_grp, ip_max, i_k_seq_tile], dtype=kernel_dtype)
+
+      # Find global max from tiles
+      neg_max_res_final[i_interleave_grp, ip_max, 0] = nisa.tensor_reduce(
+        np.max, data=neg_max_res[i_interleave_grp, ip_max, if_max],
+        axis=(1,), dtype=kernel_dtype, negate=True)
+
+      ip_softmax = nl.arange(q_seq_tile_size)[:, None]
+      if_softmax = nl.arange(seqlen)[None, :]
+      ip_sum_res = nl.arange(q_seq_tile_size)[:, None]
+      if_sum_res = nl.arange(d_head_tile_size)[None, :]
+
+      if_reduction = nl.arange(reduction_size)[None, :]
+      for i_exp in nl.affine_range(reduction_tiles):
+        exp_res[i_interleave_grp, ip_softmax, i_exp*reduction_size + if_reduction] = nisa.activation_reduce(np.exp,
+          data=qk_res_buf[i_interleave_grp, ip_softmax, i_exp * reduction_size + if_reduction],
+          reduce_op=np.sum, reduce_res=partial_sum_res[i_interleave_grp, ip_softmax, i_exp],
+          bias=neg_max_res_final[i_interleave_grp, ip_max, 0], scale=1.0,                                                                                          
+        )
+
+      sum_res[i_interleave_grp, ip_softmax, 0] = nisa.tensor_reduce(np.add, data=partial_sum_res[i_interleave_grp, :, :], axis=(1,),
+                            dtype=kernel_dtype)
+      
+      sum_reciprocal[i_interleave_grp, ip_softmax, 0] = nl.divide(1.0, sum_res[i_interleave_grp, ip_softmax, 0])
+      sum_reciprocal_broadcast[i_interleave_grp, ip_softmax, if_sum_res] = sum_reciprocal[i_interleave_grp, ip_softmax, 0].broadcast_to((q_seq_tile_size, d_head_tile_size))
+      sum_divisor[i_interleave_grp, ip_sum_res, if_sum_res] = nl.copy(sum_reciprocal_broadcast[i_interleave_grp, ip_softmax, if_sum_res], dtype=kernel_dtype)
+
+      ###################################
+      # Step 5. transpose(softmax_res)
+      ###################################
+      ip_scores_t = nl.arange(v_seq_tile_size)[:, None]
+      if_scores_t = nl.arange(v_seq_tile_size)[None, :]
+      # Loop over matmul_1 contraction
+      for i_v_seq_tile in nl.affine_range(v_seq_n_tiles // 4):
+        for i_offset in nl.affine_range(4):
+          ip_scores = nl.arange(v_seq_tile_size)[:, None]
+          if_scores = nl.arange(v_seq_tile_size)[None, :]
+          
+          local_tp_buf[i_interleave_grp, i_v_seq_tile, ip_scores, i_offset*v_seq_tile_size + if_scores] = nisa.nc_matmul(
+            exp_res[i_interleave_grp, ip_scores, (i_v_seq_tile*4+i_offset) * v_seq_tile_size + if_scores],
+            identity_load)
+
+        if_batch = nl.arange(k_seq_tile_size)[None, :]
+        trans_softmax_res[i_interleave_grp, ip_scores_t, i_v_seq_tile*k_seq_tile_size + if_batch] = nl.copy(local_tp_buf[i_interleave_grp, i_v_seq_tile, ip_scores, if_batch])
+
+      ip_out = nl.arange(d_head_tile_size)[:, None]
+      if_out = nl.arange(q_seq_tile_size)[None, :]
+
+      for i_v_seq_tile in nl.affine_range(v_seq_n_tiles):
+        ######################################################################
+        # Step 6. matmul_1(stationary=v_local, moving=trans_softmax_res, contract=seqlen_v=seqlen_k)
+        ######################################################################
+        ip_v_t = nl.arange(v_seq_tile_size)[:, None]
+        if_v_t = nl.arange(d_head_tile_size)[None, :]
+        attn_res_psum[i_interleave_grp, ip_out, if_out] += \
+          nisa.nc_matmul(moving=trans_softmax_res[i_interleave_grp, ip_scores_t, i_v_seq_tile*v_seq_tile_size+if_scores_t],
+                        stationary=v_local[i_v_seq_tile, ip_v_t, if_v_t])
+      
+      attn_res_sbuf[i_interleave_grp, ip_out, if_out] = nisa.tensor_copy(attn_res_psum[i_interleave_grp, ip_out, if_out], 
+                                                                    dtype=kernel_dtype, engine=nisa.vector_engine)
+
+      sum_local_tp_buf[i_interleave_grp, ip_sum_res, if_sum_res] = nisa.nc_matmul(sum_divisor[i_interleave_grp, ip_sum_res, if_sum_res], identity_load_fp32)
+      attn_res_div[i_interleave_grp, ip_sum_res, if_sum_res] = attn_res_sbuf[i_interleave_grp, :, :] * sum_local_tp_buf[i_interleave_grp, ip_sum_res, if_sum_res]
+
+      nl.store(
+        out_ref[batch_id, ip_out, (i_q_seq_tile*2+i_interleave_grp) * q_seq_tile_size + if_out],
+        value=attn_res_div[i_interleave_grp, :, :])
+      
+  return out_ref
\ No newline at end of file
diff --git a/src/reference/allocated_fused_linear.py b/src/reference/allocated_fused_linear.py
new file mode 100644
index 0000000..21e32af
--- /dev/null
+++ b/src/reference/allocated_fused_linear.py
@@ -0,0 +1,114 @@
+"""
+Copyright (c) 2024, Amazon.com. All Rights Reserved
+
+kernels - Fused normalization with linear layers
+
+"""
+
+import neuronxcc.nki.language as nl
+import neuronxcc.nki.isa as nisa
+import neuronxcc.nki.compiler as ncc
+import math
+import numpy as np
+from neuronxcc import nki
+from neuronxcc.nki.language import par_dim
+
+@nki.jit
+def allocated_fused_rms_norm_qkv(hidden, weights, norm_dtype=nl.float32, eps=1e-6):
+  """
+  Allocated kernel that computes RMSNorm(hidden) @ wQKV. This kernel is designed to only handle fp16/bf16 tensor types.
+  Internally, normalizations are cast to fp32 to avoid NaN errors.
+
+  Args:
+      hidden (_type_): Input tensor of the attention block in BSH layout
+      weights (_type_): Fused QKV linear weights, assumed to be eltwise-multiplied with RMS norm weight vector (gamma)
+      out_tensor (_type_): Output tensor
+      norm_dtype (_type_, optional): Data type for RMS norm, should be f32 to avoid NaN. Defaults to nl.float32.
+      eps (_type_, optional): RMS norm epsilon term. Defaults to 1e-6.
+  """
+  # Hidden should be in BSH layout.
+  batch, batchless_shape = hidden.shape[0], hidden.shape[1:]
+  seqlen, dim = batchless_shape
+  _dim, head_dim = weights.shape
+
+  assert dim <= 8192 and dim & 128 == 0, "Unsupported hidden dimension"
+  assert _dim == dim, "Reduction dimension must match"
+  assert head_dim <= 512, "Head dimension must be 512 or less"
+
+  out_tensor = nl.ndarray((batch, seqlen, head_dim), dtype=hidden.dtype, buffer=nl.shared_hbm)
+
+  pmax, fmax = nl.tile_size.pmax, nl.tile_size.psum_fmax # 128, 512
+  ix, iy = nl.mgrid[0:pmax, 0:dim]
+  i_lhs = nl.mgrid[0:pmax, 0:pmax]
+  i_rhs = nl.mgrid[0:pmax, 0:fmax]
+  i_res = nl.mgrid[0:pmax, 0:fmax]
+  M = math.ceil(dim / pmax)
+  NUM_TRANSP_TILES = math.ceil(dim / fmax)
+  NUM_TILES = math.ceil(seqlen / pmax)
+  TILES_INT = math.ceil(NUM_TILES / 2)
+  scale = 1 / dim
+
+  iden_x, iden_y = nl.mgrid[0:pmax, 0:128]
+
+  identity_a = nl.shared_constant(np.identity(n=128, dtype=np.int8), dtype=hidden.dtype)
+  identity_tensor = nl.ndarray((par_dim(pmax), 128), dtype=weights.dtype, buffer=ncc.sbuf.mod_alloc(base_addr=0))
+  identity_tensor[iden_x, iden_y] = nl.load(identity_a, dtype=weights.dtype)
+  bias_placeholder = nl.ndarray((par_dim(pmax), 1), dtype=np.float32, buffer=ncc.sbuf.mod_alloc(base_addr=128*2))
+  bias_placeholder[...] = 0
+  
+  for b in nl.affine_range(batch):
+    weights_buffer = nl.ndarray((M, par_dim(pmax), fmax), dtype=weights.dtype,
+                                buffer=ncc.sbuf.mod_alloc(base_addr=260+(3*dim+fmax)*2+(dim+1)*4, num_free_tiles=(M,)))
+    # Preload the entire weights tensor. everything fits in SBUF for LLaMA 3.1 70B
+    for m in nl.affine_range(M):
+      weights_buffer[m, i_rhs.p, i_rhs.x] = nl.load(weights[m*pmax+i_rhs.p, i_rhs.x],
+                                                    mask=(m*pmax+i_rhs.p<dim) & (i_rhs.x<head_dim))
+    for i in nl.affine_range(TILES_INT):
+      # Double buffer the input tensor
+      in_bufs = nl.ndarray((2, par_dim(pmax), dim), dtype=hidden.dtype, buffer=ncc.sbuf.mod_alloc(base_addr=260, num_free_tiles=(2,)))
+      for i_interleave_grp in nl.affine_range(2):
+        in_bufs[i_interleave_grp] = nl.load(hidden[b, (2*i+i_interleave_grp)*pmax+ix, iy], mask=(2*i+i_interleave_grp)*pmax+ix < seqlen)
+        act = nl.ndarray((par_dim(pmax), dim), dtype=norm_dtype, buffer=ncc.sbuf.mod_alloc(base_addr=260+(2*dim)*2))
+
+        # Write the RMS and RMS Reciprocal tensors back out here, in-place
+        square_sum = nl.ndarray((par_dim(pmax), 1), dtype=norm_dtype, buffer=ncc.sbuf.mod_alloc(base_addr=260+(2*dim)*2+(dim)*4))
+
+        # Write the output of RMS and RMS^T (in-place) out to here
+        out_tile = nl.ndarray((par_dim(pmax), dim), dtype=weights.dtype,
+                              buffer=ncc.sbuf.mod_alloc(base_addr=260+(2*dim)*2+(dim+1)*4))
+        
+        # Store the final output tiles to here before sending back to DRAM
+        output_sbuf = nl.ndarray((par_dim(pmax), fmax), dtype=weights.dtype,
+                                buffer=ncc.sbuf.mod_alloc(base_addr=260+(3*dim)*2+(dim+1)*4))
+
+        act[...] = nisa.activation_reduce(op=nl.square, data=in_bufs[i_interleave_grp], reduce_op=np.add, reduce_res=square_sum[...], bias=bias_placeholder[...])
+        square_sum[...] = nisa.tensor_scalar(square_sum[...], np.multiply, scale, op1=np.add, operand1=eps)
+        square_sum[...] = nisa.activation(op=nl.rsqrt, data=square_sum[...], bias=bias_placeholder[...])
+
+        # all PE array ops must output to FP32 on trn1 but must match input dtype in trn2
+        if nisa.get_nc_version() == nisa.nc_version.gen3:
+          transpose_res_psum = nl.ndarray((NUM_TRANSP_TILES, par_dim(pmax), 4*pmax), dtype=weights.dtype,
+                                          buffer=ncc.psum.mod_alloc(base_bank=0, num_bank_tiles=(1,))) # FIXME: perf is better when all tiles are on bank 0?
+        else:
+          transpose_res_psum = nl.ndarray((NUM_TRANSP_TILES, par_dim(pmax), 4*pmax), dtype=np.float32,
+                                          buffer=ncc.psum.mod_alloc(base_bank=0, num_bank_tiles=(1,))) # FIXME: perf is better when all tiles are on bank 0?
+
+        for m in nl.affine_range(NUM_TRANSP_TILES):
+          # Perform (hidden .* RMS Reciprocal)^T in tiles of fmax (512)
+          out_tile[i_rhs.p, m*fmax+i_rhs.x] = nl.multiply(in_bufs[i_interleave_grp, i_rhs.p, m*fmax + i_rhs.x], square_sum[...], dtype=weights.dtype)
+          for j in nl.affine_range(4):
+            transpose_res_psum[m, i_lhs.p, j*pmax+i_lhs.x] = nisa.nc_matmul(out_tile[i_lhs.p, (m*4+j) * pmax + i_lhs.x], identity_tensor[...],
+                                                                            is_transpose=True)
+          out_tile[i_rhs.p, m * 4*pmax + i_rhs.x] = nl.copy(transpose_res_psum[m], dtype=hidden.dtype)
+        
+        # perform (RMSNorm(hidden)^T)^T @ wQKV
+        res_psum = nl.ndarray((1, par_dim(pmax), fmax), dtype=nl.float32,
+                              buffer=ncc.psum.mod_alloc(base_bank=7, num_bank_tiles=(1,)))
+        for m in nl.affine_range(M):
+          res_psum[0] += nisa.nc_matmul(out_tile[i_lhs.p, m*pmax+i_lhs.x], weights_buffer[m, i_rhs.p, i_rhs.x])
+        
+        output_sbuf[...] = nl.copy(res_psum[0], dtype=out_tensor.dtype)
+        nl.store(out_tensor[b, (2*i+i_interleave_grp)*pmax+i_res.p, i_res.x],
+                value=output_sbuf,
+                mask=((2*i+i_interleave_grp)*pmax+i_res.p<seqlen) & (i_res.x<head_dim))
+  return out_tensor
\ No newline at end of file
diff --git a/src/reference/attention.py b/src/reference/attention.py
index 81704b5..9bf0444 100644
--- a/src/reference/attention.py
+++ b/src/reference/attention.py
@@ -6,12 +6,23 @@
 """
 import numpy as np
 
-from neuronxcc.nki import trace
 import neuronxcc.nki.isa as nisa
 import neuronxcc.nki.language as nl
+from neuronxcc import nki
 
 from neuronxcc.nki.language import par_dim
 from dataclasses import dataclass
+from functools import reduce as functools_reduce
+from operator import mul as operator_mul
+
+def n_elts(shape):
+  return functools_reduce(operator_mul, shape, 1)
+
+
+def linearize(shape, indices):
+  return sum(i * (n_elts(shape[dim + 1:]))
+             for dim, i in enumerate(indices))
+
 
 def div_ceil(n, d):
   return (n + d - 1) // d
@@ -31,9 +42,8 @@ class FlashConfig:
     'should_transpose_v': bool
   }
 
-@trace
 def _flash_attention_core(q_local_tile, k, v,
-                          q_h_per_k_h,
+                          q_h_per_k_h, seqlen_q, nheads,
                           o_buffer, l_buffer, m_buffer,
                           batch_id, head_id, gqa_head_idx, q_tile_idx,
                           local_k_large_tile_idx,
@@ -48,7 +58,7 @@ def _flash_attention_core(q_local_tile, k, v,
   """
   The flash attention core function to calcualte self attention between a tile of q and a block of K and V.
   The q_local_tile has (B_P_SIZE, B_F_SIZE), which is loaded into the SBUF already. The block size of K and V
-  is defined in the seq_tile_size of the flash_config. The results are stored in the following there buffers
+  is defined in the seq_tile_size of the flash_config. The results are stored in the following three buffers
   o_buffer: (num_large_k_tile, B_P_SIZE, d)
   l_buffer: (num_large_k_tile, B_P_SIZE, 1)
   m_buffer: (num_large_k_tile, B_P_SIZE, 1)
@@ -56,8 +66,9 @@ def _flash_attention_core(q_local_tile, k, v,
   LARGE_TILE_SZ = flash_config.seq_tile_size
   REDUCTION_TILE = min(2048, LARGE_TILE_SZ // 2)
   num_k_tile_per_large_tile = LARGE_TILE_SZ // B_F_SIZE
-  seq_len = k.shape[-1]
-  seq_q_num_tiles = seq_len // B_P_SIZE
+  seqlen_k = k.shape[-1]
+  seq_q_num_tiles = seqlen_q // B_P_SIZE
+  seq_k_num_tiles = seqlen_k // B_F_SIZE
 
   # Indices used by the distributed attention
   if global_k_large_tile_idx is None:
@@ -100,7 +111,7 @@ def _flash_attention_core(q_local_tile, k, v,
       q_pos = q_tile_idx * B_P_SIZE + i_q_p
       k_pos = global_k_large_tile_idx * LARGE_TILE_SZ + k_i * B_F_SIZE + i_q_f
       pred = q_pos >= k_pos
-      # For tiles on and on the right of the diagonal, need to do affine_select.
+      # For tiles on and to the right of the diagonal, need to do affine_select.
       # Magic number -9984.0 to replace -inf similar to what Tensorizer uses
       qk_res_buf[i_q_p, k_i * B_F_SIZE + i_q_f] = nisa.affine_select(
         pred=pred,
@@ -154,9 +165,9 @@ def _flash_attention_core(q_local_tile, k, v,
       for k_d_i in nl.sequential_range(REDUCTION_TILE // B_F_SIZE):
         offset = k_d_i + k_r_i * (REDUCTION_TILE // B_F_SIZE) \
                   + global_k_large_tile_idx * (LARGE_TILE_SZ // B_F_SIZE) \
-                  + q_tile_idx * (seq_len // B_F_SIZE) \
-                  + (head_id * q_h_per_k_h + gqa_head_idx) * (seq_len // B_F_SIZE) * seq_q_num_tiles \
-                  + batch_id * nl.num_programs(1) * (seq_len // B_F_SIZE) * seq_q_num_tiles
+                  + q_tile_idx * seq_k_num_tiles \
+                  + (head_id * q_h_per_k_h + gqa_head_idx) * seq_k_num_tiles * seq_q_num_tiles \
+                  + batch_id * nheads * seq_k_num_tiles * seq_q_num_tiles
         offset_seed = nl.add(seed_tensor[0, 0], offset, mask=forward_mask)
         nl.random_seed(seed=offset_seed, mask=forward_mask)
         softmax_dropout = nl.dropout(p_local[i_q_p, k_r_i * REDUCTION_TILE + k_d_i * B_F_SIZE + i_q_f],
@@ -172,9 +183,9 @@ def _flash_attention_core(q_local_tile, k, v,
   for i_p_t in nl.affine_range(LARGE_TILE_SZ // 512):
     p_local_t_tmp = nl.ndarray((par_dim(B_P_SIZE), 512), buffer=nl.psum, dtype=np.float32)
     for i_p_t_local in nl.affine_range(512//128):
-      p_local_t_tmp[i_q_p, i_p_t_local*128 + i_f_128] = nisa.nc_transpose(p_local[i_q_p, i_p_t*512+i_p_t_local * B_P_SIZE + i_f_128])
+      p_local_t_tmp[i_q_p, i_p_t_local*128 + i_f_128] = nisa.nc_transpose(p_local[i_q_p, i_p_t*512+i_p_t_local * B_P_SIZE + i_f_128], mask=forward_mask)
     i_f_512 = nl.arange(512)[None, :]
-    p_local_transposed[i_q_p, i_p_t * 512 + i_f_512 ] = nl.copy(p_local_t_tmp[i_q_p, i_f_512], dtype=kernel_dtype)
+    p_local_transposed[i_q_p, i_p_t * 512 + i_f_512 ] = nl.copy(p_local_t_tmp[i_q_p, i_f_512], dtype=kernel_dtype, mask=forward_mask)
 
   ps = nl.sum(p_partial_sum, axis=1, dtype=acc_type, mask=forward_mask)
   pv_psum = nl.zeros((par_dim(B_P_SIZE), B_D_SIZE), dtype=np.float32, buffer=nl.psum)
@@ -199,7 +210,8 @@ def _flash_attention_core(q_local_tile, k, v,
       l_buffer[olm_buffer_idx, i_q_p, 0] = nl.copy(l_buffer[olm_buffer_idx-1, i_q_p, 0], mask=negation_mask)
 
 
-def flash_fwd(q, k, v, seed, o, lse=None,
+@nki.jit
+def flash_fwd(q, k, v, seed,
               softmax_scale=None,
               use_causal_mask=True,
               mixed_precision=True,
@@ -213,8 +225,8 @@ def flash_fwd(q, k, v, seed, o, lse=None,
     - v: shape   (bs, nv_heads, d, seq_v) if config.should_transpose_v  else (bs, nv_heads, seq_v, d)
     - seed: shape (1,)
     - o: shape (bs, n_heads, seq_q, d)
-    - lse: shape (bs, nheads, nl.tile_size.pmax, seq // nl.tile_size.pmax) if training else None
-    - We use seq_q and seq_k just for clarity, this kernel requires seq_q == seq_k
+    - lse: shape (bs, n_heads, nl.tile_size.pmax, seq // nl.tile_size.pmax) if training else None
+    - This kernel requires seq_k == seq_v
 
   IO tensor dtypes:
     - This kernel assumes all IO tensors have the same dtype
@@ -246,36 +258,48 @@ def flash_fwd(q, k, v, seed, o, lse=None,
   config = config or FlashConfig()
   B_F_SIZE=512
   B_P_SIZE=128
-  b , h, d, n  = q.shape
+  b, h, d, seqlen_q  = q.shape
   B_D_SIZE = d
-  k_h = k.shape[1]
-  v_shape = v.shape
+  _, k_h, _, seqlen_k = k.shape
   if config.should_transpose_v:
-    assert tuple(v_shape) == (b, k_h, d, n), f"V shape does not match layout requirements, expect: {(b, k_h, d, n)} but got {v_shape}"
-    assert tuple(k.shape) == (b, k_h, d, n), f" k and v shape does not match the layout defined in the function, but got {k.shape}"
+    assert tuple(v.shape) == (b, k_h, d, seqlen_k), f"Expect shape of V to be {(b, k_h, d, seqlen_k)} (batch, heads, d_head, seqlen_k) but got {v.shape}"
+    assert tuple(k.shape) == (b, k_h, d, seqlen_k), f"Expect shape of K to be {(b, k_h, d, seqlen_k)} (batch, heads, d_head, seqlen_k) but got {k.shape}"
   else:
-    assert tuple(v_shape) == (b, k_h, n, d), f"V shape does not match layout requirements, expect: {(b, k_h, n, d)} but got {v_shape}"
-    assert tuple(k.shape) == (b,k_h, d, n), f" k and v shape does not match the layout defined in the function, but got {k.shape}"
+    assert tuple(v.shape) == (b, k_h, seqlen_k, d), f"Expect shape of V to be {(b, k_h, seqlen_k, d)} (batch, heads, seqlen_k, d_head) but got {v.shape}"
+    assert tuple(k.shape) == (b, k_h, d, seqlen_k), f"Expect shape of K to be {(b, k_h, d, seqlen_k)} (batch, heads, d_head, seqlen_k) but got {k.shape}"
   assert d <= 128, f" we do not support head_dim > 128, got head dim {d}"
   kernel_dtype = nl.bfloat16 if mixed_precision else q.dtype
-  acc_type =  np.dtype(np.float32) if mixed_precision else kernel_dtype
+  acc_type = np.dtype(np.float32) if mixed_precision else kernel_dtype
+
+  o = nl.ndarray((b, h, seqlen_q, d), dtype=q.dtype, buffer=nl.shared_hbm)
+  if config.training:
+    lse = nl.ndarray((b, h, nl.tile_size.pmax, seqlen_q // nl.tile_size.pmax),
+                     dtype=acc_type, buffer=nl.shared_hbm)
+  else:
+    lse = None
 
   i_q_p = nl.arange(B_P_SIZE)[:,None]
   i_0_f = nl.arange(1)[None, :]
-  n_tile_q = n//B_P_SIZE # since q will be loaded on PE
 
   batch_id = nl.program_id(axis=0)
-  head_id = nl.program_id(axis=1)
+
+  head_dims = list(range(1, nl.program_ndim()))
+  head_dims_shape = list(nl.num_programs(i) for i in head_dims)
+  head_dims_idx = list(nl.program_id(i) for i in head_dims)
+  head_id = linearize(head_dims_shape, head_dims_idx)
+
   softmax_scale = softmax_scale or (1.0 / (d ** 0.5))
 
+  n_tile_q = seqlen_q // B_P_SIZE # since q will be loaded on tensor engine
+
   LARGE_TILE_SZ = config.seq_tile_size
   # FIXME: Add masking for different seqlen values.
   assert config.seq_tile_size >= 512, f" seq tile_size {config.seq_tile_size} cannot be less than 512"
-  assert n % LARGE_TILE_SZ == 0, f"seqlen is not divisible by {LARGE_TILE_SZ}"
-  num_large_k_tile = n // LARGE_TILE_SZ
+  assert seqlen_k % LARGE_TILE_SZ == 0, f"Need seqlen_k to be divisible by {LARGE_TILE_SZ} but got {seqlen_k}"
+  num_large_k_tile = seqlen_k // LARGE_TILE_SZ
 
   # inference flag, check if lse is none
-  inference = not(config.training)
+  inference = not config.training
   if inference:
     assert lse is None, "lse should be none for inference"
     assert seed is None, f"seed should be None for inference, but got {seed}"
@@ -331,11 +355,13 @@ def flash_fwd(q, k, v, seed, o, lse=None,
       i_f_d = nl.arange(B_D_SIZE)[None, :]
       i_p_d = nl.arange(B_D_SIZE)[:,None]
       q_tile = nl.ndarray((B_D_SIZE, B_P_SIZE),dtype=kernel_dtype)
-      q_tile[i_p_d, i_f_128] = nl.load(q[batch_id, head_id * q_h_per_k_h + i_q_h, i_p_d, i*B_P_SIZE+i_f_128], dtype=kernel_dtype) \
-                                * softmax_scale # load (d, 128) tile in SBUF
+      q_tile[i_p_d, i_f_128] = nl.load(q[batch_id,
+                                        head_id * q_h_per_k_h + i_q_h, i_p_d,
+                                        i * B_P_SIZE + i_f_128],
+                                      dtype=kernel_dtype) * softmax_scale # load (d, 128) tile in SBUF
       # handle first tile and compute max and lse explicitly by passing initialize=True
       _flash_attention_core(q_local_tile=q_tile, k=cur_k_tile, v=cur_v_tile,
-                            q_h_per_k_h=q_h_per_k_h,
+                            q_h_per_k_h=q_h_per_k_h, seqlen_q=seqlen_q, nheads=h,
                             o_buffer=o_buffer[i], l_buffer=l_buffer[i], m_buffer=m_buffer[i],
                             batch_id=batch_id, head_id=head_id,
                             gqa_head_idx=i_q_h, q_tile_idx=i, local_k_large_tile_idx=0,
@@ -376,10 +402,12 @@ def flash_fwd(q, k, v, seed, o, lse=None,
         i_f_d = nl.arange(B_D_SIZE)[None, :]
         i_p_d = nl.arange(B_D_SIZE)[:,None]
         q_tile = nl.ndarray((B_D_SIZE, B_P_SIZE),dtype=kernel_dtype)
-        q_tile[i_p_d, i_f_128] = nl.load(q[batch_id, head_id * q_h_per_k_h + i_q_h, i_p_d, i*B_P_SIZE+i_f_128], dtype=kernel_dtype) \
-                                  * softmax_scale # load (d, 128) tile in SBUF
+        q_tile[i_p_d, i_f_128] = nl.load(q[batch_id,
+                head_id * q_h_per_k_h + i_q_h, i_p_d,
+                i * B_P_SIZE + i_f_128],
+              dtype=kernel_dtype) * softmax_scale # load (d, 128) tile in SBUF
         _flash_attention_core(q_local_tile=q_tile, k=cur_k_tile, v=cur_v_tile,
-                              q_h_per_k_h=q_h_per_k_h,
+                              q_h_per_k_h=q_h_per_k_h, seqlen_q=seqlen_q, nheads=h,
                               o_buffer=o_buffer[i], l_buffer=l_buffer[i], m_buffer=m_buffer[i],
                               batch_id=batch_id, head_id=head_id,
                               gqa_head_idx=i_q_h, q_tile_idx=i, local_k_large_tile_idx=j,
@@ -396,19 +424,23 @@ def flash_fwd(q, k, v, seed, o, lse=None,
                                       nl.exp(m_buffer[i, num_large_k_tile - 1, i_q_p, i_0_f] - l_buffer[i, num_large_k_tile - 1, i_q_p, i_0_f]),
                                       dtype=kernel_dtype)
 
-      nl.store(o[batch_id, head_id * q_h_per_k_h + i_q_h, i * B_P_SIZE + i_q_p, i_f_d], out[i_q_p, i_f_d])
+      nl.store(o[batch_id, head_id * q_h_per_k_h + i_q_h, i*B_P_SIZE + i_q_p, i_f_d], out[i_q_p, i_f_d])
       if not inference:
         lse_local = nl.zeros((par_dim(B_P_SIZE), 1), dtype=acc_type)
         lse_local[i_q_p, i_0_f] = nl.copy(l_buffer[i, num_large_k_tile - 1, i_q_p, i_0_f], dtype=acc_type)
         nl.store(lse[batch_id, head_id * q_h_per_k_h + i_q_h, i_q_p, i + i_0_f], lse_local[i_q_p, i_0_f])
 
+  if config.training:
+    return o, lse
+
+  return o
 
+@nki.jit
 def flash_attn_bwd(
   q_ref, k_ref, v_ref, o_ref,
   dy_ref,
   lse_ref,
   seed_ref,
-  out_dq_ref, out_dk_ref, out_dv_ref,
   use_causal_mask=False,
   mixed_precision=False,
   dropout_p=0.0,
@@ -454,56 +486,60 @@ def flash_attn_bwd(
   kernel_dtype = q_ref.dtype
   mixed_dtype = np.dtype(np.float32) if mixed_precision else kernel_dtype
 
-  assert q_ref.dtype == k_ref.dtype == v_ref.dtype == o_ref.dtype == dy_ref.dtype \
-         == out_dq_ref.dtype == out_dk_ref.dtype == out_dv_ref.dtype
+  assert q_ref.dtype == k_ref.dtype == v_ref.dtype == o_ref.dtype == dy_ref.dtype
   assert lse_ref.dtype == mixed_dtype
 
   # Shape checking
-  bs, nheads, d_head, seqlen = q_ref.shape
-  assert tuple(k_ref.shape) == (bs, nheads, d_head, seqlen), \
+  bs, nheads, d_head, seqlen_q = q_ref.shape
+  _, _, _, seqlen_k = k_ref.shape
+  assert tuple(k_ref.shape) == (bs, nheads, d_head, seqlen_k), \
     f"Input K shape mismatch, got {k_ref.shape}"
-  assert tuple(v_ref.shape) == (bs, nheads, d_head, seqlen), \
+  assert tuple(v_ref.shape) == (bs, nheads, d_head, seqlen_k), \
     f"Input V shape mismatch, got {v_ref.shape}"
-  assert tuple(o_ref.shape) == (bs, nheads, d_head, seqlen), \
+  assert tuple(o_ref.shape) == (bs, nheads, d_head, seqlen_q), \
     f"Input o shape mismatch, got {o_ref.shape}"
-  assert tuple(dy_ref.shape) == (bs, nheads, d_head, seqlen), \
+  assert tuple(dy_ref.shape) == (bs, nheads, d_head, seqlen_q), \
     f"Input dy shape mismatch, got {dy_ref.shape}"
-  assert tuple(lse_ref.shape) == (bs, nheads, nl.tile_size.pmax, seqlen // nl.tile_size.pmax), \
+  assert tuple(lse_ref.shape) == (bs, nheads, nl.tile_size.pmax, seqlen_q // nl.tile_size.pmax), \
     f"Input lse shape mismatch, got {lse_ref.shape}"
   if seed_ref is not None:
     assert tuple(seed_ref.shape) == (1,), \
       f"Input seed shape mismatch, got {seed_ref.shape}"
 
-  assert tuple(out_dq_ref.shape) == (bs, nheads, d_head, seqlen), \
-    f"Output dQ shape mismatch, got {out_dq_ref.shape}"
-  assert tuple(out_dk_ref.shape) == (bs, nheads, d_head, seqlen), \
-    f"Output dK shape mismatch, got {out_dk_ref.shape}"
-  assert tuple(out_dv_ref.shape) == (bs, nheads, d_head, seqlen), \
-    f"Output dV shape mismatch, got {out_dv_ref.shape}"
+  out_dq_ref = nl.ndarray((bs, nheads, d_head, seqlen_q), dtype=q_ref.dtype,
+                          buffer=nl.shared_hbm)
+  out_dk_ref = nl.ndarray((bs, nheads, d_head, seqlen_k), dtype=q_ref.dtype,
+                          buffer=nl.shared_hbm)
+  out_dv_ref = nl.ndarray((bs, nheads, d_head, seqlen_k), dtype=q_ref.dtype,
+                          buffer=nl.shared_hbm)
 
   # FIXME: Add masking for different seqlen values.
-  assert seqlen % 128 == 0, \
-    f"Input sequence length must be divisible by 128, got {seqlen}"
+  assert seqlen_q % 128 == 0 and seqlen_k % 128 == 0, \
+    f"Input sequence lengths must be divisible by 128, got seqlen_q == {seqlen_q} and seqlen_k == {seqlen_k}"
 
   # Softmax scaling factor, multiplied onto Q
   softmax_scale = softmax_scale or 1.0 / float(d_head ** 0.5)
 
   # Different batch samples/attention heads have independent attention
   batch_id = nl.program_id(axis=0)
-  head_id = nl.program_id(axis=1)
 
-  assert nl.num_programs(1) == nheads, \
-    f"The grid shape mismatch, got {nl.num_programs(1)} but should be {nheads}"
+  head_dims = list(range(1, nl.program_ndim()))
+  head_dims_shape = list(nl.num_programs(i) for i in head_dims)
+  head_dims_idx = list(nl.program_id(i) for i in head_dims)
+  head_id = linearize(head_dims_shape, head_dims_idx)
 
-  q_seq_n_tiles, q_seq_tile_size = div_ceil(seqlen, 128), 128
+  assert n_elts(head_dims_shape) == nheads, \
+    f"The grid shape mismatch, got {n_elts(head_dims_shape)} but should be {nheads}"
+
+  q_seq_n_tiles, q_seq_tile_size = div_ceil(seqlen_q, 128), 128
   d_head_n_tiles, d_head_tile_size = div_ceil(d_head, 128), min(d_head, 128)
 
-  if seqlen >= 512:
-    k_seq_n_tiles, k_seq_tile_size = seqlen // 512, 512
+  if seqlen_k >= 512:
+    k_seq_n_tiles, k_seq_tile_size = seqlen_k // 512, 512
   else:
-    k_seq_n_tiles, k_seq_tile_size = seqlen // 128, 128
+    k_seq_n_tiles, k_seq_tile_size = seqlen_k // 128, 128
 
-  k_seq_n_tiles_backward, k_seq_tile_size_backward = seqlen // 128, 128
+  k_seq_n_tiles_backward, k_seq_tile_size_backward = seqlen_k // 128, 128
   k_seq_fwd_bwd_tile_multipler = k_seq_tile_size // k_seq_tile_size_backward
 
   ##############################################################
@@ -615,7 +651,7 @@ def flash_attn_bwd(
         dk_psum=dk_psum, dv_psum=dv_psum, dq_local_reduced=dq_local_reduced,
         softmax_exp_bias=softmax_exp_bias, dy_o_sum=dy_o_sum,
         local_i_q_seq_tile=i_q_seq_tile, local_i_k_seq_tile=i_k_seq_tile,
-        seqlen=seqlen, d_head=d_head,
+        seqlen_q=seqlen_q, seqlen_k=seqlen_k, d_head=d_head, nheads=nheads,
         use_causal_mask=use_causal_mask,
         kernel_dtype=kernel_dtype, mixed_dtype=mixed_dtype,
         softmax_scale=softmax_scale,
@@ -654,36 +690,43 @@ def flash_attn_bwd(
         value=dq_local_reduced[i_q_seq_tile, i_d_head_tile, ip_dq, if_dq],
       )
 
-@trace
+  return out_dq_ref, out_dk_ref, out_dv_ref
+
+
 def _flash_attn_bwd_core(
   q_local, k_local, transposed_k_local, v_local, dy_local,
   dk_psum, dv_psum, dq_local_reduced,
   softmax_exp_bias, dy_o_sum,
   local_i_q_seq_tile, local_i_k_seq_tile,
-  seqlen, d_head,
+  seqlen_q, seqlen_k, d_head, nheads,
   use_causal_mask,
   kernel_dtype, mixed_dtype,
   softmax_scale,
   seed_local, dropout_p, dropout_p_local,
   global_i_q_seq_tile = None,
   global_i_k_seq_tile = None,
+  # Used for nl.loop_reduce on dQ if local_i_k_seq_tile is not an index e.g. if it has an offset
+  local_i_k_seq_tile_for_dq_reduce = None,
 ):
   """
-  The flash backward core funciton to calculate the gradients of Q, K and V
+  The flash backward core function to calculate the gradients of Q, K and V
   of the given tiles. The result will be accumulated into the dk, dv, dq psum
   """
-  q_seq_n_tiles, q_seq_tile_size = div_ceil(seqlen, 128), 128
+  q_seq_n_tiles, q_seq_tile_size = div_ceil(seqlen_q, 128), 128
   d_head_n_tiles, d_head_tile_size = div_ceil(d_head, 128), min(d_head, 128)
-  if seqlen >= 512:
-    k_seq_n_tiles, k_seq_tile_size = seqlen // 512, 512
+  if seqlen_k >= 512:
+    k_seq_n_tiles, k_seq_tile_size = seqlen_k // 512, 512
   else:
-    k_seq_n_tiles, k_seq_tile_size = seqlen // 128, 128
-  k_seq_n_tiles_backward, k_seq_tile_size_backward = seqlen // 128, 128
+    k_seq_n_tiles, k_seq_tile_size = seqlen_k // 128, 128
+  k_seq_n_tiles_backward, k_seq_tile_size_backward = seqlen_k // 128, 128
   k_seq_fwd_bwd_tile_multipler = k_seq_tile_size // k_seq_tile_size_backward
 
   if global_i_q_seq_tile is None:
     global_i_q_seq_tile = local_i_q_seq_tile
     global_i_k_seq_tile = local_i_k_seq_tile
+  
+  if local_i_k_seq_tile_for_dq_reduce is None:
+    local_i_k_seq_tile_for_dq_reduce = local_i_k_seq_tile
 
   mask = global_i_q_seq_tile * q_seq_tile_size >= global_i_k_seq_tile * k_seq_tile_size if use_causal_mask else None
   # PSUM buffer shape: [q_seq_tile_size P, k_seq_tile_size F]
@@ -735,7 +778,7 @@ def _flash_attn_bwd_core(
   if dropout_p > 0.0:
     offset = global_i_k_seq_tile + global_i_q_seq_tile * k_seq_n_tiles \
               + head_id * k_seq_n_tiles * q_seq_n_tiles \
-              + batch_id * nl.num_programs(1) * k_seq_n_tiles * q_seq_n_tiles
+              + batch_id * nheads * k_seq_n_tiles * q_seq_n_tiles
     offset_seed = nl.add(seed_local[0, 0], offset, mask=mask)
     nl.random_seed(seed=offset_seed, mask=mask)
     softmax_y[ip_q, if_k] = nl.dropout(softmax_y[ip_q, if_k], rate=dropout_p_local[ip_q, 0], mask=mask)
@@ -778,12 +821,12 @@ def _flash_attn_bwd_core(
   #####################################################################
   softmax_dx_local = nl.ndarray((par_dim(q_seq_tile_size), k_seq_tile_size), dtype=kernel_dtype, buffer=nl.sbuf)
   softmax_dx_local[ip_q, if_k] = \
-    nisa.tensor_scalar(data=softmax_dy[ip_q, if_k],
-                        op0=np.subtract,
-                        operand0=dy_o_sum[local_i_q_seq_tile, ip_q, 0],
-                        op1=np.multiply,
-                        operand1=softmax_y[ip_q, if_k],
-                        mask=mask)
+    nisa.scalar_tensor_tensor(data=softmax_dy[ip_q, if_k],
+                              op0=np.subtract,
+                              operand0=dy_o_sum[local_i_q_seq_tile, ip_q, 0],
+                              op1=np.multiply,
+                              operand1=softmax_y[ip_q, if_k],
+                              mask=mask)
 
   #####################################################################
   # Step 5.1 Calculate dK, with matmul(stationary=Q, moving=softmax_dx)
@@ -820,10 +863,12 @@ def _flash_attn_bwd_core(
           mask=mask)
     dq_local = nl.multiply(dq_psum[ip_dq, if_dq], softmax_scale, dtype=kernel_dtype, mask=mask)
     dq_local_reduced[local_i_q_seq_tile, i_d_head_tile, ip_dq, if_dq] = nl.loop_reduce(
-      dq_local, op=np.add, loop_indices=(local_i_k_seq_tile,),
+      dq_local, op=np.add, loop_indices=(local_i_k_seq_tile_for_dq_reduce,),
       dtype=mixed_dtype, mask=mask)
 
-def fused_self_attn_for_SD_small_head_size(q_ref, k_ref, v_ref, out_ref, use_causal_mask=False,
+
+@nki.jit
+def fused_self_attn_for_SD_small_head_size(q_ref, k_ref, v_ref, use_causal_mask=False,
                                            mixed_percision=True):
   """
   Fused self attention kernel for small head size Stable Diffusion workload.
@@ -853,16 +898,17 @@ def fused_self_attn_for_SD_small_head_size(q_ref, k_ref, v_ref, out_ref, use_cau
   # Assume all IO tensors have the same dtype
   kernel_dtype = q_ref.dtype
   pe_in_dt = nl.bfloat16 if mixed_percision else np.float32
-  assert q_ref.dtype == k_ref.dtype == v_ref.dtype == out_ref.dtype
+  assert q_ref.dtype == k_ref.dtype == v_ref.dtype
 
   # Shape checking
   bs, d_head, seqlen = q_ref.shape
   assert d_head <= 128, "Cannot use this kernel for d_head > 128"
   assert tuple(q_ref.shape) == (bs, d_head, seqlen), 'Input shape mismatch!'
   assert tuple(k_ref.shape) == (bs, seqlen, d_head), 'Input shape mismatch!'
-  assert tuple(v_ref.shape) == (bs, seqlen,
-                                d_head), f'Input shape mismatch! Expected: {(bs, seqlen, d_head)} Actual: {tuple(v_ref.shape)}'
-  assert tuple(out_ref.shape) == (bs, seqlen, d_head), 'Output shape mismatch!'
+  assert tuple(v_ref.shape) == (bs, seqlen,  d_head), \
+    f'Input shape mismatch! Expected: {(bs, seqlen, d_head)} Actual: {tuple(v_ref.shape)}'
+
+  out_ref = nl.ndarray((bs, seqlen, d_head), dtype=q_ref.dtype, buffer=nl.shared_hbm)
 
   # Softmax scaling factor, multiplied onto Q
   softmax_scale = 0.125
@@ -1028,4 +1074,5 @@ def fused_self_attn_for_SD_small_head_size(q_ref, k_ref, v_ref, out_ref, use_cau
     nl.store(
       out_ref[batch_id, i_q_seq_tile * q_seq_tile_size + if_out, ip_out],
       value=attn_res_div)
-    
\ No newline at end of file
+
+  return out_ref
diff --git a/src/reference/tutorial.py b/src/reference/tutorial.py
index 4f3ebef..b32492b 100644
--- a/src/reference/tutorial.py
+++ b/src/reference/tutorial.py
@@ -5,9 +5,14 @@
 
 """
 
+from neuronxcc import nki
 import neuronxcc.nki.language as nl
 
-def add_kernel_nx8x128x512(a_ptr, b_ptr, c_ptr, n_elements):
+
+@nki.jit
+def add_kernel_nx8x128x512(a_ptr, b_ptr, n_elements):
+  c_ptr = nl.ndarray(a_ptr.shape, dtype=a_ptr.dtype, buffer=nl.shared_hbm)
+
   ix = nl.arange(128)[:, None]
   iy = nl.arange(512)[None, :]
 
@@ -18,12 +23,9 @@ def add_kernel_nx8x128x512(a_ptr, b_ptr, c_ptr, n_elements):
 
   for i in nl.affine_range(8):
     offset = j * block_size + i * tile_size + 512 * ix + iy
-    mask = offset < n_elements
-    a_ptr = a_ptr.ptr + offset
-    b_ptr = b_ptr.ptr + offset
-    c_ptr = c_ptr.ptr + offset
-
-    a = nl.load(a_ptr, mask=mask)
-    b = nl.load(b_ptr, mask=mask)
-    c = a + b
-    nl.store(c_ptr, value=c, mask=mask)
\ No newline at end of file
+    a = nl.load(a_ptr[j, i, ix, iy], mask=offset < n_elements)
+    b = nl.load(b_ptr[j, i, ix, iy], mask=offset < n_elements)
+    c = nl.add(a, b, mask=offset < n_elements)
+    nl.store(c_ptr[j, i, ix, iy], value=c, mask=offset < n_elements)
+
+  return c_ptr
diff --git a/src/reference/vision.py b/src/reference/vision.py
index bc54941..4899d27 100644
--- a/src/reference/vision.py
+++ b/src/reference/vision.py
@@ -8,10 +8,13 @@
 
 import neuronxcc.nki.language as nl
 import neuronxcc.nki.isa as nisa
+from neuronxcc import nki
 from neuronxcc.nki.language import par_dim
 import neuronxcc.nki.typing as nt
 
-def select_and_scatter_kernel(operand_tensor, source_tensor, out_tensor):
+
+@nki.jit
+def select_and_scatter_kernel(operand_tensor, source_tensor):
   """
   Implementation of a select-and-scatter kernel.
 
@@ -51,7 +54,10 @@ def select_and_scatter_kernel(operand_tensor, source_tensor, out_tensor):
   assert C == 64 and N % 2 == 0
 
   kernel_dtype = operand_tensor.dtype
-  assert operand_tensor.dtype == source_tensor.dtype == out_tensor.dtype
+  assert operand_tensor.dtype == source_tensor.dtype
+
+  out_tensor = nl.ndarray((N, C, H, W), dtype=operand_tensor.dtype,
+                          buffer=nl.shared_hbm)
 
   p = 128  # num of partitions to use
   for ib in nl.affine_range(N // 2):
@@ -156,8 +162,11 @@ def select_and_scatter_kernel(operand_tensor, source_tensor, out_tensor):
       nl.store(out_tensor[2 * ib + ib_1, 0:64, 0:H, 0:W],
                value=out_local[(ib_1 * 64):((ib_1 + 1) * 64), 0:H, 0:W])
 
+  return out_tensor
 
-def resize_nearest_fixed_dma_kernel(data_tensor, out_tensor):
+
+@nki.jit
+def resize_nearest_fixed_dma_kernel(data_tensor, out_shape):
   """
   Resize the input image to the given size using the nearest interpolation mode. This kernel is designed to be used when the scaling factor is not an integer. 
 
@@ -174,7 +183,9 @@ def resize_nearest_fixed_dma_kernel(data_tensor, out_tensor):
   
   """
   in_b, in_h, in_w, in_c = data_tensor.shape
-  out_b, out_h, out_w, out_c = out_tensor.shape
+  out_b, out_h, out_w, out_c = out_shape
+  out_tensor = nl.ndarray(out_shape, dtype=data_tensor.dtype,
+                          buffer=nl.shared_hbm)
 
   assert in_b == out_b, "Input batch and output batch must be identical"
   assert in_c == out_c, "Input channel and output channel must be identical"
@@ -198,3 +209,5 @@ def resize_nearest_fixed_dma_kernel(data_tensor, out_tensor):
     local_data = nl.load(target_addr)
     dst_addr_0 = out_tile[b_map, i, c_map]
     nl.store(dst_addr_0, value=local_data)
+
+  return out_tensor
diff --git a/src/tutorials/average_pool2d/average_pool2d_jax.py b/src/tutorials/average_pool2d/average_pool2d_jax.py
index e3b428d..139c42d 100644
--- a/src/tutorials/average_pool2d/average_pool2d_jax.py
+++ b/src/tutorials/average_pool2d/average_pool2d_jax.py
@@ -4,29 +4,22 @@
 JAX implementation for average pool 2D NKI tutorial.
 
 """
-from functools import partial
-from jax_neuronx import nki_call
-import jax
+# NKI_EXAMPLE_40_BEGIN
 import jax.numpy as jnp
-
-from average_pool2d_nki_kernels import tensor_avgpool_kernel_
-
-
-def tensor_avgpool_kernel(in_array, pool_size):
-  return nki_call(
-    partial(tensor_avgpool_kernel_, pool_size=pool_size),
-    in_array,
-    out_shape=jax.ShapeDtypeStruct((C, HOUT, WOUT), dtype=in_array.dtype),
-  )
+# NKI_EXAMPLE_40_END
+from average_pool2d_nki_kernels import tensor_avgpool_kernel
 
 
+# NKI_EXAMPLE_40_BEGIN
 # Reference JAX implementation
 def jax_average_pool_2D(in_tensor, pool_size):
   c, h_in, w_in = in_tensor.shape
   reshaped = in_tensor.reshape(c, h_in // pool_size, pool_size, w_in // pool_size, pool_size)
   return jnp.nanmean(reshaped, axis=(2, 4))
+  # NKI_EXAMPLE_40_END
 
 
+# NKI_EXAMPLE_41_BEGIN
 if __name__ == "__main__":
   POOL_SIZE = 2
   C, HIN, WIN = 2, 6, 6
@@ -34,7 +27,9 @@ def jax_average_pool_2D(in_tensor, pool_size):
 
   in_array = jnp.arange(C * HIN * WIN, dtype=jnp.float32).reshape(C, HIN, WIN)
 
+  # NKI_EXAMPLE_39_BEGIN
   out_nki = tensor_avgpool_kernel(in_array, pool_size=POOL_SIZE)
+  # NKI_EXAMPLE_39_END
   out_jax = jax_average_pool_2D(in_array, pool_size=POOL_SIZE)
 
   print(in_array, out_nki, out_jax)
@@ -42,4 +37,5 @@ def jax_average_pool_2D(in_tensor, pool_size):
   if jnp.allclose(out_nki, out_jax):
     print("NKI and JAX match")
   else:
-    print("NKI and JAX differ")
\ No newline at end of file
+    print("NKI and JAX differ")
+    # NKI_EXAMPLE_41_END
diff --git a/src/tutorials/average_pool2d/average_pool2d_nki_kernels.py b/src/tutorials/average_pool2d/average_pool2d_nki_kernels.py
index c81a4a5..68d3a31 100644
--- a/src/tutorials/average_pool2d/average_pool2d_nki_kernels.py
+++ b/src/tutorials/average_pool2d/average_pool2d_nki_kernels.py
@@ -5,48 +5,40 @@
 
 """
 import numpy as np
+# NKI_EXAMPLE_37_BEGIN
 import neuronxcc.nki as nki
 import neuronxcc.nki.language as nl
+from neuronxcc.nki.typing import tensor
 
-
-def tensor_avgpool_kernel_(in_tensor, out_tensor, pool_size):
+@nki.jit
+def tensor_avgpool_kernel(in_tensor, pool_size):
   """NKI kernel to compute a 2D avg-pool operation
 
   Args:
       in_tensor: an input tensor, of shape C x H x W
       pool_size: an integer representing a (square) pool-window size
+
+  Return:
       out_tensor: the resulting output tensor, of shape C x (H/pool_size) x (W/pool_size)
   """
 
   # Get input/output dimensions
   sz_cin, sz_hin, sz_win = in_tensor.shape
-  sz_cout, sz_hout, sz_wout = out_tensor.shape
-  assert sz_cin == sz_cout
+  sz_hout = sz_hin // pool_size
+  sz_wout = sz_win // pool_size
+  # Create output tensor shared between all SPMD instances as result tensor
+  out_tensor = nl.ndarray((sz_cin, sz_hout, sz_wout), dtype=in_tensor.dtype,
+                          buffer=nl.shared_hbm)
 
   # Set relevant sizes
   sz_p = sz_cin
   sz_pool = pool_size
 
-  # Generate tensor h/w index patterns
-  # 3D indexing according to [C, H, W]
-  i_p = nl.arange(sz_p)[:, None, None] # 3D for
-  i_win = nl.arange(sz_win)[None, None, :]
-  i_hin = nl.arange(sz_hin)[None, :, None]
-
-  i_wout = nl.arange(sz_wout)[None, None, :]
-  i_hout = nl.arange(sz_hout)[None, :, None]
-
   # Generate pool index patterns (requires two extra dimensions, for the pool window)
-  i_0 = nl.arange(sz_p)[:, None, None, None, None] #
-  i_1 = nl.arange(sz_hin//sz_pool)[None, :, None, None, None] # y_outer
-  i_2 = nl.arange(sz_pool)[None, None, :, None, None] # y_inner
-  i_3 = nl.arange(sz_win//sz_pool)[None, None, None, :, None] # x_outer
-  i_4 = nl.arange(sz_pool)[None, None, None, None, :] # x_inner
+  i0, i1, i2, i3, i4 = nl.mgrid[0:sz_p, 0:sz_hin//sz_pool, 0:sz_pool, 0:sz_win//sz_pool, 0:sz_pool]
 
   # Load input data from external memory to on-chip memory
-  # Declare ndarray to force a 3D tensor (temporary requirement)
-  in_tile = nl.ndarray([sz_p, sz_hin, sz_win], dtype=in_tensor.dtype)
-  in_tile[:,:,:] = nl.load(in_tensor[i_p, i_hin, i_win])
+  in_tile: tensor[sz_p, sz_hin, sz_win] = nl.load(in_tensor)
 
   # Perform the pooling operation:
   # We use numpy's advanced indexing, in order to extend in_tile to 5D, and then reduce-average two dimension.
@@ -54,10 +46,15 @@ def tensor_avgpool_kernel_(in_tensor, out_tensor, pool_size):
   # axis[1] and axis[2] together index the rows, with axis[2] responsible for inner strides
   # (i.e. inside a pooling window), and axis[1] responsible for the outer strides. As such, we reduce over axis[2].
   # Similarly, axis[3] and axis[4] together index the columns, and we thus reduce over axis[4].
-  out_tile = nl.sum(in_tile[i_0, sz_pool*i_1+i_2, sz_pool*i_3+i_4], axis=[2,4]) / (pool_size*pool_size)
+  out_tile : tensor[sz_p, sz_hout, sz_wout] = nl.sum(in_tile[i0, sz_pool*i1+i2, sz_pool*i3+i4],
+                                                     axis=[2,4]) / (pool_size*pool_size)
+
+  # Store the results back to hbm
+  nl.store(out_tensor, value=out_tile)
 
-  # Store the results back to external memory
-  nl.store(out_tensor[i_p, i_hout, i_wout], value=out_tile)
+  # Transfer the ownership of `out_tensor` to the caller
+  return out_tensor
+  # NKI_EXAMPLE_37_END
 
 
 # Reference NumPy implementation
@@ -74,10 +71,8 @@ def np_average_pool_2D(in_tensor, pool_size):
   HOUT, WOUT = HIN//POOL_SIZE, WIN//POOL_SIZE
 
   in_tensor = np.arange(C * HIN * WIN, dtype=np.float16).reshape(C, HIN, WIN)
-  out_nki = np.zeros((C, HOUT, WOUT), dtype=np.float16)
 
-  tensor_avgpool_kernel_baremetal = nki.baremetal(tensor_avgpool_kernel_)
-  tensor_avgpool_kernel_baremetal(in_tensor, out_nki, POOL_SIZE)
+  out_nki = tensor_avgpool_kernel(in_tensor, POOL_SIZE)
 
   out_np = np_average_pool_2D(in_tensor, POOL_SIZE)
 
diff --git a/src/tutorials/average_pool2d/average_pool2d_torch.py b/src/tutorials/average_pool2d/average_pool2d_torch.py
index 3409a31..c5fb4ea 100644
--- a/src/tutorials/average_pool2d/average_pool2d_torch.py
+++ b/src/tutorials/average_pool2d/average_pool2d_torch.py
@@ -4,13 +4,14 @@
 PyTorch implementation for average pool 2D NKI tutorial.
 
 """
+# NKI_EXAMPLE_38_BEGIN
 import torch
-from torch_neuronx import nki_jit
 from torch_xla.core import xla_model as xm
-
-from average_pool2d_nki_kernels import tensor_avgpool_kernel_
+# NKI_EXAMPLE_38_END
+from average_pool2d_nki_kernels import tensor_avgpool_kernel
 
 
+# NKI_EXAMPLE_38_BEGIN
 if __name__ == "__main__":
   device = xm.xla_device()
 
@@ -22,8 +23,7 @@
   in_tensor = torch.arange(C * HIN * WIN, dtype=torch.bfloat16).reshape(C, HIN, WIN).to(device=device)
   out_nki = torch.zeros((C, HOUT, WOUT), dtype=torch.bfloat16).to(device=device)
 
-  tensor_avgpool_kernel_torch = nki_jit(tensor_avgpool_kernel_)
-  tensor_avgpool_kernel_torch(in_tensor, out_nki, POOL_SIZE)
+  out_nki = tensor_avgpool_kernel(in_tensor, POOL_SIZE)
 
   out_torch = torch.nn.functional.avg_pool2d(in_tensor, POOL_SIZE, POOL_SIZE)
 
@@ -33,3 +33,4 @@
     print("NKI and Torch match")
   else:
     print("NKI and Torch differ")
+    # NKI_EXAMPLE_38_END
diff --git a/src/tutorials/fused_mamba/mamba_nki_kernels.py b/src/tutorials/fused_mamba/mamba_nki_kernels.py
index 9f8af60..4ff6642 100644
--- a/src/tutorials/fused_mamba/mamba_nki_kernels.py
+++ b/src/tutorials/fused_mamba/mamba_nki_kernels.py
@@ -4,16 +4,19 @@
 Mamba-v1 NKI kernel implementation.
 
 """
+# NKI_EXAMPLE_25_BEGIN
 import neuronxcc.nki as nki
 import neuronxcc.nki.language as nl
 import neuronxcc.nki.isa as nisa
 import numpy as np
+# NKI_EXAMPLE_25_END
 import os
 import argparse
 import itertools
 
-
-def mamba_v1(delta, u, A, B, C, output):
+# NKI_EXAMPLE_25_BEGIN
+@nki.jit
+def mamba_v1(delta, u, A, B, C):
     """Computes the SSM operation in the Mamba model.
 
     :param delta: (batch_size, channels, seq_len)
@@ -24,6 +27,9 @@ def mamba_v1(delta, u, A, B, C, output):
     :return: (batch_size, channels, seq_len)
     """
     batch_size, channels, seq_len = delta.shape
+    output = nl.ndarray((batch_size, channels, seq_len), dtype=delta.dtype,
+                        buffer=nl.shared_hbm)
+
     _, state_size = A.shape
 
     # We can relax this using mask paramters in all the NKI API calls
@@ -84,8 +90,12 @@ def mamba_v1(delta, u, A, B, C, output):
             nl.store(output[i_batch, channel_start:channel_start+channel_psize, 0:seq_len],
                     scanC_accum[i_channel_tile, 0:channel_psize, 0:seq_len])
 
+    return output
+# NKI_EXAMPLE_25_END
 
-def mamba_v2(delta, u, A, B, C, output):
+# NKI_EXAMPLE_26_BEGIN
+@nki.jit
+def mamba_v2(delta, u, A, B, C):
     """Computes the SSM operation in the Mamba model.
 
     :param delta: (batch_size, channels, seq_len)
@@ -96,6 +106,8 @@ def mamba_v2(delta, u, A, B, C, output):
     :return: (batch_size, channels, seq_len)
     """
     batch_size, channels, seq_len = delta.shape
+    output = nl.ndarray((batch_size, channels, seq_len), dtype=delta.dtype,
+                        buffer=nl.shared_hbm)
     _, state_size = A.shape
 
     assert channels % 128 == 0
@@ -153,8 +165,12 @@ def mamba_v2(delta, u, A, B, C, output):
             nl.store(output[i_batch, channel_start:channel_start+channel_psize, 0:seq_len],
                     scanC_accum[0:channel_psize, 0:seq_len])
 
+    return output
+# NKI_EXAMPLE_26_END
+
 
-def mamba_v3(delta, u, A, B, C, output):
+@nki.jit
+def mamba_v3(delta, u, A, B, C):
     """Computes the SSM operation in the Mamba model.
 
     :param delta: (batch_size, channels, seq_len)
@@ -165,6 +181,8 @@ def mamba_v3(delta, u, A, B, C, output):
     :return: (batch_size, channels, seq_len)
     """
     batch_size, channels, seq_len = delta.shape
+    output = nl.ndarray((batch_size, channels, seq_len), dtype=delta.dtype,
+                        buffer=nl.shared_hbm)
     _, state_size = A.shape
 
     # Map channels to the partition dimension
@@ -239,6 +257,7 @@ def mamba_v3(delta, u, A, B, C, output):
             # Store scanC_accum for a single batch to output
             nl.store(output[i_batch, channel_start:channel_start+channel_psize, 0:seq_len],
                     scanC_accum[0:channel_psize, 0:seq_len])
+    return output
 
 
 def parse_args():
@@ -310,9 +329,7 @@ def parse_args():
         if args.mode == "accuracy":
             # v1: reference kernel
             print(f">>>> Running v1 (reference).")
-            nki_out_v1 = np.empty((batch, channels, seq_len), dtype=dtype)
-            nki.baremetal(mamba_v1)\
-                         (delta, u, A, B, C, nki_out_v1)
+            nki_out_v1 = mamba_v1(delta, u, A, B, C)
 
             for version in args.version:
                 if version == "v1":
@@ -321,9 +338,7 @@ def parse_args():
 
                 print(f">>>> Running version {version}.")
                 func = func_dict[version]
-                nki_out_test = np.empty((batch, channels, seq_len), dtype=dtype)
-                nki.baremetal(func)\
-                             (delta, u, A, B, C, nki_out_test)
+                nki_out_test = func(delta, u, A, B, C)
                 print(f">>>> mamba {version} matches?", np.all(nki_out_test == nki_out_v1))
                 assert np.all(nki_out_test == nki_out_v1)
 
@@ -333,11 +348,10 @@ def parse_args():
             for version in args.version:
                 print(f">>>> Running version {version}.")
                 func = func_dict[version]
-                nki_out_test = np.empty((batch, channels, seq_len), dtype=dtype)
                 nki.benchmark(func,
                               save_neff_name='file.neff',
                               save_trace_name='profile.ntff')\
-                             (delta, u, A, B, C, nki_out_test)
+                             (delta, u, A, B, C)
                 # TODO: rename neff/ntff (bug in nki.benchmark with neff name)
                 os.rename("file.neff", f"{version}_b{batch}_sl{seq_len}_c{channels}_ss{state_size}.neff")
                 os.rename("profile.ntff", f"{version}_b{batch}_sl{seq_len}_c{channels}_ss{state_size}.ntff")
diff --git a/src/tutorials/fused_mamba/mamba_torch.py b/src/tutorials/fused_mamba/mamba_torch.py
index a2e593f..cd94a0b 100644
--- a/src/tutorials/fused_mamba/mamba_torch.py
+++ b/src/tutorials/fused_mamba/mamba_torch.py
@@ -5,6 +5,7 @@
 
 """
 
+# NKI_EXAMPLE_24_BEGIN
 import torch
 import torch_neuronx
 import torch_xla.core.xla_model as xm
@@ -99,16 +100,14 @@ def parse_args():
     torch_out = mamba_layer(delta, A, B, u, C)
     xm.mark_step()
     print(torch_out)
+    # NKI_EXAMPLE_24_END
 
     if args.mode == "accuracy":
         # Call NKI mamba_v1 kernel to check accuracy
         from mamba_nki_kernels import mamba_v1
-        from torch_neuronx import nki_jit
-
-        nki_out = torch.empty((batch, channels, seq_len), dtype=dtype, device=device)
 
         xm.mark_step()
-        nki_jit(mamba_v1)(delta, u, A, B, C, nki_out)
+        nki_out = mamba_v1(delta, u, A, B, C)
         xm.mark_step()
 
         allclose = torch.allclose(torch_out, nki_out, atol=1e-2, rtol=1e-2)
diff --git a/src/tutorials/layernorm/layernorm_nki_kernel.py b/src/tutorials/layernorm/layernorm_nki_kernel.py
index 503ce7d..c0c235c 100644
--- a/src/tutorials/layernorm/layernorm_nki_kernel.py
+++ b/src/tutorials/layernorm/layernorm_nki_kernel.py
@@ -4,21 +4,27 @@
 LayerNorm NKI kernel implementation.
 
 """
+# NKI_EXAMPLE_45_BEGIN
 import neuronxcc.nki as nki
 import neuronxcc.nki.language as nl
 import neuronxcc.nki.isa as nisa
 import numpy as np
 import math
+# NKI_EXAMPLE_45_END
 import os
 import argparse
 
 
-def nki_layernorm_kernel_v1(input_tensor, epsilon, gamma_vector, beta_vector, output_tensor):
+# NKI_EXAMPLE_45_BEGIN
+@nki.jit
+def nki_layernorm_kernel_v1(input_tensor, epsilon, gamma_vector, beta_vector):
   """Computes LayerNorm.
     Used nki.language APIs only.
   """
+  output_tensor = nl.ndarray(input_tensor.shape, dtype=input_tensor.dtype,
+                             buffer=nl.shared_hbm)
+
   # Ensure that the shapes of tensors match
-  assert input_tensor.shape == output_tensor.shape
   assert input_tensor.shape[1] == gamma_vector.shape[0] == beta_vector.shape[0]
 
   # Generate tile indices for loading/storing data
@@ -58,12 +64,20 @@ def nki_layernorm_kernel_v1(input_tensor, epsilon, gamma_vector, beta_vector, ou
     nl.store(output_tensor[i * nl.tile_size.pmax + i_p_io, i_f_io], value=output_sb,
              mask=(i * nl.tile_size.pmax + i_p_io < num_rows))
 
-def nki_layernorm_kernel_v2(input_tensor, epsilon, gamma_vector, beta_vector, output_tensor):
+  return output_tensor
+  # NKI_EXAMPLE_45_END
+
+
+# NKI_EXAMPLE_46_BEGIN
+@nki.jit
+def nki_layernorm_kernel_v2(input_tensor, epsilon, gamma_vector, beta_vector):
   """Computes LayerNorm.
     Used nki.isa APIs to calculate mean/variance and perform shift/scale.
   """
+  output_tensor = nl.ndarray(input_tensor.shape, dtype=input_tensor.dtype,
+                             buffer=nl.shared_hbm)
+
   # Ensure that the shapes of tensors match
-  assert input_tensor.shape == output_tensor.shape
   assert input_tensor.shape[1] == gamma_vector.shape[0] == beta_vector.shape[0]
 
   # Generate tile indices for loading/storing data
@@ -122,69 +136,66 @@ def nki_layernorm_kernel_v2(input_tensor, epsilon, gamma_vector, beta_vector, ou
     nl.store(output_tensor[i * nl.tile_size.pmax + i_p_io, i_f_io], value=output_sb,
              mask=(i * nl.tile_size.pmax + i_p_io < num_rows))
 
+  return output_tensor
+  # NKI_EXAMPLE_46_END
+
 
 def parse_args():
-    parser = argparse.ArgumentParser(
-    """Run LayerNorm pytorch implementation.
-    """)
-    parser.add_argument("--nrows",
-                        default=4*1024,
-                        type=int,
-                        help="""The number of input rows""")
-    parser.add_argument("--ncols",
-                        default=8*1024,
-                        type=int,
-                        help="""The number of input columns""")
-    parser.add_argument("--mode",
-                        choices=["accuracy", "perf"],
-                        default="accuracy",
-                        help="""Do accuracy test or perf test.
-                                Accuracy test compares LayerNorm kernel against PyTorch implementation.
-                                Perf test will generate a NEFF for the PyTorch implementation in local directory
-                                for a manual run of neuron-profile.
-                             """)
-    args = parser.parse_args()
-    return args
+  parser = argparse.ArgumentParser(
+  """Run LayerNorm pytorch implementation.
+  """)
+  parser.add_argument("--nrows",
+                      default=4*1024,
+                      type=int,
+                      help="""The number of input rows""")
+  parser.add_argument("--ncols",
+                      default=8*1024,
+                      type=int,
+                      help="""The number of input columns""")
+  parser.add_argument("--mode",
+                      choices=["accuracy", "perf"],
+                      default="accuracy",
+                      help="""Do accuracy test or perf test.
+                              Accuracy test compares LayerNorm kernel against PyTorch implementation.
+                              Perf test will generate a NEFF for the PyTorch implementation in local directory
+                              for a manual run of neuron-profile.
+                           """)
+  args = parser.parse_args()
+  return args
 
 if __name__ == "__main__":
-    args = parse_args()
-    func_dict = {"v1": nki_layernorm_kernel_v1,
-                 "v2": nki_layernorm_kernel_v2,
-                 }
-
-    # Generate toy example
-    num_rows = args.nrows
-    num_cols = args.ncols
-    input_tensor = np.random.rand(num_rows, num_cols).astype(np.float32)
-    gamma_vector = np.random.rand(num_cols).astype(np.float32)
-    beta_vector = np.random.rand(num_cols).astype(np.float32)
-    epsilon = 1e-5
-            
-    if args.mode == "accuracy":
-        # version 1
-        print(f">>>> Running version 1")
-        nki_out_v1 = np.empty((num_rows, num_cols), dtype=np.float32)
-        nki.baremetal(nki_layernorm_kernel_v1)\
-                    (input_tensor, epsilon, gamma_vector, beta_vector, nki_out_v1)
-        # version 2
-        print(f">>>> Running version 2")
-        nki_out_v2 = np.empty((num_rows, num_cols), dtype=np.float32)
-        nki.baremetal(nki_layernorm_kernel_v2)\
-                    (input_tensor, epsilon, gamma_vector, beta_vector, nki_out_v2)
-        # compare
-        np_all = np.all(nki_out_v1 == nki_out_v1)
-        print(f">>>> LayerNorm V1 and V2 matches?", np_all)
-        assert np_all
-                
-    else:
-      # perf mode
-      for version in ["v1", "v2"]:
-          print(f">>>> Running version {version}.")
-          func = func_dict[version]
-          nki_out_test = np.empty((num_rows, num_cols), dtype=np.float32)
-          nki.benchmark(func,
-                        save_neff_name='file.neff',
-                        save_trace_name='profile.ntff')\
-                        (input_tensor, epsilon, gamma_vector, beta_vector, nki_out_test)
-          os.rename("file.neff", f"{version}_{num_rows}_{num_cols}.neff")
-          os.rename("profile.ntff", f"{version}_{num_rows}_{num_cols}.ntff")
+  args = parse_args()
+  func_dict = {"v1": nki_layernorm_kernel_v1,
+               "v2": nki_layernorm_kernel_v2,
+               }
+
+  # Generate toy example
+  num_rows = args.nrows
+  num_cols = args.ncols
+  input_tensor = np.random.rand(num_rows, num_cols).astype(np.float32)
+  gamma_vector = np.random.rand(num_cols).astype(np.float32)
+  beta_vector = np.random.rand(num_cols).astype(np.float32)
+  epsilon = 1e-5
+
+  if args.mode == "accuracy":
+    # version 1
+    print(f">>>> Running version 1")
+    nki_out_v1 = nki_layernorm_kernel_v1(input_tensor, epsilon, gamma_vector, beta_vector)
+    # version 2
+    print(f">>>> Running version 2")
+    nki_out_v2 = nki_layernorm_kernel_v2(input_tensor, epsilon, gamma_vector, beta_vector)
+    # compare
+    np_all = np.all(nki_out_v1 == nki_out_v1)
+    print(f">>>> LayerNorm V1 and V2 matches?", np_all)
+    assert np_all
+
+  else:
+    # perf mode
+    for version in ["v1", "v2"]:
+      print(f">>>> Running version {version}.")
+      func = func_dict[version]
+      benchmark_kernel = nki.benchmark(func, save_neff_name='file.neff',
+                                       save_trace_name='profile.ntff')
+      nki_out_test = benchmark_kernel(input_tensor, epsilon, gamma_vector, beta_vector)
+      os.rename("file.neff", f"{version}_{num_rows}_{num_cols}.neff")
+      os.rename("profile.ntff", f"{version}_{num_rows}_{num_cols}.ntff")
diff --git a/src/tutorials/layernorm/layernorm_torch.py b/src/tutorials/layernorm/layernorm_torch.py
index 59853fd..c2be186 100644
--- a/src/tutorials/layernorm/layernorm_torch.py
+++ b/src/tutorials/layernorm/layernorm_torch.py
@@ -4,9 +4,9 @@
 LayerNorm NKI kernel implementation.
 
 """
+# NKI_EXAMPLE_47_BEGIN
 import torch
 from torch_xla.core import xla_model as xm
-from torch_neuronx import nki_jit
 import argparse
 import os
 
@@ -42,13 +42,16 @@ def parse_args():
     args = parser.parse_args()
     return args
 
+
+from neuronxcc.nki.docs.examples.layernorm.layernorm_nki_kernel import nki_layernorm_kernel_v1, \
+  nki_layernorm_kernel_v2
+
 if __name__ == "__main__":
     args = parse_args()
-    from neuronxcc.nki.docs.examples.layernorm.layernorm_nki_kernel import nki_layernorm_kernel_v1, nki_layernorm_kernel_v2
     func_dict = {"v1": nki_layernorm_kernel_v1,
                  "v2": nki_layernorm_kernel_v2,
                  }
-    
+
     device = xm.xla_device()
     num_rows = args.nrows
     num_cols = args.ncols
@@ -58,7 +61,7 @@ def parse_args():
     gamma_vector = torch.rand((num_cols), dtype=torch.float32)
     beta_vector = torch.rand((num_cols), dtype=torch.float32)
     epsilon = 1e-5
-    
+
     # Compute torch layernorm layer in cpu
     output_torch = layernorm_layer(input_tensor, epsilon, gamma_vector, beta_vector)
 
@@ -66,17 +69,15 @@ def parse_args():
     input_tensor = input_tensor.to(device=device)
     gamma_vector = gamma_vector.to(device=device)
     beta_vector = beta_vector.to(device=device)
-    output_nki = torch.zeros((num_rows, num_cols), dtype=torch.float32).to(device=device)
 
     print(f">>>> Running version {args.version}.")
     func = func_dict[args.version]
 
     # add nki_jit decorator
-    nki_layernorm_kernel = nki_jit(func)
 
     # Compute NKI layernorm kernel in NeuronDevice
     xm.mark_step()
-    nki_layernorm_kernel(input_tensor, epsilon, gamma_vector, beta_vector, output_nki)
+    output_nki = func(input_tensor, epsilon, gamma_vector, beta_vector)
     xm.mark_step()
     output_nki = output_nki.to(device='cpu')
 
@@ -86,5 +87,6 @@ def parse_args():
         print("NKI and Torch match")
     else:
         print("NKI and Torch differ")
-    
+        # NKI_EXAMPLE_47_END
+
     assert allclose
\ No newline at end of file
diff --git a/src/tutorials/matrix_multiplication/matrix_multiplication_nki_kernels.py b/src/tutorials/matrix_multiplication/matrix_multiplication_nki_kernels.py
index 7aeb5d6..8f913f2 100644
--- a/src/tutorials/matrix_multiplication/matrix_multiplication_nki_kernels.py
+++ b/src/tutorials/matrix_multiplication/matrix_multiplication_nki_kernels.py
@@ -12,7 +12,9 @@
 import numpy as np
 
 
-def nki_matmul_basic_(lhsT, rhs, result):
+# NKI_EXAMPLE_16_BEGIN
+@nki.jit
+def nki_matmul_basic_(lhsT, rhs):
   """NKI kernel to compute a 64x128x512 matrix multiplication operation
 
   Args:
@@ -20,8 +22,11 @@ def nki_matmul_basic_(lhsT, rhs, result):
         matrix multiplication, delivered transposed for optimal performance
       rhs: an input tensor of shape [128,512], a right hand side argument of the
         matrix multiplication
+  Returns:
       result: the resulting output tensor of shape [64,512]
   """
+  result = nl.ndarray((64, 512), dtype=lhsT.dtype, buffer=nl.shared_hbm)
+
   # Defining indexes for input LHS.T
   # - Note: here we take LayoutConstraint #1 into account:
   # "For MatMult, contraction axis must be mapped to P-dim"
@@ -53,8 +58,13 @@ def nki_matmul_basic_(lhsT, rhs, result):
   # This dictates which indices to use to address the result tile.
   nl.store(result[i_out_p, i_out_f], value=result_sbuf)
 
+  return result
+  # NKI_EXAMPLE_16_END
+
 
-def nki_matmul_tiled_(lhsT, rhs, result):
+# NKI_EXAMPLE_18_BEGIN
+@nki.jit
+def nki_matmul_tiled_(lhsT, rhs):
   """NKI kernel to compute a matrix multiplication operation in a tiled manner
 
   Args:
@@ -64,12 +74,14 @@ def nki_matmul_tiled_(lhsT, rhs, result):
       rhs: an input tensor of shape [K,N], where K is a multiple of 128, and N
         is a multiple of 512.  It is the right-hand-side argument of the matrix
         multiplication.
+  Returns:
       result: the resulting output tensor of shape [M,N]
   """
 
   K, M = lhsT.shape
   K_, N = rhs.shape
   assert K == K_, "lhsT and rhs must have the same contraction dimension"
+  result = nl.ndarray((M, N), dtype=lhsT.dtype, buffer=nl.shared_hbm)
 
   TILE_M = nl.tile_size.gemm_stationary_fmax  # 128
   TILE_K = nl.tile_size.pmax  # 128
@@ -100,8 +112,13 @@ def nki_matmul_tiled_(lhsT, rhs, result):
       nl.store(result[m * TILE_M:(m + 1) * TILE_M, n * TILE_N:(n + 1) * TILE_N],
                value=res_sb)
 
+  return result
+  # NKI_EXAMPLE_18_END
 
-def nki_matmul_hoist_load_(lhsT, rhs, result):
+
+# NKI_EXAMPLE_19_BEGIN
+@nki.jit
+def nki_matmul_hoist_load_(lhsT, rhs):
   """NKI kernel to compute a matrix multiplication operation in a tiled manner
      while hoisting the load of the lhsT and rhs to outer loops.
 
@@ -112,12 +129,14 @@ def nki_matmul_hoist_load_(lhsT, rhs, result):
       rhs: an input tensor of shape [K,N], where K is a multiple of 128, and N
         is a multiple of 512.  It is the right-hand-side argument of the matrix
         multiplication.
+  Returns:
       result: the resulting output tensor of shape [M,N]
   """
 
   K, M = lhsT.shape
   K_, N = rhs.shape
   assert K == K_, "lhsT and rhs must have the same contraction dimension"
+  result = nl.ndarray((M, N), dtype=lhsT.dtype, buffer=nl.shared_hbm)
 
   TILE_M = nl.tile_size.gemm_stationary_fmax  # 128
   TILE_K = nl.tile_size.pmax  # 128
@@ -163,8 +182,13 @@ def nki_matmul_hoist_load_(lhsT, rhs, result):
       res_sb = nl.copy(res_psum, dtype=result.dtype)
       nl.store(result[m * TILE_M + i_res.p, n * TILE_N + i_res.x], value=res_sb)
 
+  return result
+  # NKI_EXAMPLE_19_END
+
 
-def nki_matmul_block_free_dimension_(lhsT, rhs, result):
+# NKI_EXAMPLE_20_BEGIN
+@nki.jit
+def nki_matmul_block_free_dimension_(lhsT, rhs):
   """NKI kernel to compute a matrix multiplication operation while blocking the
      free dimensions of the LHS and RHS to improve memory access pattern.
 
@@ -175,12 +199,14 @@ def nki_matmul_block_free_dimension_(lhsT, rhs, result):
       rhs: an input tensor of shape [K,N], where K is a multiple of 128, and N
         is a multiple of 512.  It is the right-hand-side argument of the matrix
         multiplication.
+  Returns:
       result: the resulting output tensor of shape [M,N]
   """
 
   K, M = lhsT.shape
   K_, N = rhs.shape
   assert K == K_, "lhsT and rhs must have the same contraction dimension"
+  result = nl.ndarray((M, N), dtype=lhsT.dtype, buffer=nl.shared_hbm)
 
   TILE_M = nl.tile_size.gemm_stationary_fmax  # 128
   TILE_K = nl.tile_size.pmax  # 128
@@ -243,11 +269,15 @@ def nki_matmul_block_free_dimension_(lhsT, rhs, result):
                           (n * TILES_IN_BLOCK_N + bn) * TILE_N + i_res.x],
                    value=res_sb)
 
+  return result
+  # NKI_EXAMPLE_20_END
 
+
+# NKI_EXAMPLE_21_BEGIN
+@nki.jit
 def nki_matmul_fully_optimized_(
     lhsT,
     rhs,
-    result,
     # Meta-parameters
     TILES_IN_BLOCK_M=16,
     TILES_IN_BLOCK_N=2,
@@ -264,13 +294,15 @@ def nki_matmul_fully_optimized_(
       rhs: an input tensor of shape [K,N],  where K is a multiple of 128 *
         TILES_IN_BLOCK_K and N is a multiple of 512 * TILES_IN_BLOCK_N.  It is
         the right-hand-side argument of the matrix multiplication.
-      result: the resulting output tensor of shape [M,N]
       TILES_IN_BLOCK_*: meta parameters to control blocking dimensions
+  Returns:
+      result: the resulting output tensor of shape [M,N]
   """
 
   K, M = lhsT.shape
   K_, N = rhs.shape
   assert K == K_, "lhsT and rhs must have the same contraction dimension"
+  result = nl.ndarray((M, N), dtype=lhsT.dtype, buffer=nl.shared_hbm)
 
   TILE_M = nl.tile_size.gemm_stationary_fmax  # 128
   TILE_K = nl.tile_size.pmax  # 128
@@ -360,16 +392,19 @@ def nki_matmul_fully_optimized_(
                         BLOCK_N * n + i_res_packed.x],
                  value=result_packed[i_res_packed.p, i_res_packed.x])
 
+  return result
+# NKI_EXAMPLE_21_END
+
 
+# NKI_EXAMPLE_23_BEGIN
 if __name__ == "__main__":
   # Benchmarking with large matrices to show the differences more clearly
   lhsT = nt.tensor[[8192, 4096], nl.bfloat16]
   rhs = nt.tensor[[8192, 8192], nl.bfloat16]
-  output = nt.tensor[[4096, 8192], nl.bfloat16]
 
   def benchmark_nki(nki_func):
     bench_func = nki.benchmark(warmup=5, iters=10)(nki_func)
-    bench_func(lhsT, rhs, output)
+    bench_func(lhsT, rhs)
     latency_res = bench_func.benchmark_result.nc_latency
     p99 = latency_res.get_latency_percentile(99)
     print("Latency: {:.2f} ms (P99)".format(p99 / 1000.0))
@@ -385,3 +420,4 @@ def benchmark_nki(nki_func):
 
   print("Benchmarking nki_matmul_fully_optimized")
   benchmark_nki(nki_matmul_fully_optimized_)
+  # NKI_EXAMPLE_23_END
diff --git a/src/tutorials/matrix_multiplication/matrix_multiplication_torch.py b/src/tutorials/matrix_multiplication/matrix_multiplication_torch.py
index ec0084c..de39ce8 100644
--- a/src/tutorials/matrix_multiplication/matrix_multiplication_torch.py
+++ b/src/tutorials/matrix_multiplication/matrix_multiplication_torch.py
@@ -7,23 +7,21 @@
 
 import torch
 from torch_xla.core import xla_model as xm
-from torch_neuronx import nki_jit
 
 from matrix_multiplication_nki_kernels import nki_matmul_basic_, nki_matmul_tiled_, nki_matmul_hoist_load_, nki_matmul_block_free_dimension_, nki_matmul_fully_optimized_
 
 if __name__ == "__main__":
 
+  # NKI_EXAMPLE_17_BEGIN
   device = xm.xla_device()
   cpu = torch.device('cpu')
 
   # Test the small workload with basic kernel
   lhs_small = torch.rand((64, 128), dtype=torch.bfloat16, device=device)
   rhs_small = torch.rand((128, 512), dtype=torch.bfloat16, device=device)
-  output_small = torch.zeros((64, 512), dtype=torch.bfloat16, device=device)
 
   # Run NKI kernel
-  nki_matmul_basic_jit = nki_jit(nki_matmul_basic_)
-  nki_matmul_basic_jit(lhs_small.T, rhs_small, output_small)
+  output_small = nki_matmul_basic_(lhs_small.T, rhs_small)
 
   # Run torch reference
   output_small_torch = torch.matmul(lhs_small, rhs_small)
@@ -34,18 +32,18 @@
     print("NKI and Torch match")
   else:
     print("NKI and Torch differ")
+    # NKI_EXAMPLE_17_END
 
+  # NKI_EXAMPLE_22_BEGIN
   # Test the large workload with tiled kernels
   lhs = torch.rand((4096, 1024), dtype=torch.bfloat16, device=device)
   rhs = torch.rand((1024, 2048), dtype=torch.bfloat16, device=device)
-  output = torch.zeros((4096, 2048), dtype=torch.bfloat16, device=device)
 
   # Run torch reference
   output_torch = torch.matmul(lhs, rhs).to(device=cpu)
 
   def check_match(nki_func):
-    jit_func = nki_jit(nki_func)
-    jit_func(lhs.T, rhs, output)
+    output = nki_func(lhs.T, rhs)
     output_nki = output.to(device=cpu)
     if torch.allclose(output_torch, output_nki, atol=1e-4, rtol=1e-2):
       print("NKI and Torch match")
@@ -63,3 +61,4 @@ def check_match(nki_func):
 
   print("Checking correctness of nki_matmul_fully_optimized")
   check_match(nki_matmul_fully_optimized_)
+  # NKI_EXAMPLE_22_END
diff --git a/src/tutorials/rmsnorm/rmsnorm_jax.py b/src/tutorials/rmsnorm/rmsnorm_jax.py
index 5b412d8..f0efc20 100644
--- a/src/tutorials/rmsnorm/rmsnorm_jax.py
+++ b/src/tutorials/rmsnorm/rmsnorm_jax.py
@@ -7,9 +7,9 @@
 
 import jax
 import jax.numpy as jnp
-from jax_neuronx import nki_call
 from rmsnorm_nki_kernels import nki_rmsnorm_kernel
 
+# NKI_EXAMPLE_44_BEGIN
 # Reference JAX implementation
 def jax_rms_norm(a_tensor, g_tensor):
   # Square the tensor (element-wise)
@@ -26,11 +26,7 @@ def jax_rms_norm(a_tensor, g_tensor):
 a_tensor = jax.random.uniform(a_key, (250, 512))
 g_tensor = jax.random.uniform(g_key, (512,))
 
-output_nki = nki_call(
-  nki_rmsnorm_kernel,
-  a_tensor, g_tensor,
-  out_shape=jax.ShapeDtypeStruct(a_tensor.shape, dtype=a_tensor.dtype),
-)
+output_nki = nki_rmsnorm_kernel(a_tensor, g_tensor)
 
 print(a_tensor)
 
@@ -43,3 +39,4 @@ def jax_rms_norm(a_tensor, g_tensor):
   print("NKI and JAX match")
 else:
   print("NKI and JAX differ")
+  # NKI_EXAMPLE_44_END
diff --git a/src/tutorials/rmsnorm/rmsnorm_nki_kernels.py b/src/tutorials/rmsnorm/rmsnorm_nki_kernels.py
index 140b682..402eecd 100644
--- a/src/tutorials/rmsnorm/rmsnorm_nki_kernels.py
+++ b/src/tutorials/rmsnorm/rmsnorm_nki_kernels.py
@@ -6,20 +6,23 @@
 """
 
 import numpy as np
+# NKI_EXAMPLE_42_BEGIN
 import math
 import neuronxcc.nki as nki
 import neuronxcc.nki.language as nl
 
 
-def nki_rmsnorm_kernel(a_tensor, g_tensor, out_tensor):
+@nki.jit
+def nki_rmsnorm_kernel(a_tensor, g_tensor):
   # Calculate out_tensor = a_tensor/RMS(a_tensor) * g_tensor
   # Where RMS(a_tensor) = sqrt((1/N) * sum(a_tensor * a_tensor))
   # and N = a_tensor.shape[1]
   # Reduction (mean) is performed in the free (2nd) dimension
+  out_tensor = nl.ndarray(a_tensor.shape, dtype=a_tensor.dtype,
+                          buffer=nl.shared_hbm)
 
   # Make sure shapes match
   assert a_tensor.shape[1] == g_tensor.shape[0]
-  assert a_tensor.shape == out_tensor.shape
 
   # Generate tensor indices to index input tensor
   ix = nl.arange(128)[:, None]
@@ -68,14 +71,15 @@ def nki_rmsnorm_kernel(a_tensor, g_tensor, out_tensor):
     nl.store(out_tensor[i * 128 + ix, iy], value=out_tile,
             mask=(i * 128 + ix < num_rows))
 
+  return out_tensor
+  # NKI_EXAMPLE_42_END
+
 
 if __name__ == "__main__":
   a = np.random.rand(128, 512).astype(np.float32)
   g = np.random.rand(512).astype(np.float32)
-  output_nki  = np.zeros(a.shape, dtype=a.dtype)
 
-  nki_rmsnorm_kernel_baremetal = nki.baremetal(nki_rmsnorm_kernel)
-  nki_rmsnorm_kernel_baremetal(a, g, output_nki)
+  output_nki = nki_rmsnorm_kernel(a, g)
   print(f"output_nki={output_nki}")
 
   # One-line numpy RMSNorm
diff --git a/src/tutorials/rmsnorm/rmsnorm_torch.py b/src/tutorials/rmsnorm/rmsnorm_torch.py
index 71ced3e..c9bfc69 100644
--- a/src/tutorials/rmsnorm/rmsnorm_torch.py
+++ b/src/tutorials/rmsnorm/rmsnorm_torch.py
@@ -5,11 +5,11 @@
 
 """
 
-from torch_neuronx.xla_impl.ops import nki_jit
 import torch
 import os
 from rmsnorm_nki_kernels import nki_rmsnorm_kernel
 
+# NKI_EXAMPLE_43_BEGIN
 # Reference torch implementation
 def torch_rmsnorm_kernel(a_tensor, g_tensor):
   # Square the tensor (element-wise)
@@ -25,13 +25,10 @@ def torch_rmsnorm_kernel(a_tensor, g_tensor):
 from torch_xla.core import xla_model as xm
 device = xm.xla_device()
 
-nki_rmsnorm_kernel = nki_jit(nki_rmsnorm_kernel)
-
 a_tensor = torch.rand((250, 512), dtype=torch.float32).to(device=device)
 g_tensor = torch.rand((512), dtype=torch.float32).to(device=device)
-output_nki = torch.zeros((250, 512), dtype=torch.float32).to(device=device)
 
-nki_rmsnorm_kernel(a_tensor, g_tensor, output_nki)
+output_nki = nki_rmsnorm_kernel(a_tensor, g_tensor)
 print(f"output_nki={output_nki}")
 
 output_torch = torch_rmsnorm_kernel(a_tensor, g_tensor)
@@ -41,3 +38,4 @@ def torch_rmsnorm_kernel(a_tensor, g_tensor):
   print("NKI and Torch match")
 else:
   print("NKI and Torch differ")
+# NKI_EXAMPLE_43_END
diff --git a/src/tutorials/sd_attention/sd_attention_nki_kernels.py b/src/tutorials/sd_attention/sd_attention_nki_kernels.py
index e5eec25..6d1f781 100644
--- a/src/tutorials/sd_attention/sd_attention_nki_kernels.py
+++ b/src/tutorials/sd_attention/sd_attention_nki_kernels.py
@@ -12,7 +12,9 @@
 import argparse
 from scipy.special import softmax
 
-def fused_self_attn_for_SD_small_head_size(q_ref, k_ref, v_ref, out_ref, use_causal_mask=False,
+# NKI_EXAMPLE_31_BEGIN
+@nki.jit
+def fused_self_attn_for_SD_small_head_size(q_ref, k_ref, v_ref, use_causal_mask=False,
                                            mixed_percision=True):
   """
   Fused self attention kernel for small head dimension Stable Diffusion workload, 
@@ -44,7 +46,7 @@ def fused_self_attn_for_SD_small_head_size(q_ref, k_ref, v_ref, out_ref, use_cau
   # Assume all IO tensors have the same dtype
   kernel_dtype = q_ref.dtype
   pe_in_dt = nl.bfloat16 if mixed_percision else np.float32
-  assert q_ref.dtype == k_ref.dtype == v_ref.dtype == out_ref.dtype
+  assert q_ref.dtype == k_ref.dtype == v_ref.dtype
 
   # Shape checking
   seqlen, d_head = q_ref.shape
@@ -53,7 +55,7 @@ def fused_self_attn_for_SD_small_head_size(q_ref, k_ref, v_ref, out_ref, use_cau
   assert tuple(k_ref.shape) == (seqlen, d_head), 'Input shape mismatch!'
   assert tuple(v_ref.shape) == (seqlen,d_head), \
   f'Input shape mismatch! Expected: {(seqlen, d_head)} Actual: {tuple(v_ref.shape)}'
-  assert tuple(out_ref.shape) == (seqlen, d_head), 'Output shape mismatch!'
+  out_ref = nl.ndarray((seqlen, d_head), dtype=q_ref.dtype, buffer=nl.shared_hbm)
 
   # Softmax scaling factor, multiplied onto Q
   softmax_scale = 0.125
@@ -210,58 +212,61 @@ def fused_self_attn_for_SD_small_head_size(q_ref, k_ref, v_ref, out_ref, use_cau
       out_ref[i_q_seq_tile * q_seq_tile_size + if_out, ip_out],
       value=attn_res_div)
 
+  return out_ref
+# NKI_EXAMPLE_31_END
+
 
 def parse_args():
-    parser = argparse.ArgumentParser("Run Stable Diffusion Attention NKI kernel.")
-    parser.add_argument("--mode",
-                        choices=["accuracy", "perf"],
-                        default="accuracy",
-                        help="""Do accuracy test or perf test.
-                                Accuracy test uses cpu golden output as golden reference.
-                             """)
+  parser = argparse.ArgumentParser("Run Stable Diffusion Attention NKI kernel.")
+  parser.add_argument("--mode",
+                      choices=["accuracy", "perf"],
+                      default="accuracy",
+                      help="""Do accuracy test or perf test.
+                              Accuracy test uses cpu golden output as golden reference.
+                           """)
+
+  args = parser.parse_args()
+  return args
 
-    args = parser.parse_args()
-    return args
 
 def cpu_golden_attn(q, k, v):
-    softmax_scale = 0.125
+  softmax_scale = 0.125
 
-    q_scaled = q * softmax_scale
-    raw_score = np.matmul(q_scaled, k.transpose())
-    norm_score = softmax(raw_score, axis=-1)
+  q_scaled = q * softmax_scale
+  raw_score = np.matmul(q_scaled, k.transpose())
+  norm_score = softmax(raw_score, axis=-1)
 
-    return np.matmul(norm_score, v)
+  return np.matmul(norm_score, v)
 
 
 if __name__ == "__main__":
-    args = parse_args()
-
-    print(f"Running {args.mode} mode.")
-
-    seqlen, d_head = 4096, 64
-    
-    # Set up input tensors
-    dtype = np.float32
-    q_tensor = np.random.rand(seqlen, d_head).astype(dtype)
-    k_tensor = np.random.rand(seqlen, d_head).astype(dtype)
-    v_tensor = np.random.rand(seqlen, d_head).astype(dtype)
-    output_nki = np.empty((seqlen, d_head), dtype=dtype)
-    output_golden = cpu_golden_attn(q_tensor, k_tensor, v_tensor)
-    
-    if args.mode == "accuracy":
-        nki.baremetal(fused_self_attn_for_SD_small_head_size)\
-                        (q_tensor, k_tensor, v_tensor, output_nki)
-        allclose = np.allclose(output_nki, output_golden, atol=1e-5, rtol=1e-3)
-        print(f">>>> SD attention matches CPU reference? {allclose}")
-        assert allclose, "Accuracy check fails!"
-
-    else:
-        benchmark_func = nki.benchmark(fused_self_attn_for_SD_small_head_size,
-                        save_neff_name='file.neff',
-                        save_trace_name='profile.ntff')
-        benchmark_func(q_tensor, k_tensor, v_tensor, output_nki)
-        
-        metrics = benchmark_func.benchmark_result.nc_latency
-        print(">>>> SD attention benchmark results")
-        print("latency.p50 = " + str(metrics.get_latency_percentile(50)))
-        print("latency.p99 = " + str(metrics.get_latency_percentile(99)))
\ No newline at end of file
+  args = parse_args()
+
+  print(f"Running {args.mode} mode.")
+
+  seqlen, d_head = 4096, 64
+
+  # Set up input tensors
+  dtype = np.float32
+  q_tensor = np.random.rand(seqlen, d_head).astype(dtype)
+  k_tensor = np.random.rand(seqlen, d_head).astype(dtype)
+  v_tensor = np.random.rand(seqlen, d_head).astype(dtype)
+  output_nki = np.empty((seqlen, d_head), dtype=dtype)
+  output_golden = cpu_golden_attn(q_tensor, k_tensor, v_tensor)
+
+  if args.mode == "accuracy":
+    output_nki = fused_self_attn_for_SD_small_head_size(q_tensor, k_tensor, v_tensor)
+    allclose = np.allclose(output_nki, output_golden, atol=1e-5, rtol=1e-3)
+    print(f">>>> SD attention matches CPU reference? {allclose}")
+    assert allclose, "Accuracy check fails!"
+
+  else:
+    benchmark_func = nki.benchmark(fused_self_attn_for_SD_small_head_size,
+                                   save_neff_name='file.neff',
+                                   save_trace_name='profile.ntff')
+    benchmark_func(q_tensor, k_tensor, v_tensor)
+
+    metrics = benchmark_func.benchmark_result.nc_latency
+    print(">>>> SD attention benchmark results")
+    print("latency.p50 = " + str(metrics.get_latency_percentile(50)))
+    print("latency.p99 = " + str(metrics.get_latency_percentile(99)))
\ No newline at end of file
diff --git a/src/tutorials/sd_attention/sd_attention_torch.py b/src/tutorials/sd_attention/sd_attention_torch.py
index f124607..639e5cf 100644
--- a/src/tutorials/sd_attention/sd_attention_torch.py
+++ b/src/tutorials/sd_attention/sd_attention_torch.py
@@ -5,8 +5,8 @@
 
 """
 
+# NKI_EXAMPLE_32_BEGIN
 import torch
-from torch_neuronx.xla_impl.ops import nki_jit
 from torch_xla.core import xla_model as xm
 
 from sd_attention_nki_kernels import fused_self_attn_for_SD_small_head_size
@@ -28,10 +28,8 @@ def cpu_golden_attn(q, k, v):
   q_tensor = torch.rand((4096, 64), dtype=torch.float32).to(device=device)
   k_tensor = torch.rand((4096, 64), dtype=torch.float32).to(device=device)
   v_tensor = torch.rand((4096, 64), dtype=torch.float32).to(device=device)
-  output_nki = torch.zeros((4096, 64), dtype=torch.float32).to(device=device)
 
-  nki_func = nki_jit(func=fused_self_attn_for_SD_small_head_size)
-  nki_func(q_tensor, k_tensor, v_tensor, output_nki)
+  output_nki = fused_self_attn_for_SD_small_head_size(q_tensor, k_tensor, v_tensor)
 
   output_torch = cpu_golden_attn(q_tensor, k_tensor, v_tensor)
 
@@ -42,4 +40,5 @@ def cpu_golden_attn(q, k, v):
   else:
     print("NKI and Torch differ")
 
-  assert allclose
\ No newline at end of file
+  assert allclose
+  # NKI_EXAMPLE_32_END
diff --git a/src/tutorials/tensor_addition/tensor_addition_jax.py b/src/tutorials/tensor_addition/tensor_addition_jax.py
index 9655b84..e40f962 100644
--- a/src/tutorials/tensor_addition/tensor_addition_jax.py
+++ b/src/tutorials/tensor_addition/tensor_addition_jax.py
@@ -4,42 +4,15 @@
 JAX implementation for tensor addition NKI tutorial.
 
 """
+# NKI_EXAMPLE_30_BEGIN
 import jax
 import jax.numpy as jnp
-from jax_neuronx import nki_call
+# NKI_EXAMPLE_30_END
 
-from tensor_addition_nki_kernels import nki_tensor_add_kernel_
-
-
-def nki_tensor_add(a_input, b_input):
-  """NKI kernel caller to compute element-wise addition of two input tensors
-
-  This kernel caller lifts tile-size restriction, by applying the kernel on tiles of the inputs/outputs
-
-  Args:
-      a_input: a first input tensor, of shape [N*128, M*512]
-      b_input: a second input tensor, of shape [N*128, M*512]
-
-  Returns:
-      a tensor of shape [N*128, M*512], the result of a_input + b_input
-  """
-
-  # The SPMD launch grid denotes the number of kernel instances.
-  # In this case, we use a 2D grid where the size of each invocation is 128x512
-  grid_x = a_input.shape[0] // 128
-  grid_y = a_input.shape[1] // 512
-
-  out_shape = jax.ShapeDtypeStruct((a_input.shape[0], a_input.shape[1]), dtype=a_input.dtype)
-
-  return nki_call(
-      nki_tensor_add_kernel_,
-      a_input,
-      b_input,
-      grid=(grid_x, grid_y),
-      out_shape=out_shape,
-    )
+from tensor_addition_nki_kernels import nki_tensor_add
 
 
+# NKI_EXAMPLE_30_BEGIN
 if __name__ == "__main__":
 
   seed_a, seed_b = jax.random.split(jax.random.PRNGKey(42))
@@ -59,3 +32,4 @@ def nki_tensor_add(a_input, b_input):
     print("NKI and JAX differ")
 
   assert allclose
+  # NKI_EXAMPLE_30_END
diff --git a/src/tutorials/tensor_addition/tensor_addition_nki_kernels.py b/src/tutorials/tensor_addition/tensor_addition_nki_kernels.py
index 2b49237..ea72488 100644
--- a/src/tutorials/tensor_addition/tensor_addition_nki_kernels.py
+++ b/src/tutorials/tensor_addition/tensor_addition_nki_kernels.py
@@ -5,20 +5,26 @@
 
 """
 import numpy as np
+# NKI_EXAMPLE_27_BEGIN
 import neuronxcc.nki as nki
 import neuronxcc.nki.language as nl
 
 
-def nki_tensor_add_kernel_(a_input, b_input, c_output):
+@nki.jit
+def nki_tensor_add_kernel_(a_input, b_input):
   """NKI kernel to compute element-wise addition of two input tensors
 
-  This kernel assumes strict input/output tile-sizes, of up-to [128,512]
+  This kernel assumes strict input/output sizes can be uniformly tiled to [128,512]
 
   Args:
-      a_input: a first input tensor, of shape [128,512]
-      b_input: a second input tensor, of shape [128,512]
-      c_output: an output tensor, of shape [128,512]
+      a_input: a first input tensor
+      b_input: a second input tensor
+
+  Returns:
+      c_output: an output tensor
   """
+  # Create output tensor shared between all SPMD instances as result tensor
+  c_output = nl.ndarray(a_input.shape, dtype=a_input.dtype, buffer=nl.shared_hbm)
 
   # Calculate tile offsets based on current 'program'
   offset_i_x = nl.program_id(0) * 128
@@ -39,7 +45,12 @@ def nki_tensor_add_kernel_(a_input, b_input, c_output):
   # store the addition results back to device memory (c_output)
   nl.store(c_output[ix, iy], value=c_tile)
 
+  # Transfer the ownership of `c_output` to the caller
+  return c_output
+  # NKI_EXAMPLE_27_END
+
 
+# NKI_EXAMPLE_28_BEGIN
 def nki_tensor_add(a_input, b_input):
   """NKI kernel caller to compute element-wise addition of two input tensors
 
@@ -57,12 +68,9 @@ def nki_tensor_add(a_input, b_input):
   # In this case, we use a 2D grid where the size of each invocation is 128x512
   grid_x = a_input.shape[0] // 128
   grid_y = a_input.shape[1] // 512
-  c_output = np.zeros(a_input.shape, dtype=a_input.dtype)
-
-  nki_tensor_add_kernel_baremetal = nki.baremetal(nki_tensor_add_kernel_)
-  nki_tensor_add_kernel_baremetal[grid_x, grid_y](a_input, b_input, c_output)
 
-  return c_output
+  return nki_tensor_add_kernel_[grid_x, grid_y](a_input, b_input)
+  # NKI_EXAMPLE_28_END
 
 
 if __name__ == "__main__":
diff --git a/src/tutorials/tensor_addition/tensor_addition_torch.py b/src/tutorials/tensor_addition/tensor_addition_torch.py
index 942e728..83673e5 100644
--- a/src/tutorials/tensor_addition/tensor_addition_torch.py
+++ b/src/tutorials/tensor_addition/tensor_addition_torch.py
@@ -4,38 +4,15 @@
 PyTorch implementation for tensor addition NKI tutorial.
 
 """
+# NKI_EXAMPLE_29_BEGIN
 import torch
 from torch_xla.core import xla_model as xm
-from torch_neuronx import nki_jit
+# NKI_EXAMPLE_29_END
 
-from tensor_addition_nki_kernels import nki_tensor_add_kernel_
+from tensor_addition_nki_kernels import nki_tensor_add
 
 
-def nki_tensor_add(a_input, b_input):
-  """NKI kernel caller to compute element-wise addition of two input tensors
-
-  This kernel caller lifts tile-size restriction, by applying the kernel on tiles of the inputs/outputs
-
-  Args:
-      a_input: a first input tensor, of shape [N*128, M*512]
-      b_input: a second input tensor, of shape [N*128, M*512]
-
-  Returns:
-      a tensor of shape [N*128, M*512], the result of a_input + b_input
-  """
-
-  # The SPMD launch grid denotes the number of kernel instances.
-  # In this case, we use a 2D grid where the size of each invocation is 128x512
-  grid_x = a_input.shape[0] // 128
-  grid_y = a_input.shape[1] // 512
-  c_output = torch.zeros(a_input.shape, dtype=a_input.dtype).to(device=device)
-
-  # Decorate the NKI kernel for PyTorch tracing
-  nki_tensor_add_kernel_torch = nki_jit(nki_tensor_add_kernel_)
-  nki_tensor_add_kernel_torch[grid_x, grid_y](a_input, b_input, c_output)
-
-  return c_output
-
+# NKI_EXAMPLE_29_BEGIN
 if __name__ == "__main__":
   device = xm.xla_device()
 
@@ -55,3 +32,4 @@ def nki_tensor_add(a_input, b_input):
     print("NKI and Torch differ")
 
   assert allclose
+  # NKI_EXAMPLE_29_END
diff --git a/src/tutorials/transpose2d/transpose2d_jax.py b/src/tutorials/transpose2d/transpose2d_jax.py
index 024782c..f23ceef 100644
--- a/src/tutorials/transpose2d/transpose2d_jax.py
+++ b/src/tutorials/transpose2d/transpose2d_jax.py
@@ -5,25 +5,18 @@
 
 """
 
+# NKI_EXAMPLE_36_BEGIN
 import jax
 import jax.numpy as jnp
-from functools import partial
-from jax_neuronx import nki_call
+# NKI_EXAMPLE_36_END
 
 from transpose2d_nki_kernels import tensor_transpose2D_kernel_
 
-
-def transpose2D(in_tensor, shape2D):
-  return nki_call(
-    partial(tensor_transpose2D_kernel_, shape2D=shape2D),
-    in_tensor,
-    out_shape=jax.ShapeDtypeStruct(in_tensor.shape, dtype=in_tensor.dtype)
-  )
-
+# NKI_EXAMPLE_36_BEGIN
 if __name__ == "__main__":
   P, X, Y = 5, 37, 44
   a = jax.random.uniform(jax.random.PRNGKey(42), (P, X * Y))
-  a_t_nki = transpose2D(a, (X, Y))
+  a_t_nki = tensor_transpose2D_kernel_(a, shape2D=(X, Y))
 
   a_t_jax = jnp.transpose(a.reshape(P, X, Y), axes=(0, 2, 1)).reshape(P, X * Y)
   print(a, a_t_nki, a_t_jax)
@@ -35,3 +28,4 @@ def transpose2D(in_tensor, shape2D):
     print("NKI and JAX differ")
 
   assert allclose
+# NKI_EXAMPLE_36_END
diff --git a/src/tutorials/transpose2d/transpose2d_nki_kernels.py b/src/tutorials/transpose2d/transpose2d_nki_kernels.py
index d993c7e..171e6ed 100644
--- a/src/tutorials/transpose2d/transpose2d_nki_kernels.py
+++ b/src/tutorials/transpose2d/transpose2d_nki_kernels.py
@@ -5,11 +5,13 @@
 """
 
 import numpy as np
+# NKI_EXAMPLE_33_BEGIN
 import neuronxcc.nki as nki
 import neuronxcc.nki.language as nl
 
 
-def tensor_transpose2D_kernel_(in_tensor, out_tensor, shape2D):
+@nki.jit
+def tensor_transpose2D_kernel_(in_tensor, shape2D):
   """
   NKI kernel to reorder the elements on axis[1] of the input tensor.
 
@@ -36,6 +38,8 @@ def tensor_transpose2D_kernel_(in_tensor, out_tensor, shape2D):
     shape2D: tuple representing the dimensions to be transposed: (#rows, #cols)
     out_tensor: an output (transposed) tensor
   """
+  out_tensor = nl.ndarray(in_tensor.shape, dtype=in_tensor.dtype,
+                          buffer=nl.shared_hbm)
   # Gather input shapes
   sz_p, _ = in_tensor.shape
 
@@ -64,14 +68,15 @@ def tensor_transpose2D_kernel_(in_tensor, out_tensor, shape2D):
   # Finally, we store out_tile to external memory
   nl.store(out_tensor, value=out_tile)
 
+  return out_tensor
+  # NKI_EXAMPLE_33_END
+
 
 if __name__ == "__main__":
   P, X, Y = 5, 3, 4
   a = np.arange(P*X*Y, dtype=np.int8).reshape((P, X*Y))
-  a_t_nki = np.zeros((P, Y*X), dtype=np.int8)
 
-  tensor_transpose2D_kernel_torch = nki.baremetal(tensor_transpose2D_kernel_)
-  tensor_transpose2D_kernel_torch(a, a_t_nki, (X, Y))
+  a_t_nki = tensor_transpose2D_kernel_(a, (X, Y))
 
   a_t_np = np.transpose(a.reshape(P, X, Y), (0, 2, 1)).reshape(P, X * Y)
 
diff --git a/src/tutorials/transpose2d/transpose2d_torch.py b/src/tutorials/transpose2d/transpose2d_torch.py
index 71083d7..61fe367 100644
--- a/src/tutorials/transpose2d/transpose2d_torch.py
+++ b/src/tutorials/transpose2d/transpose2d_torch.py
@@ -4,13 +4,15 @@
 PyTorch implementation for transpose2d NKI tutorial.
 """
 
+# NKI_EXAMPLE_34_BEGIN
 import torch
 from torch_xla.core import xla_model as xm
-from torch_neuronx import nki_jit
+# NKI_EXAMPLE_34_END
 
 from transpose2d_nki_kernels import tensor_transpose2D_kernel_
 
 
+# NKI_EXAMPLE_34_BEGIN
 if __name__ == "__main__":
   device = xm.xla_device()
 
@@ -18,8 +20,7 @@
   a = torch.arange(P*X*Y, dtype=torch.int8).reshape((P, X*Y)).to(device=device)
   a_t_nki = torch.zeros((P, Y*X), dtype=torch.int8).to(device=device)
 
-  tensor_transpose2D_kernel_torch = nki_jit(tensor_transpose2D_kernel_)
-  tensor_transpose2D_kernel_torch(a, a_t_nki, (X, Y))
+  a_t_nki = tensor_transpose2D_kernel_(a, (X, Y))
 
   a_t_torch = torch.transpose(a.reshape(P, X, Y), 1, 2).reshape(P, X * Y)
 
@@ -32,3 +33,4 @@
     print("NKI and PyTorch differ")
 
   assert allclose
+  # NKI_EXAMPLE_34_END
diff --git a/test/integration/flash_attention/flash_attention_benchmark.py b/test/integration/flash_attention/flash_attention_benchmark.py
index 5aa2e40..918a14f 100644
--- a/test/integration/flash_attention/flash_attention_benchmark.py
+++ b/test/integration/flash_attention/flash_attention_benchmark.py
@@ -14,6 +14,8 @@
 
 from flash_attention import nki_flash_attn_func
 
+parent_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
+sys.path.append(parent_dir)
 from perf_utils.LatencyCollector import benchmark
 
 if len(sys.argv) != 2:
diff --git a/test/integration/fused_sd_attention_small_head/sd2_512_benchmark.py b/test/integration/fused_sd_attention_small_head/sd2_512_benchmark.py
index e7fd205..5d63424 100644
--- a/test/integration/fused_sd_attention_small_head/sd2_512_benchmark.py
+++ b/test/integration/fused_sd_attention_small_head/sd2_512_benchmark.py
@@ -8,6 +8,8 @@
 from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
 from diffusers.models.unet_2d_condition import UNet2DConditionOutput
 
+parent_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
+sys.path.append(parent_dir)
 from perf_utils.LatencyCollector import benchmark
 
 import sys
diff --git a/test/integration/resize_nearest_fixed_dma_kernel/sd2_inpainting_936_624_benchmark.py b/test/integration/resize_nearest_fixed_dma_kernel/sd2_inpainting_936_624_benchmark.py
index 3ba0eab..4970f72 100644
--- a/test/integration/resize_nearest_fixed_dma_kernel/sd2_inpainting_936_624_benchmark.py
+++ b/test/integration/resize_nearest_fixed_dma_kernel/sd2_inpainting_936_624_benchmark.py
@@ -23,6 +23,8 @@
 else:
     from diffusers.models.cross_attention import CrossAttention
 
+parent_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
+sys.path.append(parent_dir)
 from perf_utils.LatencyCollector import benchmark
 
 import sys
diff --git a/test/unit/test_SD_attention_small_head.py b/test/unit/test_SD_attention_small_head.py
index 5480fa4..32e6945 100644
--- a/test/unit/test_SD_attention_small_head.py
+++ b/test/unit/test_SD_attention_small_head.py
@@ -11,7 +11,7 @@
 
 test_trace_file_path='local_trace.ntff'
 numeric_func = baremetal(fused_self_attn_for_SD_small_head_size)
-bench_func = benchmark(warmup=5, iters=10, save_trace_name=test_trace_file_path)(fused_self_attn_for_SD_small_head_size)
+bench_func = benchmark(warmup=5, iters=20, save_trace_name=test_trace_file_path)(fused_self_attn_for_SD_small_head_size)
 
 def cpu_golden_attn(q, k, v):
     softmax_scale = 0.125
@@ -34,16 +34,16 @@ def test_attention_for_SD_perf(self, bs, seqlen, d, dtype, latency):
         q = np.random.random_sample((bs, d, seqlen)).astype(np.float32)
         k = np.random.random_sample((bs, seqlen, d)).astype(np.float32)
         v = np.random.random_sample((bs, seqlen, d)).astype(np.float32)
-        out = nl.static_cast(np.ndarray(shape=(bs, seqlen, d)), dtype)
-        
+
         q_dev = nl.static_cast(q, dtype)
         k_dev = nl.static_cast(k, dtype)
         v_dev = nl.static_cast(v, dtype)
 
-        bench_func[bs](q_dev, k_dev, v_dev, out)
-        latency_res = bench_func.benchmark_result.nc_latency
-        p99 = latency_res.get_latency_percentile(99)
-        assert p99 <= latency
+        bench_func_ = bench_func[bs]
+        bench_func_(q_dev, k_dev, v_dev)
+        latency_res = bench_func_.benchmark_result.nc_latency
+        p50 = latency_res.get_latency_percentile(50)
+        assert p50 <= latency*1.05 # short running kernels are subjected to hardware fluctuation
         assert os.path.getsize(test_trace_file_path) > 0
 
     @pytest.mark.parametrize("bs,seqlen,d,dtype", [
@@ -54,13 +54,12 @@ def test_attention_for_SD_numberic(self, bs, seqlen, d, dtype):
         q = np.random.random_sample((bs, d, seqlen)).astype(np.float32)
         k = np.random.random_sample((bs, seqlen, d)).astype(np.float32)
         v = np.random.random_sample((bs, seqlen, d)).astype(np.float32)
-        out = nl.static_cast(np.ndarray(shape=(bs, seqlen, d)), dtype)
-        
+
         q_dev = nl.static_cast(q, dtype)
         k_dev = nl.static_cast(k, dtype)
         v_dev = nl.static_cast(v, dtype)
 
-        numeric_func[bs](q_dev, k_dev, v_dev, out)
+        out = numeric_func[bs](q_dev, k_dev, v_dev)
         out = nl.static_cast(out, np.float32)
         golden_result = cpu_golden_attn(q, k, v)
         assert np.allclose(out, golden_result, atol=1e-2)
diff --git a/test/unit/test_allocated_SD_attention_small_head.py b/test/unit/test_allocated_SD_attention_small_head.py
new file mode 100644
index 0000000..ee0de86
--- /dev/null
+++ b/test/unit/test_allocated_SD_attention_small_head.py
@@ -0,0 +1,67 @@
+"""
+Copyright (c) 2023, Amazon.com. All Rights Reserved
+"""
+import os
+import pytest
+from neuronxcc.nki.kernels.allocated_attention import allocated_fused_self_attn_for_SD_small_head_size
+from neuronxcc.nki import benchmark, baremetal
+import neuronxcc.nki as nki
+import neuronxcc.nki.language as nl
+import numpy as np
+from scipy.special import softmax
+
+test_trace_file_path='local_trace.ntff'
+numeric_func = baremetal(allocated_fused_self_attn_for_SD_small_head_size)
+bench_func = benchmark(warmup=5, iters=20, save_trace_name=test_trace_file_path)(allocated_fused_self_attn_for_SD_small_head_size)
+
+def cpu_golden_attn(q, k, v):
+    softmax_scale = 0.125
+    q_scaled = q * softmax_scale
+    raw_score = np.matmul(q_scaled.transpose(0, 2, 1), k)
+    
+    norm_score = softmax(raw_score, axis=-1)
+
+    # Transpose the result so it has the same layout as ours
+    return np.matmul(norm_score, v).transpose(0, 2, 1)
+
+class TestAttention:
+
+    @pytest.mark.parametrize("bs,seqlen,d,dtype,latency", [
+        [1, 4096, 128, np.float32, 410],
+        [1, 4096, 128, nl.bfloat16, 350],
+        [1, 5120, 128, nl.bfloat16, 586]
+    ])
+    def test_allocated_attention_for_SD_perf(self, bs, seqlen, d, dtype, latency):
+        q = np.random.random_sample((bs, d, seqlen)).astype(np.float32)
+        k = np.random.random_sample((bs, d, seqlen)).astype(np.float32)
+        v = np.random.random_sample((bs, seqlen, d)).astype(np.float32)
+
+        q_dev = nl.static_cast(q, dtype)
+        k_dev = nl.static_cast(k, dtype)
+        v_dev = nl.static_cast(v, dtype)
+
+        bench_func_ = bench_func[bs]
+        bench_func_(q_dev, k_dev, v_dev)
+        latency_res = bench_func_.benchmark_result.nc_latency
+        p50 = latency_res.get_latency_percentile(50)
+        assert p50 <= latency * 1.05 # short running kernels are subjected to hardware fluctuation
+        assert os.path.getsize(test_trace_file_path) > 0
+
+    @pytest.mark.parametrize("bs,seqlen,d,dtype", [
+        [1, 4096, 128, np.float32],
+        [1, 4096, 128, nl.bfloat16],
+        [1, 5120, 128, nl.bfloat16]
+    ])
+    def test_allocated_attention_for_SD_numberic(self, bs, seqlen, d, dtype):
+        q = np.random.random_sample((bs, d, seqlen)).astype(np.float32)
+        k = np.random.random_sample((bs, d, seqlen)).astype(np.float32)
+        v = np.random.random_sample((bs, seqlen, d)).astype(np.float32)
+
+        q_dev = nl.static_cast(q, dtype)
+        k_dev = nl.static_cast(k, dtype)
+        v_dev = nl.static_cast(v, dtype)
+
+        out = numeric_func[bs](q_dev, k_dev, v_dev)
+        out = nl.static_cast(out, np.float32)
+        golden_result = cpu_golden_attn(q, k, v)
+        assert np.allclose(out, golden_result, atol=1e-2)
diff --git a/test/unit/test_flash_attn_bwd.py b/test/unit/test_flash_attn_bwd.py
index a55abbe..3aedab0 100644
--- a/test/unit/test_flash_attn_bwd.py
+++ b/test/unit/test_flash_attn_bwd.py
@@ -7,6 +7,8 @@
 import neuronxcc.nki.language as nl
 import numpy as np
 
+from TestDecorators import xfail
+
 numeric_func = baremetal(flash_attn_bwd)
 bench_func = benchmark(warmup=5, iters=10)(flash_attn_bwd)
 
@@ -85,6 +87,7 @@ def mixed_precision_matmul(a, b):
 
 class TestAttention:
 
+    @xfail # P167481231
     @pytest.mark.parametrize("bs, nheads, seqlen, d, dtype, latency", [
         [1, 4, 32*1024, 128, nl.bfloat16, 117000],
     ])
@@ -97,23 +100,16 @@ def test_flash_attn_bwd_perf(self, bs, nheads, seqlen, d, dtype, latency):
         lse = np.random.random_sample([bs, nheads, nl.tile_size.pmax, seqlen // nl.tile_size.pmax]).astype(np.float32)
         seed = None
 
-        out_dq = np.zeros(shape=[bs, nheads, d, seqlen], dtype=dtype)
-        out_dk = np.zeros(shape=[bs, nheads, d, seqlen], dtype=dtype)
-        out_dv = np.zeros(shape=[bs, nheads, d, seqlen], dtype=dtype)
-        
         q = nl.static_cast(q, dtype)
         k = nl.static_cast(k, dtype)
         v = nl.static_cast(v, dtype)
         o_proj = nl.static_cast(o_proj, dtype)
         dy = nl.static_cast(dy, dtype)  
-        out_dq = nl.static_cast(out_dq, dtype)
-        out_dk = nl.static_cast(out_dk, dtype)
-        out_dv = nl.static_cast(out_dv, dtype)
-
-        bench_func[bs, nheads](q, k, v, o_proj, dy, lse, seed,
-                               out_dq, out_dk, out_dv,
-                               use_causal_mask=True, mixed_precision=True)
-        latency_res = bench_func.benchmark_result.nc_latency
+
+        bench_func_ = bench_func[bs, nheads]
+        bench_func_(q, k, v, o_proj, dy, lse, seed,
+                    use_causal_mask=True, mixed_precision=True)
+        latency_res = bench_func_.benchmark_result.nc_latency
         p99 = latency_res.get_latency_percentile(99)
         assert p99 <= latency
 
@@ -130,10 +126,7 @@ def test_flash_attn_bwd_numerical(self, bs, nheads, seqlen, d, dtype):
         v = nl.static_cast(v, dtype)
         dy = nl.static_cast(dy, dtype)
         seed = None
-        out_dq = np.zeros(shape=[bs, nheads, d, seqlen], dtype=dtype)
-        out_dk = np.zeros(shape=[bs, nheads, d, seqlen], dtype=dtype)
-        out_dv = np.zeros(shape=[bs, nheads, d, seqlen], dtype=dtype)
-  
+
         dq_golden, dk_golden, dv_golden, cached_negative_max, cached_sum_reciprocal, o_proj = \
           cpu_attention_backward(q, k, v, dy, use_causal_mask=True)
         cached_negative_max = cached_negative_max.reshape(bs, nheads, seqlen // nl.tile_size.pmax,
@@ -142,9 +135,9 @@ def test_flash_attn_bwd_numerical(self, bs, nheads, seqlen, d, dtype):
                                                               nl.tile_size.pmax).transpose(0, 1, 3, 2)
         lse = -1.0 * (cached_negative_max + np.log(cached_sum_reciprocal))
 
-        numeric_func[bs, nheads](q, k, v, o_proj, dy, lse, seed,
-                                 out_dq, out_dk, out_dv,
-                                 use_causal_mask=True, mixed_precision=True)
+        out_dq, out_dk, out_dv = numeric_func[bs, nheads](q, k, v, o_proj, dy, lse, seed,
+                                                          use_causal_mask=True,
+                                                          mixed_precision=True)
 
         assert np.allclose(out_dq, dq_golden, atol=1e-2)
         assert np.allclose(out_dk, dk_golden, atol=1e-2)
diff --git a/test/unit/test_flash_attn_fwd.py b/test/unit/test_flash_attn_fwd.py
index 4d91164..fff4ac2 100644
--- a/test/unit/test_flash_attn_fwd.py
+++ b/test/unit/test_flash_attn_fwd.py
@@ -63,75 +63,84 @@ def mixed_precision_matmul(a, b):
  
 class TestAttention:
  
-    @pytest.mark.parametrize("bs, nheads, seqlen, d, dtype, use_causal_mask,\
+    @pytest.mark.parametrize("bs, nheads, seqlen_q, seqlen_k, d, dtype, use_causal_mask,\
                               mixed_precision, training, tile_size, kv_heads, should_transpose_v, latency", [
-    [1, 6, 32*1024, 96, nl.bfloat16, True, True, True, 2048, 3, False, 87000000000],
-    [1, 1, 32*1024, 96, nl.bfloat16, True, True, False, 2048, None, False, 15100000000],
+    [1, 6, 32*1024, 32*1024, 96, nl.bfloat16, True, True, True, 2048, 3, False, 87000000000],
+    [1, 1, 32*1024, 32*1024, 96, nl.bfloat16, True, True, False, 2048, None, False, 15100000000],
+    # Non-square
+    [1, 3, 32*1024, 16*1024, 96, nl.bfloat16, True, True, False, 2048, None, False, 7550000000],
+    [1, 3, 16*1024, 32*1024, 96, nl.bfloat16, True, True, False, 2048, None, False, 7550000000],
     ])
-    def test_flash_attn_fwd_perf(self, bs, nheads, seqlen, d, dtype, use_causal_mask, 
+    def test_flash_attn_fwd_perf(self, bs, nheads, seqlen_q, seqlen_k, d, dtype, use_causal_mask, 
                                  mixed_precision, training, tile_size, kv_heads, should_transpose_v,latency):
-        q = (np.random.random_sample([bs, nheads, d, seqlen]) - 0.5) * 2
-        k = (np.random.random_sample([bs, nheads, d, seqlen]) - 0.5) * 2
+        q = (np.random.random_sample([bs, nheads, d, seqlen_q]) - 0.5) * 2
+        k = (np.random.random_sample([bs, nheads, d, seqlen_k]) - 0.5) * 2
         if should_transpose_v:
-            v = (np.random.random_sample([bs, nheads, d, seqlen]) - 0.5) * 2
+            v = (np.random.random_sample([bs, nheads, d, seqlen_k]) - 0.5) * 2
         else:
-            v = (np.random.random_sample([bs, nheads, seqlen, d]) - 0.5) * 2
-        o_proj = np.zeros(shape=[bs, nheads, seqlen, d], dtype=dtype)
-        out_lse = np.zeros(shape=[bs, nheads, int(nl.tile_size.pmax), seqlen // nl.tile_size.pmax], 
+            v = (np.random.random_sample([bs, nheads, seqlen_k, d]) - 0.5) * 2
+        o_proj = np.zeros(shape=[bs, nheads, seqlen_q, d], dtype=dtype)
+        out_lse = np.zeros(shape=[bs, nheads, int(nl.tile_size.pmax), seqlen_q // nl.tile_size.pmax], 
                                   dtype=nl.float32 if mixed_precision else dtype) if training else None
         seed = None
         
         q = nl.static_cast(q, dtype)
         k = nl.static_cast(k, dtype)
         v = nl.static_cast(v, dtype)
-        o_proj = nl.static_cast(o_proj, dtype)
         config = FlashConfig(**{'seq_tile_size':tile_size, 'training':training, 'should_transpose_v':should_transpose_v})
 
         heads = nheads if kv_heads is None else kv_heads
 
-        bench_func[bs, heads](q, k, v, seed, o_proj, out_lse,
-                               use_causal_mask=use_causal_mask, mixed_precision=mixed_precision, config=config)
-        latency_res = bench_func.benchmark_result.nc_latency
+        bench_func_ = bench_func[bs, heads]
+        bench_func_(q, k, v, seed, use_causal_mask=use_causal_mask,
+                    mixed_precision=mixed_precision, config=config)
+        latency_res = bench_func_.benchmark_result.nc_latency
         p99 = latency_res.get_latency_percentile(99)
         assert p99 <= latency
  
-    @pytest.mark.parametrize("bs, nheads, seqlen, d, dtype, use_causal_mask,\
+    @pytest.mark.parametrize("bs, nheads, seqlen_q, seqlen_k, d, dtype, use_causal_mask,\
                               training, tile_size, kv_heads, should_transpose_v", [
-    [1, 6, 4096, 128, np.float32, True, True, 2048, 3, False],
-    [1, 1, 4096, 128, np.float32, True, False, 2048, None, False],
+    [1, 6, 4096, 4096, 128, np.float32, True, True, 2048, 3, False],
+    [1, 1, 4096, 4096, 128, np.float32, True, False, 2048, None, False],
+    [1, 1, 8192, 4096, 128, np.float32, True, False, 2048, None, False],
+    [1, 1, 4096, 8192, 128, np.float32, True, False, 2048, None, False],
     ])
-    def test_flash_attn_fwd_numerical(self, bs, nheads, seqlen, d, dtype, use_causal_mask, 
+    def test_flash_attn_fwd_numerical(self, bs, nheads, seqlen_q, seqlen_k, d, dtype, use_causal_mask, 
                                      training, tile_size, kv_heads, should_transpose_v):
-        q = (np.random.random_sample([bs, nheads, d, seqlen]) - 0.5) * 2
-        k = (np.random.random_sample([bs, kv_heads or nheads, d, seqlen]) - 0.5) * 2
+        q = (np.random.random_sample([bs, nheads, d, seqlen_q]) - 0.5) * 2
+        k = (np.random.random_sample([bs, kv_heads or nheads, d, seqlen_k]) - 0.5) * 2
         if should_transpose_v:
-            v = (np.random.random_sample([bs, nheads, d, seqlen]) - 0.5) * 2
+            v = (np.random.random_sample([bs, nheads, d, seqlen_k]) - 0.5) * 2
             cpu_permute = (0, 1, 2, 3)
         else:
-            v = (np.random.random_sample([bs, kv_heads or nheads, seqlen, d]) - 0.5) * 2
+            v = (np.random.random_sample([bs, kv_heads or nheads, seqlen_k, d]) - 0.5) * 2
             cpu_permute = (0, 1, 3, 2)
-        o_proj = np.zeros(shape=[bs, nheads, seqlen, d], dtype=dtype)
+
         q = nl.static_cast(q, dtype)
         k = nl.static_cast(k, dtype)
         v = nl.static_cast(v, dtype)
         seed = None
-        out_lse = np.zeros(shape=[bs, nheads, int(nl.tile_size.pmax), seqlen // nl.tile_size.pmax], 
-                                  dtype=np.float32) if training else None
 
         o_proj_golden, cached_negative_max, cached_sum_reciprocal  = \
           cpu_attention_forward(q, k, v.transpose(cpu_permute), use_causal_mask=use_causal_mask,mixed_precision=True)
         o_proj_golden = o_proj_golden.transpose(0,1,3,2) # (b,h, d, seq)
-        cached_negative_max = cached_negative_max.reshape(bs, nheads, seqlen // nl.tile_size.pmax,
+        cached_negative_max = cached_negative_max.reshape(bs, nheads, seqlen_q // nl.tile_size.pmax,
                                                           nl.tile_size.pmax).transpose(0, 1, 3, 2)
-        cached_sum_reciprocal = cached_sum_reciprocal.reshape(bs, nheads, seqlen // nl.tile_size.pmax,
+        cached_sum_reciprocal = cached_sum_reciprocal.reshape(bs, nheads, seqlen_q // nl.tile_size.pmax,
                                                               nl.tile_size.pmax).transpose(0, 1, 3, 2)
         lse_golden = -1.0 * (cached_negative_max + np.log(cached_sum_reciprocal)) if training else None
         config = FlashConfig(**{'seq_tile_size':tile_size, 'training':training, 'should_transpose_v':should_transpose_v})
 
         heads = nheads if kv_heads is None else kv_heads
-        numeric_func[bs, heads](q, k, v, seed, o_proj, out_lse, seed,
-                                 use_causal_mask=use_causal_mask, mixed_precision=True, config=config)
+        results = numeric_func[bs, heads](q, k, v, seed,
+                                          use_causal_mask=use_causal_mask,
+                                          mixed_precision=True,
+                                          config=config)
 
-        assert np.allclose(o_proj, o_proj_golden, atol=1e-2)
         if training:
+            o_proj, out_lse = results
+            assert np.allclose(o_proj, o_proj_golden, atol=1e-2)
             assert np.allclose(out_lse, lse_golden, atol=1e-2)
+        else:
+            o_proj = results
+            assert np.allclose(o_proj, o_proj_golden, atol=1e-2)
diff --git a/test/unit/test_resize_nearest.py b/test/unit/test_resize_nearest.py
index a77968b..2bbc601 100644
--- a/test/unit/test_resize_nearest.py
+++ b/test/unit/test_resize_nearest.py
@@ -11,6 +11,7 @@
 numeric_func = baremetal(resize_nearest_fixed_dma_kernel)
 bench_func = benchmark(warmup=5, iters=10)(resize_nearest_fixed_dma_kernel)
 
+
 def cpu_golden_result(data_tensor, output_shape):
     in_b, in_h, in_w, in_c = data_tensor.shape
     out_b, out_h, out_w, out_c = output_shape
@@ -36,18 +37,18 @@ def cpu_golden_result(data_tensor, output_shape):
 class TestResizeNearest:
 
     @pytest.mark.parametrize("in_b, in_h, in_w, in_c, out_b, out_h, out_w, out_c, dtype, latency", [
- 	    [10, 30, 20, 1280, 10, 59, 38, 1280, np.float32, 1722],
+ 	    [10, 30, 20, 1280, 10, 59, 38, 1280, np.float32, 1740],
         [1, 30, 20, 1280, 1, 59, 38, 1280, nl.float16, 659],
         [1, 30, 20, 1280, 1, 59, 38, 1280, nl.bfloat16, 659],
  	])
     def test_resize_nearest_for_perf(self, in_b, in_h, in_w, in_c, out_b, out_h, out_w, out_c, dtype, latency):
         input_tensor = np.random.random_sample((in_b, in_h, in_w, in_c)).astype(np.float32)
-        output_tensor = nl.static_cast(np.ndarray(shape=(out_b, out_h, out_w, out_c)), dtype)
-        
+
         input_dev = nl.static_cast(input_tensor, dtype)
 
-        bench_func[in_b](input_dev, output_tensor)
-        latency_res = bench_func.benchmark_result.nc_latency
+        bench_func_ = bench_func[in_b]
+        bench_func_(input_dev, (out_b, out_h, out_w, out_c))
+        latency_res = bench_func_.benchmark_result.nc_latency
         p99 = latency_res.get_latency_percentile(99)
         assert p99 <= latency
 
@@ -58,11 +59,10 @@ def test_resize_nearest_for_perf(self, in_b, in_h, in_w, in_c, out_b, out_h, out
  	])
     def test_resize_nearest_for_numberic(self, in_b, in_h, in_w, in_c, out_b, out_h, out_w, out_c, dtype):
         input_tensor = np.random.random_sample((in_b, in_h, in_w, in_c)).astype(np.float32)
-        output_tensor = nl.static_cast(np.ndarray(shape=(out_b, out_h, out_w, out_c)), dtype)
-        
+
         input_dev = nl.static_cast(input_tensor, dtype)
 
-        numeric_func[in_b](input_dev, output_tensor)
+        output_tensor = numeric_func[in_b](input_dev, (out_b, out_h, out_w, out_c))
         output_tensor = nl.static_cast(output_tensor, np.float32)
         golden_result = cpu_golden_result(input_tensor, output_tensor.shape)
         assert np.allclose(output_tensor, golden_result, atol=1e-2)
diff --git a/test/unit/test_rmsnorm_qkv.py b/test/unit/test_rmsnorm_qkv.py
new file mode 100644
index 0000000..24ad31c
--- /dev/null
+++ b/test/unit/test_rmsnorm_qkv.py
@@ -0,0 +1,65 @@
+"""
+Copyright (c) 2024, Amazon.com. All Rights Reserved
+"""
+import pytest
+from neuronxcc.nki.kernels.allocated_fused_linear import allocated_fused_rms_norm_qkv
+from neuronxcc.nki import benchmark, baremetal
+import neuronxcc.nki.language as nl
+import numpy as np
+
+numeric_func = baremetal(allocated_fused_rms_norm_qkv)
+bench_func = benchmark(warmup=5, iters=10)(allocated_fused_rms_norm_qkv)
+
+np.random.seed(0)
+
+
+def rms_norm(hidden, gamma, eps=1e-6):
+  rms = np.sqrt(np.mean(np.square(hidden), axis=-1, keepdims=True))
+  norm = hidden * np.reciprocal(rms + eps)
+  if gamma is not None:
+    norm *= gamma
+  return norm
+
+def cpu_golden_result(hidden, gamma, qkv_weights, dtype, do_norm=True):
+  if do_norm:
+      hidden = rms_norm(hidden, gamma)
+  qkv_out = (hidden @ qkv_weights).astype(dtype)
+  return qkv_out
+
+class TestRMSNormQKV:
+  @pytest.mark.parametrize("batch, seqlen, dim, d_head, dtype, latency", [
+    [1, 128, 512, 512, np.float16, 25],
+    [1, 512, 1024, 384, nl.bfloat16, 40],
+    [1, 128, 1024, 512, nl.bfloat16, 28],
+    [1, 1024, 8192, 512, nl.bfloat16, 301 * 1.02]
+  ])
+  def test_allocated_rmsnorm_qkv_perf(self, batch, seqlen, dim, d_head, dtype, latency):
+    hidden = np.random.random_sample((batch, seqlen, dim)).astype(np.float32)
+    weights = np.random.random_sample((dim, d_head)).astype(np.float32)
+
+    hidden = nl.static_cast(hidden, dtype)
+    weights = nl.static_cast(weights, dtype)
+
+    bench_func(hidden, weights)
+    latency_res = bench_func.benchmark_result.nc_latency
+    p99 = latency_res.get_latency_percentile(99)
+    assert p99 <= latency
+
+  @pytest.mark.parametrize("batch, seqlen, dim, d_head, dtype", [
+    [1, 128, 512, 512, np.float16],
+    [1, 512, 1024, 384, nl.bfloat16],
+    [1, 128, 1024, 512, nl.bfloat16],
+    [1, 1024, 8192, 512, nl.bfloat16]
+  ])
+  def test_allocated_rmsnorm_qkv_numeric(self, batch, seqlen, dim, d_head, dtype):
+    hidden = np.random.random_sample((batch, seqlen, dim))
+    weights = np.random.random_sample((dim, d_head))
+
+    hidden_dev = nl.static_cast(hidden, dtype)
+    weights_dev = nl.static_cast(weights, dtype)
+
+    out = numeric_func(hidden_dev, weights_dev)
+    out = nl.static_cast(out, np.float32)
+    golden_res = nl.static_cast(cpu_golden_result(hidden, None, weights, dtype, do_norm=True), np.float32)
+    assert np.allclose(out, golden_res, atol=1e-2, rtol=1e-2)
+
diff --git a/test/unit/test_select_and_scatter.py b/test/unit/test_select_and_scatter.py
index 70f7a7c..fc99b37 100644
--- a/test/unit/test_select_and_scatter.py
+++ b/test/unit/test_select_and_scatter.py
@@ -39,7 +39,6 @@ def cpu_golden_result(operand_tensor, source_tensor, window_dimensions=(3, 3), w
                     out_h = h * stride_h + local_h - padding[0]
                     out_w = w * stride_w + local_w - padding[1]
                     output_tensor[n, c, out_h, out_w] += source_tensor[n, c, h, w]
-
     return output_tensor
 
 class TestSelectAndScatter:
@@ -47,31 +46,28 @@ class TestSelectAndScatter:
  	    [8, 64, 112, 112, 56, 56, np.float32, 4500],
  	])
     def test_select_and_scatter_for_perf(self, n, c, operand_h, operand_w, source_h, source_w, dtype, latency):
-        operand_tensor = np.random.random_sample((n, c, operand_h, operand_w)).astype(np.float32)
-        source_tensor = np.random.random_sample((n, c, source_h, source_w)).astype(np.float32)
-        output_tensor = nl.static_cast(np.ndarray(shape=(n, c, operand_h, operand_w)), dtype)
-        
-        operand_dev = nl.static_cast(operand_tensor, dtype)
-        source_dev = nl.static_cast(source_tensor, dtype)
+        operand_dev = nl.static_cast(np.random.random_sample((n, c, operand_h, operand_w)), dtype)
+        source_dev = nl.static_cast(np.random.random_sample((n, c, source_h, source_w)), dtype)
 
-        bench_func(operand_dev, source_dev, output_tensor)
+        bench_func(operand_dev, source_dev)
         latency_res = bench_func.benchmark_result.nc_latency
         p99 = latency_res.get_latency_percentile(99)
         assert p99 <= latency
 
     @pytest.mark.parametrize("n, c, operand_h, operand_w, source_h, source_w, dtype", [
  	    [8, 64, 112, 112, 56, 56, np.float32],
- 	    pytest.param(8, 64, 112, 112, 56, 56, nl.bfloat16, marks=pytest.mark.xfail),
+ 	    [8, 64, 112, 112, 56, 56, nl.bfloat16],
  	])
     def test_select_and_scatter_for_numeric(self, n, c, operand_h, operand_w, source_h, source_w, dtype):
-        operand_tensor = np.random.random_sample((n, c, operand_h, operand_w)).astype(np.float32)
-        source_tensor = np.random.random_sample((n, c, source_h, source_w)).astype(np.float32)
-        output_tensor = nl.static_cast(np.ndarray(shape=(n, c, operand_h, operand_w)), dtype)
-        
-        operand_dev = nl.static_cast(operand_tensor, dtype)
-        source_dev = nl.static_cast(source_tensor, dtype)
+        operand_dev = nl.static_cast(np.random.random_sample((n, c, operand_h, operand_w)), dtype)
+        source_dev = nl.static_cast(np.random.random_sample((n, c, source_h, source_w)), dtype)
+
+        sw = nl.static_cast(np.ndarray(shape=(n, c, source_h, source_w, 3, 3)), dtype)
+        operand_tensor = nl.static_cast(operand_dev, np.float32)
+        source_tensor = nl.static_cast(source_dev, np.float32)
 
-        numeric_func(operand_dev, source_dev, output_tensor)
+        output_dev = numeric_func(operand_dev, source_dev)
         golden_result = cpu_golden_result(operand_tensor, source_tensor)
-        output_tensor = nl.static_cast(output_tensor, np.float32)
-        assert np.allclose(output_tensor, golden_result)
\ No newline at end of file
+        nki_result = nl.static_cast(output_dev, np.float32)
+
+        assert np.allclose(nki_result, golden_result, rtol=1e-2, atol=1e-2)

From bdeca88513a228c182a2b46a9b124c3fbdf0eab0 Mon Sep 17 00:00:00 2001
From: aws-qieqingy <122939906+aws-qieqingy@users.noreply.github.com>
Date: Fri, 13 Dec 2024 20:52:29 -0500
Subject: [PATCH 2/3] Recover documentation change overwritten by code push
 (#7)

---
 CONTRIBUTING.md | 69 ++++++++++++++++++++++++++++++++----
 LICENSE         | 16 +++++++++
 LICENSE.txt     |  1 -
 README.md       | 93 ++++++++++---------------------------------------
 4 files changed, 97 insertions(+), 82 deletions(-)
 create mode 100644 LICENSE
 delete mode 100644 LICENSE.txt

diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index 32ce44e..4c16260 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -1,6 +1,6 @@
 # Contributing Guidelines
 
-Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional
+Thank you for your interest in contributing to our project. Whether it's a new NKI kernel, improving existing kernel code, bug fix, new feature, correction, or additional
 documentation, we greatly value feedback and contributions from our community.
 
 Please read through this document before submitting any issues or pull requests to ensure we have all the necessary
@@ -24,14 +24,13 @@ reported the issue. Please try to include as much information as you can. Detail
 Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that:
 
 1. You are working against the latest source on the *main* branch.
-2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already.
-3. You open an issue to discuss any significant work - we would hate for your time to be wasted.
+2. You check existing open, and recently merged pull requests to make sure someone else hasn't addressed the problem already.
 
 To send us a pull request, please:
 
 1. Fork the repository.
-2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change.
-3. Ensure local tests pass.
+2. Modify the source; please focus on the specific changes you are contributing. If you also reformat all the code, it will be hard for us to focus on your change.
+3. Please ensure your change satisfies the requirements listed in [Testing Requirements](#testing-requirements) and [Coding Guidelines](#coding-guidelines)
 4. Commit to your fork using clear commit messages.
 5. Send us a pull request, answering any default questions in the pull request interface.
 6. Wait for a repository collaborator to look at your pull request, run the automated tests, and review. If additional changes or discussion is needed, a collaborator will get back to you, so please stay involved in the conversation.
@@ -40,8 +39,64 @@ To send us a pull request, please:
 GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and
 [creating a pull request](https://help.github.com/articles/creating-a-pull-request/).
 
+### Testing Requirements
+Running the binaries for a NKI kernel require Neuron devices on an AWS EC2 instance from trn1, trn1n, or inf2 instance families. 
+Details on setting up an instance can be found in [here](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/frameworks/torch/torch-setup.html).
 
-## Finding contributions to work on
+If you would like to test your kernel without requiring a Neuron device, you can use `nki.simulate()` to run your kernel using `NumPy` input/output tensors and types. 
+An example can be found in the [layernorm tutorial test](test/unit/test_tutorials_layernorm.py). However, kernels with _only_ simulation tests will not be accepted.
+
+#### Requirements for Kernels Targeting `src/reference/`
+
+All kernels located in this folder need to have the following tests.
+
+1. Numeric accuracy tests with `nki.baremetal`. The output from the kernel
+must be validated against a CPU reference implementation. See `test_flash_attn_fwd_numerical` in [test_flash_attn_fwd.py](test/unit/test_flash_attn_fwd.py) as an example. Documentation for `nki.baremetal` is available at [here](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/nki/api/generated/nki.baremetal.html).
+
+2. Performance benchmark tests with `nki.benchmark`. The unit test must have performance checks. At a minimum, put an assertion to verify p99 latency meets a certain threshold. See `test_flash_attn_fwd_perf` in [test_flash_attn_fwd.py](test/unit/test_flash_attn_fwd.py) as an example. Documentation for `nki.benchmark` is available at [here](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/nki/api/generated/nki.benchmark.html)
+
+3. End-to-End integration tests that use your kernel in a model. 
+    
+    a. Each test should be in its own separate folder.
+
+    b. Each Test must have a `run.sh` script, that accepts an argument \<path to test_result.json\>. See [run.sh of FlashAttention](test/integration/flash_attention/run.sh) as an example. 
+
+    c. The test scripts must produce benchmark results with the `benchmark` function, located in [LatencyCollector.py](test/integration/perf_utils/LatencyCollector.py). The `benchmark` function will write the latency of your E2E model to the `test_result.json`.
+
+    d. Register your test target in [run_integration.sh](test/integration/run_integration.sh).
+
+
+### Coding Guidelines
+Most guidelines are covered by a **PEP-8** check on all newly submitted code, which covers aspects such as code layout and basic Python naming conventions. 
+In addition to PEP-8, we use the following NKI specific style guidelines:
+
+1. **Abbreviations**
+    * Importing NKI modules should use consistent names. For example,
+    ```
+    import neuronxcc.nki as nki
+    import neuronxcc.nki.isa as nisa
+    import neuronxcc.nki.language as nl
+    import neuronxcc.nki.typing as nt
+    import numpy as np
+    ```   
+2. Variable Names
+    * Indexing should specify partition and free dimensions along with the variable they are used for. For example:
+        The index for the partition dimension for tile `a` would be
+        ```
+        i_p_a = nl.arange(128)[:, None]
+        ```
+        while the index for the free dimension for tile `b` would be
+        ```
+        i_f_b = nl.arange(512)[None, :]
+        ```
+    * Name loop variables, indices, and buffers consistently, and specify their intended use in the name.
+
+3. Documentation
+   * New kernels should containing inline docstrings that describe the semantics of the kernel, and provide information on the IO layout. 
+   Upon release, we generate the documentation for our kernels and merge them into the NKI API documentation which will appear in the official AWS NKI documentation. 
+
+
+## Finding Contributions to Work on
 Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any ['help wanted'](https://github.com/aws-neuron/nki-samples/labels/help%20wanted) issues is a great place to start.
 
 
@@ -51,7 +106,7 @@ For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of
 opensource-codeofconduct@amazon.com with any additional questions or comments.
 
 
-## Security issue notifications
+## Security Issue Notifications
 If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue.
 
 
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000..3b1fad4
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,16 @@
+MIT No Attribution
+
+Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+
+Permission is hereby granted, free of charge, to any person obtaining a copy of
+this software and associated documentation files (the "Software"), to deal in
+the Software without restriction, including without limitation the rights to
+use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
+the Software, and to permit persons to whom the Software is furnished to do so.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
+FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
+COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
+IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
\ No newline at end of file
diff --git a/LICENSE.txt b/LICENSE.txt
deleted file mode 100644
index e7f39e2..0000000
--- a/LICENSE.txt
+++ /dev/null
@@ -1 +0,0 @@
-TODO: Fill LICENSE after it is finalized
\ No newline at end of file
diff --git a/README.md b/README.md
index 2d97f6b..60602a9 100644
--- a/README.md
+++ b/README.md
@@ -6,7 +6,7 @@ At the core of the Neuron SDK is the Neuron Compiler, which takes computation gr
 them into highly optimized machine code. 
 
 [NKI](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/nki) is a Python-based programming environment designed for the compiler which
-adopts commonly used NumPy andTriton-like syntax along with tile-level semantics. 
+adopts commonly used NumPy and Triton-like syntax along with tile-level semantics. 
 NKI also interoperates with the Neuron Profiler, providing insights into performance bottlenecks and instruction latencies. 
 It offers tensor printing support, standard error messaging, and built-in kernel simulation capabilities for efficient debugging purposes. 
 NKI offers two types of programming interfaces: 
@@ -16,31 +16,25 @@ enabling bare-metal access to the chip for full control.
 
 ![alt "High-level flow of NKI in the Neuron Compiler. NKI emits IR immediately before the backend-IR compilation stage"](doc_assets/high-level-nki-flow.png#center "High-Level NKI Flow")
 
-### nki.language 
-**nki.language** enables precise control over computation and data movement on NeuronCores-- the processing units within AWS Inferentia and Trainium chips. 
-Developers can control data movement between device memory and on-chip memory explicitly using `nl.load()` and `nl.store()` operations. 
-Developers can then perform the desired computations on loaded tensors, such as element-wise operations or tensor contractions, 
-providing crucial performance improvements. Additionally, developers can control how computation is performed on different compute engines inside NeuronCores. 
-nki.language APIs are considered high-level APIs and are designed for "ease of use" for ML practitioners. 
-To achieve the best performance, developers can enlist the nki.isa APIs.
-
-![alt "Diagram of the NeuronCore Architecture. It shows 4 engines: tensor, vector, scalar, and GPSIMD, connected to SBUF memory. The tensor, vector, and scalar engines are also connected to a high-speed PSUM memory bank that supports accumulate on write. Lastly the HBM (DRAM) is connected to both SBUF and PSUM memory banks."](doc_assets/pm-nc.png#scale_50#center "NeuronCore Architecture")
-
-### nki.isa
-
-**nki.isa** provides direct access to chip instructions to offer flexibility and fine-grained control over instruction usage and performance optimizations. 
-Developers can utilize various `nki.isa` instructions using the Tensor, Vector, Scalar, GP-SIMD, and DMA engines. 
-For example, developers can use `nki.isa.nc_matmul()` to compute a matrix multiplication using Tensor Engine. 
-Alternatively, developers can use `nki.isa.activation()` to apply an activation function on every element of the input tile using Scalar Engine.
+## Documentation
+The latest NKI documentation can be found on the AWS Documentation site, [here](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/nki/). 
+Documentation for NKI kernels are both inline (docstring) and available on the documentation site's 
+[kernel API reference page](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/nki/api/nki.kernels.html).
 
 ## Repository Structure
 
 ### src
 
 #### reference
-The [reference kernels](src/reference/) are optimized reference kernels. All kernels located in this folder must have all of numeric accuracy tests 
+This folder contains the source code of the `neuronxcc.nki.kernels`, and they are optimized kernels from the Neuron Team serving as samples. 
+
+All kernels located in this folder have numeric accuracy tests 
 and performance benchmarks defined in the [test](test/) directory. We also demonstrate using these kernels end-to-end in our [integration tests](test/integration/).
 
+Note that these kernels are already being deployed as part of the Neuron stack. With flash attention as an example,
+[compiling Llama models with transformers-neuronx](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/libraries/transformers-neuronx/transformers-neuronx-developer-guide.html)
+will automatically invoke the `flash_fwd` kernel in [attention.py](src/reference/attention.py). Therefore, replacing the framework operators with these NKI kernels likely won't result in extra performance benefit.
+
 
 #### tutorials
 The [tutorial kernels](src/tutorials/) are for educational purpose and include the kernels that are used in NKI guides. 
@@ -58,65 +52,16 @@ verify the numeric accuracy of the operation, and publish performance results to
 The [integration tests](tests/integration) folder contains integration tests of (selected) kernels. They verify the numeric accuracy of the model’s output, 
 and publish end-to-end performance results into the [integration benchmarks](docs/benchmarks/integration) folder.
 
-## Documentation
-The latest NKI documentation can be found on the AWS Documentation site, [here](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/nki/). 
-Documentation for NKI kernels are both inline (docstring) and available on the documentation site's 
-[kernel API reference page](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/nki/api/nki.kernels.html).
+## Maintenance Policy
+NKI is currently released as **beta** while we gather feedback from our users and integrate it into the API. NKI API follow the [Neuron SDK Maintenance Policy](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/sdk-policy.html).
 
-## Versioning
-NKI is currently released as **beta** while we gather feedback from our users and integrate it into the API. We will also be updating the NKI API as needed 
-to support new Neuron and Neuron Compiler features. While NKI is in beta we may need to make backwards-incompatible changes to incorporate feedback from 
-our users or to support new use-cases of NKI on Neuron devices. Upon releasing NKI as generally available (GA), we will commit to not making backwards 
-incompatible changes to the NKI API for any supported version of the Neuron compiler. 
+## Getting Help
+Have a look at the GitHub issues for this repository where you will find past issues customers have encountered with workarounds and clarifications. 
+If you cannot find a suitable issue for your use-case feel free to [file an issue](https://github.com/aws-neuron/nki-samples/issues/new) to ask for assistance or to suggest improvements. Please read [CONTRIBUTING.md](CONTRIBUTING.md) for detailed information on submitting issues.
 
 ## Contributing
-We invite you to join the NKI community! If you'd like to share kernels you create with the community, we welcome your contributions to this repository via. 
-GitHub pull-requests as well as through filed issues discussing features, bug fixes, new use-cases, and API improvements.
-
-### Getting Help
-Have a look at the GitHub issues for this repository where you will find past issues customers have encountered with workarounds and clarifications. 
-If you cannot find a suitable issue for your use-case feel free to file an issue asking for assistance or to suggest improvements.
-
-In addition, extensive NKI documentation can be found [here](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/nki).
-
-### Testing and Merging
-Running the binaries for a NKI kernel require Neuron devices on an AWS EC2 instance from trn1, trn1n, or inf2 instance families. 
-Details on setting up an instance can be found in [here](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/frameworks/torch/torch-setup.html).
-
-Before merging, the Neuron team will need to internally test and verify kernels work as expected. If the change is accepted, 
-we will manually merge your changes, and it will be merged here upon the next release. 
-
-If you would like to test your kernel without a requiring a Neuron device, you can use `nki.simulate()` to run your kernel using `NumPy` tensors and types. 
-An example can be found in the [layernorm tutorial test](test/unit/test_tutorials_layernorm.py).
-
-### Coding Guidelines
-Most guidelines are covered by a **PEP-8** check on all newly submitted code, which covers aspects such as code layout and basic Python naming conventions. 
-In addition to PEP-8, we use the following NKI specific style guidelines:
-
-1. **Abbreviations**
-    * Importing NKI modules should use consistent names. For example,
-    ```
-    import neuronxcc.nki as nki
-    import neuronxcc.nki.isa as nisa
-    import neuronxcc.nki.language as nl
-    import neuronxcc.nki.typing as nt
-    import numpy as np
-    ```   
-2. Variable Names
-    * Indexing should specify partition and free dimensions along with the variable they are used for. For example:
-        The index for the partition dimension for tile `a` would be
-        ```
-        i_p_a = nl.arange(128)[:, None]
-        ```
-        while the index for the free dimension for tile `b` would be
-        ```
-        i_f_b = nl.arange(512)[None, :]
-        ```
-    * Name loop variables, indices, and buffers consistently, and specify their intended use in the name.
-
-3. Documentation
-   * New kernels should containing inline docstrings that describe the semantics of the kernel, and provide information on the IO layout. 
-   Upon release, we generate the documentation for our kernels and merge them into the NKI API documentation which will appear in the official AWS NKI documentation. 
+We invite you to join the NKI community! If you'd like to share kernels you create with the community, we welcome your contributions to this repository via
+GitHub pull-requests as well as through filed issues discussing features, bug fixes, new use-cases, and API improvements. Please see [CONTRIBUTING.md](CONTRIBUTING.md) for more information
 
 ## Licensing
 This repository is licensed under the terms of the [MIT-0 License](LICENSE.txt)
\ No newline at end of file

From 23835a53cfcb83fcf6bb0eeb0cb98f39d15a8019 Mon Sep 17 00:00:00 2001
From: aws-qieqingy <122939906+aws-qieqingy@users.noreply.github.com>
Date: Sat, 21 Dec 2024 16:27:43 -0500
Subject: [PATCH 3/3] Sync latest kernels and unit tests (#8)

* Add latest kernels and unit tests for NeuronSDK release 2.21

* Move code into src/nki_samples to make unit test executable

* Add Github workflow to run simulation tests on main branch and incoming PRs
---
 .github/workflows/run_simulation_tests.yml    |  32 +
 src/nki_samples/reference/__init__.py         |  10 +
 .../reference/allocated_attention.py          |  12 +-
 .../reference/allocated_fused_linear.py       |   0
 src/{ => nki_samples}/reference/attention.py  | 926 ++++++++++--------
 src/{ => nki_samples}/reference/tutorial.py   |   0
 src/{ => nki_samples}/reference/vision.py     |   0
 .../average_pool2d/average_pool2d_jax.py      |   0
 .../average_pool2d_nki_kernels.py             |   0
 .../average_pool2d/average_pool2d_torch.py    |   0
 .../fused_mamba/mamba_nki_kernels.py          |   0
 .../tutorials/fused_mamba/mamba_torch.py      |   0
 .../layernorm/layernorm_nki_kernel.py         |   0
 .../tutorials/layernorm/layernorm_torch.py    |   0
 .../matrix_multiplication_nki_kernels.py      |   0
 .../matrix_multiplication_torch.py            |   0
 .../tutorials/rmsnorm/rmsnorm_jax.py          |   0
 .../tutorials/rmsnorm/rmsnorm_nki_kernels.py  |   0
 .../tutorials/rmsnorm/rmsnorm_torch.py        |   0
 .../sd_attention/sd_attention_nki_kernels.py  |   0
 .../sd_attention/sd_attention_torch.py        |   0
 .../tensor_addition/tensor_addition_jax.py    |   0
 .../tensor_addition_nki_kernels.py            |   0
 .../tensor_addition/tensor_addition_torch.py  |   0
 .../tutorials/transpose2d/transpose2d_jax.py  |   0
 .../transpose2d/transpose2d_nki_kernels.py    |   0
 .../transpose2d/transpose2d_torch.py          |   0
 src/reference/__init__.py                     |  32 -
 test/unit/README.md                           |   8 +-
 test/unit/__main__.py                         |  14 -
 test/unit/conftest.py                         |  28 +
 test/unit/test_SD_attention_small_head.py     |  14 +-
 .../test_allocated_SD_attention_small_head.py |  15 +-
 test/unit/test_flash_attn_bwd.py              |  21 +-
 test/unit/test_flash_attn_fwd.py              |  24 +-
 test/unit/test_neuron_profile.py              |  86 ++
 test/unit/test_resize_nearest.py              |  16 +-
 test/unit/test_rmsnorm_qkv.py                 |  18 +-
 test/unit/test_select_and_scatter.py          |  16 +-
 39 files changed, 758 insertions(+), 514 deletions(-)
 create mode 100644 .github/workflows/run_simulation_tests.yml
 create mode 100644 src/nki_samples/reference/__init__.py
 rename src/{ => nki_samples}/reference/allocated_attention.py (97%)
 rename src/{ => nki_samples}/reference/allocated_fused_linear.py (100%)
 rename src/{ => nki_samples}/reference/attention.py (54%)
 rename src/{ => nki_samples}/reference/tutorial.py (100%)
 rename src/{ => nki_samples}/reference/vision.py (100%)
 rename src/{ => nki_samples}/tutorials/average_pool2d/average_pool2d_jax.py (100%)
 rename src/{ => nki_samples}/tutorials/average_pool2d/average_pool2d_nki_kernels.py (100%)
 rename src/{ => nki_samples}/tutorials/average_pool2d/average_pool2d_torch.py (100%)
 rename src/{ => nki_samples}/tutorials/fused_mamba/mamba_nki_kernels.py (100%)
 rename src/{ => nki_samples}/tutorials/fused_mamba/mamba_torch.py (100%)
 rename src/{ => nki_samples}/tutorials/layernorm/layernorm_nki_kernel.py (100%)
 rename src/{ => nki_samples}/tutorials/layernorm/layernorm_torch.py (100%)
 rename src/{ => nki_samples}/tutorials/matrix_multiplication/matrix_multiplication_nki_kernels.py (100%)
 rename src/{ => nki_samples}/tutorials/matrix_multiplication/matrix_multiplication_torch.py (100%)
 rename src/{ => nki_samples}/tutorials/rmsnorm/rmsnorm_jax.py (100%)
 rename src/{ => nki_samples}/tutorials/rmsnorm/rmsnorm_nki_kernels.py (100%)
 rename src/{ => nki_samples}/tutorials/rmsnorm/rmsnorm_torch.py (100%)
 rename src/{ => nki_samples}/tutorials/sd_attention/sd_attention_nki_kernels.py (100%)
 rename src/{ => nki_samples}/tutorials/sd_attention/sd_attention_torch.py (100%)
 rename src/{ => nki_samples}/tutorials/tensor_addition/tensor_addition_jax.py (100%)
 rename src/{ => nki_samples}/tutorials/tensor_addition/tensor_addition_nki_kernels.py (100%)
 rename src/{ => nki_samples}/tutorials/tensor_addition/tensor_addition_torch.py (100%)
 rename src/{ => nki_samples}/tutorials/transpose2d/transpose2d_jax.py (100%)
 rename src/{ => nki_samples}/tutorials/transpose2d/transpose2d_nki_kernels.py (100%)
 rename src/{ => nki_samples}/tutorials/transpose2d/transpose2d_torch.py (100%)
 delete mode 100644 src/reference/__init__.py
 delete mode 100644 test/unit/__main__.py
 create mode 100644 test/unit/conftest.py
 create mode 100644 test/unit/test_neuron_profile.py

diff --git a/.github/workflows/run_simulation_tests.yml b/.github/workflows/run_simulation_tests.yml
new file mode 100644
index 0000000..7ea89ef
--- /dev/null
+++ b/.github/workflows/run_simulation_tests.yml
@@ -0,0 +1,32 @@
+name: Run Python Simulation Tests
+
+on:
+  push:
+    branches: [ "main" ]
+  pull_request:
+    branches: [ "main" ]
+
+permissions:
+  contents: read
+
+jobs:
+  build:
+
+    runs-on: ubuntu-latest
+
+    steps:
+    - uses: actions/checkout@v4
+    - name: Set up Python 3.10
+      uses: actions/setup-python@v3
+      with:
+        python-version: "3.10"
+    - name: Install dependencies
+      run: |
+        python -m pip install --upgrade pip
+        python -m pip config set global.extra-index-url https://pip.repos.neuron.amazonaws.com
+        python -m pip install wget awscli
+        python -m pip install pytest
+        python -m pip install neuronx-cc==2.*
+    - name: Test with pytest
+      run: |
+        PYTHONPATH=$PYTHONPATH:src/ pytest test/unit/ --simulation-only
\ No newline at end of file
diff --git a/src/nki_samples/reference/__init__.py b/src/nki_samples/reference/__init__.py
new file mode 100644
index 0000000..c9e5d37
--- /dev/null
+++ b/src/nki_samples/reference/__init__.py
@@ -0,0 +1,10 @@
+# Copyright (c) 2023, Amazon.com. All Rights Reserved
+
+"""
+Package containing public kernels for Neuron Kernel Interface (NKI).
+
+Kernels here are also available in the `neuronxcc.nki.kernels` namespace, and they 
+are synced with this repository on every Neuron SDK release. 
+
+https://github.com/aws-neuron/nki-samples
+"""
diff --git a/src/reference/allocated_attention.py b/src/nki_samples/reference/allocated_attention.py
similarity index 97%
rename from src/reference/allocated_attention.py
rename to src/nki_samples/reference/allocated_attention.py
index 564412c..94b513f 100644
--- a/src/reference/allocated_attention.py
+++ b/src/nki_samples/reference/allocated_attention.py
@@ -9,13 +9,13 @@
 @nki.jit
 def allocated_fused_self_attn_for_SD_small_head_size(q_ref, k_ref, v_ref,
                                            use_causal_mask=False,
-                                           mixed_percision=True):
+                                           mixed_precision=True):
   """
   Allocated fused self attention kernel for small head size Stable Diffusion workload.
   
-  Computes (softmax(Q.T@K)V).T. The wired layout is choosen to avoid transpose as
+  Computes (softmax(Q.T@K)V).T. The wired layout is chosen to avoid transpose as
   much as possible to simplify the debug. The kernel uses the direct allocation API,
-  and implements double buffering to achive better performance than automatic allocation.
+  and implements double buffering to achieve better performance than automatic allocation.
   As of NeuronSDK 2.21, it achieves 18% better performance than auto allocated equivalent.
   To see the performance gap, you can use ``force_auto_alloc`` decorator to override
   manual allocation and benchmark the performance difference.
@@ -34,14 +34,14 @@ def allocated_fused_self_attn_for_SD_small_head_size(q_ref, k_ref, v_ref,
 
   IO tensor dtypes:
    - This kernel assumes all IO tensors have the same dtype
-   - If mixed_percision is True, then all Tensor Engine operation will be performed in
+   - If mixed_precision is True, then all Tensor Engine operation will be performed in
      bfloat16 and accumulation will be performed in float32. Otherwise the intermediates
      will be in the same type as the inputs.
   """
   # Use q_ref dtype as the intermediate tensor dtype
   # Assume all IO tensors have the same dtype
   kernel_dtype = np.float32
-  pe_in_dt = nl.bfloat16 if mixed_percision else kernel_dtype
+  pe_in_dt = nl.bfloat16 if mixed_precision else kernel_dtype
 
   kernel_dtype_itemsize = np.dtype(kernel_dtype).itemsize
   pe_in_dt_itemsize = np.dtype(pe_in_dt).itemsize
@@ -211,7 +211,7 @@ def psum_addr(bank_map, idx, pdim_size, fdim_size):
             on_true_tile=qk_psum[i_interleave_grp, i_k_seq_tile,ip_qk, if_qk], on_false_value=-9984.0, dtype=kernel_dtype)
         else:
           # Copy result to SBUF and find partial maximum for softmax
-          qk_res_buf[i_interleave_grp, ip_qk, i_k_seq_tile * k_seq_tile_size + if_qk] = nisa.tensor_scalar_reduce(qk_psum[i_interleave_grp, i_k_seq_tile,ip_qk, if_qk], np.add, 1.0,
+          qk_res_buf[i_interleave_grp, ip_qk, i_k_seq_tile * k_seq_tile_size + if_qk] = nisa.tensor_scalar_reduce(data=qk_psum[i_interleave_grp, i_k_seq_tile,ip_qk, if_qk], op0=np.add, operand0=1.0,
               reduce_op=np.max, reduce_res=neg_max_res[i_interleave_grp, ip_max, i_k_seq_tile], dtype=kernel_dtype)
 
       # Find global max from tiles
diff --git a/src/reference/allocated_fused_linear.py b/src/nki_samples/reference/allocated_fused_linear.py
similarity index 100%
rename from src/reference/allocated_fused_linear.py
rename to src/nki_samples/reference/allocated_fused_linear.py
diff --git a/src/reference/attention.py b/src/nki_samples/reference/attention.py
similarity index 54%
rename from src/reference/attention.py
rename to src/nki_samples/reference/attention.py
index 9bf0444..3c456a6 100644
--- a/src/reference/attention.py
+++ b/src/nki_samples/reference/attention.py
@@ -15,6 +15,7 @@
 from functools import reduce as functools_reduce
 from operator import mul as operator_mul
 
+
 def n_elts(shape):
   return functools_reduce(operator_mul, shape, 1)
 
@@ -27,21 +28,59 @@ def linearize(shape, indices):
 def div_ceil(n, d):
   return (n + d - 1) // d
 
+
 @dataclass(frozen=True)
 class FlashConfig:
   """
     Config class for flash attention with default values
   """
   seq_tile_size:int = 2048
+  attn_core_tile_size:int = 256
   training:bool = True
   should_transpose_v:bool = False
+  lse_dtype: str = ""
+
+
+@nki.jit(mode='trace')
+def transpose_p_local(p_local_transposed, p_local, LARGE_TILE_SZ):
+  for i in nl.affine_range(LARGE_TILE_SZ // 512):
+    if nisa.get_nc_version() == nisa.nc_version.gen3:
+      p_local_t_tmp = nl.ndarray((par_dim(128), 512), buffer=nl.sbuf, dtype=p_local.dtype)
+    else:
+      p_local_t_tmp = nl.ndarray((par_dim(128), 512), buffer=nl.psum, dtype=np.float32)
+
+    for j in nl.affine_range(512 // 128):
+      j_128_slice = nl.ds(j * 128, 128)
+      i_j_128_slice = nl.ds(i * 512 + j * 128, 128)
+
+      if nisa.get_nc_version() == nisa.nc_version.gen3:
+        p_local_t_tmp[:, j_128_slice] = nisa.dma_transpose(
+          p_local[:, i_j_128_slice])
+      else:
+        p_local_t_tmp[:, j_128_slice] = nisa.nc_transpose(
+          p_local[:, i_j_128_slice])
+
+    p_local_transposed[:, nl.ds(i * 512, 512)] = nl.copy(
+      p_local_t_tmp, dtype=p_local_transposed.dtype)
 
-  __annotations__ = {
-    'seq_tile_size': int,
-    'training': bool,
-    'should_transpose_v': bool
-  }
 
+@nki.jit(mode='trace')
+def dropout_p_local(p_local, dropout_p, dropout_p_tensor, seed_tensor,
+                    seed_offset_base, k_r_i, REDUCTION_TILE):
+  B_F_SIZE = 512
+  for k_d_i in nl.sequential_range(REDUCTION_TILE // B_F_SIZE):
+    p_local_f_slice = nl.ds(k_r_i * REDUCTION_TILE + k_d_i * B_F_SIZE, B_F_SIZE)
+
+    offset = k_d_i + seed_offset_base
+    offset_seed = nl.add(seed_tensor, offset, dtype=nl.int32)
+    nl.random_seed(seed=offset_seed)
+    softmax_dropout = nl.dropout(p_local[:, p_local_f_slice],
+                                 rate=dropout_p_tensor[:, 0])
+    p_local[:, p_local_f_slice] = nl.multiply(
+      softmax_dropout, 1 / (1 - dropout_p))
+
+
+@nki.jit(mode='trace')
 def _flash_attention_core(q_local_tile, k, v,
                           q_h_per_k_h, seqlen_q, nheads,
                           o_buffer, l_buffer, m_buffer,
@@ -49,169 +88,212 @@ def _flash_attention_core(q_local_tile, k, v,
                           local_k_large_tile_idx,
                           kernel_dtype, acc_type,
                           flash_config: FlashConfig,
-                          olm_buffer_idx=None,
-                          global_k_large_tile_idx=None,
-                          use_causal_mask=False, initialize=False,
+                          use_causal_mask, initialize,
                           B_P_SIZE=128, B_F_SIZE=512, B_D_SIZE=128,
-                          dropout_p=0.0, dropout_p_tensor=None, seed_tensor=None
-                          ):
+                          dropout_p=0.0, dropout_p_tensor=None, seed_tensor=None,
+                          logit_bias_tile=None):
   """
   The flash attention core function to calcualte self attention between a tile of q and a block of K and V.
   The q_local_tile has (B_P_SIZE, B_F_SIZE), which is loaded into the SBUF already. The block size of K and V
   is defined in the seq_tile_size of the flash_config. The results are stored in the following three buffers
-  o_buffer: (num_large_k_tile, B_P_SIZE, d)
-  l_buffer: (num_large_k_tile, B_P_SIZE, 1)
-  m_buffer: (num_large_k_tile, B_P_SIZE, 1)
+  o_buffer: (B_P_SIZE, d)
+  l_buffer: (B_P_SIZE, 1)
+  m_buffer: (B_P_SIZE, 1)
   """
   LARGE_TILE_SZ = flash_config.seq_tile_size
-  REDUCTION_TILE = min(2048, LARGE_TILE_SZ // 2)
   num_k_tile_per_large_tile = LARGE_TILE_SZ // B_F_SIZE
   seqlen_k = k.shape[-1]
   seq_q_num_tiles = seqlen_q // B_P_SIZE
   seq_k_num_tiles = seqlen_k // B_F_SIZE
 
-  # Indices used by the distributed attention
-  if global_k_large_tile_idx is None:
-    global_k_large_tile_idx = local_k_large_tile_idx
-  if olm_buffer_idx is None:
-    olm_buffer_idx = local_k_large_tile_idx
-
-  i_q_p = nl.arange(B_P_SIZE)[:, None]
-  i_q_f = nl.arange(B_F_SIZE)[None, :]
-  i_d_p = nl.arange(B_D_SIZE)[:, None]
-  i_d_f = nl.arange(B_D_SIZE)[None, :]
-  i_f_128 = nl.arange(B_P_SIZE)[None, :]
-  i_f_k_tiles = nl.arange(num_k_tile_per_large_tile)[None, :]
-
-  # mask are used to only apply computation to the lower half of the matrix,
-  # which reduce the arthimetic intensity by half
-  forward_mask = q_tile_idx * B_P_SIZE >= global_k_large_tile_idx * LARGE_TILE_SZ if use_causal_mask else None
-  # Negation mask is the negation of `forward_mask`, which is used for the
-  # instructions executed on the blocks in the upper triangular section
-  # of the matrix.
-  # These instructions should not be executed when causual mask is disabled.
-  #
-  # For example, the o_buffer still needs to be propagated from o[j-1] to o[j] in
-  # the upper triangular of the matrix.
-  negation_mask = q_tile_idx * B_P_SIZE < global_k_large_tile_idx * LARGE_TILE_SZ if use_causal_mask else None
-
   qk_res_buf = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ), buffer=nl.sbuf, dtype=acc_type)
   max_local = nl.ndarray((par_dim(B_P_SIZE), num_k_tile_per_large_tile), dtype=acc_type)
+
   for k_i in nl.affine_range(num_k_tile_per_large_tile):
-    qk_psum = nl.zeros((par_dim(B_P_SIZE), B_F_SIZE),
+    k_i_b_f_slice = nl.ds(k_i * B_F_SIZE, B_F_SIZE)
+
+    qk_psum = nl.ndarray((par_dim(B_P_SIZE), B_F_SIZE),
                         dtype=np.float32, buffer=nl.psum)  # (128, 512)
-    multiplication_required_selection = global_k_large_tile_idx * LARGE_TILE_SZ + k_i * B_F_SIZE <= q_tile_idx * B_P_SIZE if use_causal_mask else None
-    qk_psum[i_q_p, i_q_f] += nl.matmul(q_local_tile, k[i_d_p, k_i * B_F_SIZE + i_q_f], transpose_x=True,
-                                       mask=multiplication_required_selection) # (p(128), 512)
+    if use_causal_mask:
+      multiplication_required_selection = q_tile_idx * B_P_SIZE >= local_k_large_tile_idx * LARGE_TILE_SZ + k_i * B_F_SIZE
+    else:
+      multiplication_required_selection = True
+
+    if multiplication_required_selection:
+      qk_psum[:, :] = nl.matmul(q_local_tile, k[:, k_i_b_f_slice], transpose_x=True) # (p(128), 512)
+    else:
+      qk_psum[:, :] = 0
 
     if use_causal_mask:
-      left_diagonal_selection = q_tile_idx * B_P_SIZE >= global_k_large_tile_idx * LARGE_TILE_SZ + (k_i + 1) * B_F_SIZE
-      diagonal_and_right_selection = (q_tile_idx * B_P_SIZE < global_k_large_tile_idx * LARGE_TILE_SZ + (k_i + 1) * B_F_SIZE) & forward_mask
+      left_diagonal_selection = q_tile_idx * B_P_SIZE >= local_k_large_tile_idx * LARGE_TILE_SZ + (k_i + 1) * B_F_SIZE
+      diagonal_and_right_selection = (q_tile_idx * B_P_SIZE < local_k_large_tile_idx * LARGE_TILE_SZ + (k_i + 1) * B_F_SIZE)
+      right_diagonal_selection = ((q_tile_idx + 1) * B_P_SIZE <= local_k_large_tile_idx * LARGE_TILE_SZ + k_i * B_F_SIZE)
+      diagonal = ((q_tile_idx * B_P_SIZE < local_k_large_tile_idx * LARGE_TILE_SZ + (k_i + 1) * B_F_SIZE) &
+                  ((q_tile_idx + 1) * B_P_SIZE > local_k_large_tile_idx * LARGE_TILE_SZ + k_i * B_F_SIZE))
 
+      i_q_p, i_q_f = nl.mgrid[0:B_P_SIZE, 0:B_F_SIZE]
       q_pos = q_tile_idx * B_P_SIZE + i_q_p
-      k_pos = global_k_large_tile_idx * LARGE_TILE_SZ + k_i * B_F_SIZE + i_q_f
+      k_pos = local_k_large_tile_idx * LARGE_TILE_SZ + k_i * B_F_SIZE + i_q_f
       pred = q_pos >= k_pos
-      # For tiles on and to the right of the diagonal, need to do affine_select.
-      # Magic number -9984.0 to replace -inf similar to what Tensorizer uses
-      qk_res_buf[i_q_p, k_i * B_F_SIZE + i_q_f] = nisa.affine_select(
-        pred=pred,
-        on_true_tile=qk_psum[i_q_p, i_q_f], on_false_value=-9984.0, dtype=kernel_dtype,
-        mask=diagonal_and_right_selection)
-
-      # For tiles on the left of the diagonal, direct copy, no select required.
-      qk_res_buf[i_q_p, k_i * B_F_SIZE + i_q_f] = \
-        nl.copy(qk_psum[i_q_p, i_q_f], dtype=kernel_dtype, mask=left_diagonal_selection)
+
+      qk_select_tmp = nl.ndarray(qk_psum.shape, dtype=qk_psum.dtype, buffer=nl.sbuf)
+
+      if logit_bias_tile is not None:
+        if right_diagonal_selection:
+          qk_select_tmp[...] = qk_psum
+
+          # For tiles to the right of the diagonal, do affine_select.
+          # Magic number -9984.0 to replace -inf similar to what Tensorizer uses
+          qk_res_buf[:, k_i_b_f_slice] = nisa.affine_select(
+              pred=pred,
+              on_true_tile=qk_select_tmp, on_false_value=-9984.0, dtype=acc_type)
+
+        # For tiles on the diagonal, add logit bias and need to do affine_select.
+        intermediate = \
+            nl.add(qk_psum, logit_bias_tile[:, k_i_b_f_slice],
+                   dtype=acc_type, mask=diagonal)
+        qk_res_buf[:, k_i_b_f_slice] = nisa.affine_select(
+            pred=pred,
+            on_true_tile=intermediate, on_false_value=-9984.0, dtype=acc_type,
+            mask=diagonal)
+
+        # For tiles on the left of the diagonal, just add logit bias, no select required.
+        qk_res_buf[:, k_i_b_f_slice] = \
+            nl.add(qk_psum, logit_bias_tile[:, k_i_b_f_slice],
+                   dtype=acc_type, mask=left_diagonal_selection)
+      else:
+        # For tiles on and to the right of the diagonal, need to do affine_select.
+        # Magic number -9984.0 to replace -inf similar to what Tensorizer uses
+        if diagonal_and_right_selection:
+          qk_select_tmp[...] = qk_psum
+
+          qk_res_buf[:, k_i_b_f_slice] = nisa.affine_select(
+            pred=pred,
+            on_true_tile=qk_select_tmp, on_false_value=-9984.0, dtype=acc_type)
+
+        # For tiles on the left of the diagonal, direct copy, no select required.
+        qk_res_buf[:, k_i_b_f_slice] = \
+          nl.copy(qk_psum, dtype=acc_type, mask=left_diagonal_selection)
     else:
-      # Simply send psum result back to sbuf
-      qk_res_buf[i_q_p, k_i * B_F_SIZE + i_q_f] = \
-        nl.copy(qk_psum[i_q_p, i_q_f], dtype=kernel_dtype)
+      if logit_bias_tile is not None:
+        # Simply add logit bias which copies back to sbuf at the same time
+        qk_res_buf[:, k_i_b_f_slice] = \
+            nl.add(qk_psum, logit_bias_tile[:, k_i_b_f_slice], dtype=acc_type)
+      else:
+        # Simply send psum result back to sbuf
+        qk_res_buf[:, k_i_b_f_slice] = nl.copy(qk_psum, dtype=acc_type)
 
     # Calculate max of the current tile
-    max_local[i_q_p, k_i] = nisa.tensor_reduce(np.max, qk_res_buf[i_q_p, k_i * B_F_SIZE + i_q_f], axis=(1,),
-                                        dtype=acc_type, negate=False, mask=forward_mask)
-
-  max_ = nisa.tensor_reduce(np.max, max_local[i_q_p, i_f_k_tiles], axis=(1, ),
-                    dtype=acc_type, negate=False, mask=forward_mask)
-  if not initialize:
-    m_previous = nl.copy(m_buffer[olm_buffer_idx - 1, i_q_p, 0])
-    m_buffer[olm_buffer_idx, i_q_p, 0] = nl.maximum(m_previous, max_, mask=forward_mask) # (128,1)
-    if use_causal_mask:
-      m_buffer[olm_buffer_idx, i_q_p, 0] = nl.copy(m_previous, mask=negation_mask)
+    max_local[:, k_i] = nisa.tensor_reduce(
+      np.max, qk_res_buf[:, k_i_b_f_slice], axis=(1,), dtype=acc_type,
+      negate=False)
 
-    m_current = m_buffer[olm_buffer_idx, i_q_p, 0]
-    # Compute scaling factor
-    alpha = nisa.activation(np.exp, m_previous, bias=-1*m_current, scale=1.0, mask=forward_mask)
-    o_previous = nl.copy(o_buffer[olm_buffer_idx-1, i_q_p, i_d_f], mask=forward_mask)
-    o_previous_scaled = nl.multiply(o_previous, alpha, mask=forward_mask)
-  else:
-    m_buffer[0, i_q_p, 0] = nl.copy(max_)
+  max_ = nisa.tensor_reduce(np.max, max_local[:, :], axis=(1, ),
+                            dtype=acc_type, negate=False)
+
+  o_previous_scaled = nl.ndarray((par_dim(B_P_SIZE), B_D_SIZE), dtype=o_buffer.dtype)
+
+  if initialize:
+    m_buffer[:, 0] = nl.copy(max_)
     m_current = max_
+  else:
+    m_previous = nl.copy(m_buffer[:, 0])
+    m_buffer[:, 0] = nl.maximum(m_previous, max_) # (128,1)
+
+    m_current = m_buffer[:, 0]
+    # Compute scaling factor
+    alpha = nisa.activation(np.exp, m_current, bias=m_previous, scale=-1.0)
+    o_previous_scaled[...] = nl.multiply(o_buffer[:, :], alpha)
 
   p_local = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ), dtype=kernel_dtype)
-  i_r_f = nl.arange(REDUCTION_TILE)[None,: ]
+  REDUCTION_TILE = min(2048, LARGE_TILE_SZ // 2)
+
   p_partial_sum = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ // REDUCTION_TILE), dtype=acc_type)
+
   for k_r_i in nl.affine_range(LARGE_TILE_SZ // REDUCTION_TILE):
-    # compute exp(qk-max)
-    p_local[i_q_p, k_r_i * REDUCTION_TILE + i_r_f] = \
-      nisa.activation(np.exp,
-                      qk_res_buf[i_q_p, k_r_i * REDUCTION_TILE + i_r_f],
-                      bias=-1 * m_current,
-                      scale=1.0,
-                      dtype=kernel_dtype,
-                      mask=forward_mask)
+    k_r_i_reduce_slice = nl.ds(k_r_i * REDUCTION_TILE, REDUCTION_TILE)
 
     # dropout
     if dropout_p > 0.0:
-      for k_d_i in nl.sequential_range(REDUCTION_TILE // B_F_SIZE):
-        offset = k_d_i + k_r_i * (REDUCTION_TILE // B_F_SIZE) \
-                  + global_k_large_tile_idx * (LARGE_TILE_SZ // B_F_SIZE) \
-                  + q_tile_idx * seq_k_num_tiles \
-                  + (head_id * q_h_per_k_h + gqa_head_idx) * seq_k_num_tiles * seq_q_num_tiles \
-                  + batch_id * nheads * seq_k_num_tiles * seq_q_num_tiles
-        offset_seed = nl.add(seed_tensor[0, 0], offset, mask=forward_mask)
-        nl.random_seed(seed=offset_seed, mask=forward_mask)
-        softmax_dropout = nl.dropout(p_local[i_q_p, k_r_i * REDUCTION_TILE + k_d_i * B_F_SIZE + i_q_f],
-                                    rate=dropout_p_tensor[i_q_p, 0],
-                                    mask=forward_mask)
-        p_local[i_q_p, k_r_i * REDUCTION_TILE + k_d_i * B_F_SIZE + i_q_f] = \
-          nl.multiply(softmax_dropout, 1 / (1 - dropout_p), mask=forward_mask)
-
-    # Compute partial row-tile sum of exp(qk-max))
-    p_partial_sum[i_q_p, k_r_i] = nl.sum(p_local[i_q_p, k_r_i * REDUCTION_TILE + i_r_f], axis=1, dtype=acc_type, mask=forward_mask)
+      # compute exp(qk-max)
+      p_local[:, k_r_i_reduce_slice] = \
+        nisa.activation(np.exp, qk_res_buf[:, k_r_i_reduce_slice],
+                        bias=-1 * m_current, scale=1.0,
+                        dtype=kernel_dtype)
+
+      seed_offset_base = k_r_i * (REDUCTION_TILE // B_F_SIZE) \
+                         + local_k_large_tile_idx * (LARGE_TILE_SZ // B_F_SIZE) \
+                         + q_tile_idx * seq_k_num_tiles \
+                         + (head_id * q_h_per_k_h + gqa_head_idx) * seq_k_num_tiles * seq_q_num_tiles \
+                         + batch_id * nheads * seq_k_num_tiles * seq_q_num_tiles
+
+      dropout_p_local(p_local=p_local, dropout_p=dropout_p,
+                      dropout_p_tensor=dropout_p_tensor, seed_tensor=seed_tensor,
+                      seed_offset_base=seed_offset_base, k_r_i=k_r_i,
+                      REDUCTION_TILE=REDUCTION_TILE)
+
+      # Compute partial row-tile sum of exp(qk-max))
+      # FIXME: Use activation accumulate and accumulate over k_r_i loop?
+      p_partial_sum[:, k_r_i] = nl.sum(p_local[:, k_r_i_reduce_slice],
+                                       axis=1, dtype=acc_type)
+    else:
+      # compute exp(qk-max)
+      # Compute partial row-tile sum of exp(qk-max))
+      # FIXME: Use activation accumulate to accumulate over k_r_i loop?
+      p_local[:, k_r_i_reduce_slice] = \
+        nisa.activation_reduce(np.exp, qk_res_buf[:, k_r_i_reduce_slice],
+                               bias=-1 * m_current, scale=1.0,
+                               reduce_op=nl.add, reduce_res=p_partial_sum[:, k_r_i],
+                               dtype=kernel_dtype)
+
+  ps = nl.sum(p_partial_sum, axis=1, dtype=acc_type)
 
   p_local_transposed = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ), dtype=kernel_dtype)
-  for i_p_t in nl.affine_range(LARGE_TILE_SZ // 512):
-    p_local_t_tmp = nl.ndarray((par_dim(B_P_SIZE), 512), buffer=nl.psum, dtype=np.float32)
-    for i_p_t_local in nl.affine_range(512//128):
-      p_local_t_tmp[i_q_p, i_p_t_local*128 + i_f_128] = nisa.nc_transpose(p_local[i_q_p, i_p_t*512+i_p_t_local * B_P_SIZE + i_f_128], mask=forward_mask)
-    i_f_512 = nl.arange(512)[None, :]
-    p_local_transposed[i_q_p, i_p_t * 512 + i_f_512 ] = nl.copy(p_local_t_tmp[i_q_p, i_f_512], dtype=kernel_dtype, mask=forward_mask)
-
-  ps = nl.sum(p_partial_sum, axis=1, dtype=acc_type, mask=forward_mask)
-  pv_psum = nl.zeros((par_dim(B_P_SIZE), B_D_SIZE), dtype=np.float32, buffer=nl.psum)
+  transpose_p_local(p_local_transposed=p_local_transposed, p_local=p_local,
+                    LARGE_TILE_SZ=LARGE_TILE_SZ)
+
+  pv_psum = nl.zeros((par_dim(B_P_SIZE), B_D_SIZE), dtype=np.float32,
+                     buffer=nl.psum, lazy_initialization=True)
   for k_i in nl.affine_range(LARGE_TILE_SZ // B_P_SIZE):
-    pv_psum[i_q_p, i_d_f] += nl.matmul(p_local_transposed[i_q_p, k_i * B_P_SIZE + i_f_128],
-                                       v[k_i, i_q_p, i_d_f],
-                                       transpose_x=True,
-                                       mask=forward_mask) # (128, 128) (p(Br), d)
+    pv_psum[:, :] += nl.matmul(p_local_transposed[:, nl.ds(k_i * B_P_SIZE, B_P_SIZE)],
+                               v[k_i, :, :], transpose_x=True) # (128, 128) (p(Br), d)
 
   if initialize:
-    o_buffer[olm_buffer_idx, i_q_p, i_d_f] = nl.copy(pv_psum[i_q_p, i_d_f])
-    l_buffer[olm_buffer_idx, i_q_p, 0] = nl.add(nl.log(ps), max_)
+    o_buffer[:, :] = nl.copy(pv_psum[:, :])
+    l_buffer[:, 0] = nl.add(nl.log(ps), max_)
   else:
-    if use_causal_mask:
-      o_buffer[olm_buffer_idx, i_q_p, i_d_f] = nl.copy(o_buffer[olm_buffer_idx-1, i_q_p, i_d_f], mask=negation_mask)
-    o_buffer[olm_buffer_idx, i_q_p, i_d_f] = nl.add(o_previous_scaled, pv_psum, mask=forward_mask)
+    o_buffer[:, :] = nl.add(o_previous_scaled, pv_psum)
+
+    exp = nisa.activation(nl.exp, m_current, bias=l_buffer[:, 0], scale=-1.0)
+    l_buffer[:, 0] = nl.add(m_current, nisa.activation(nl.log, exp, bias=ps))
+
+
+@nki.jit(mode='trace')
+def load_v_tile(v_hbm_tile, cur_v_tile, j, v_i, config):
+  LARGE_TILE_SZ = config.seq_tile_size
+  B_P_SIZE = 128
+
+  if not config.should_transpose_v:
+    cur_v_tile[v_i, :, :] = nl.load(
+      v_hbm_tile[nl.ds(j * LARGE_TILE_SZ + B_P_SIZE * v_i, B_P_SIZE), :],
+      dtype=cur_v_tile.dtype)
+    return
+
+  if nisa.get_nc_version() == nisa.nc_version.gen3:
+    cur_v_tile_transposed = nisa.dma_transpose(
+      v_hbm_tile[:, nl.ds(j * LARGE_TILE_SZ + B_P_SIZE * v_i, B_P_SIZE)])
+    cur_v_tile[v_i, :, :] = nisa.tensor_copy(cur_v_tile_transposed,
+                                             dtype=cur_v_tile.dtype)
+    return
+
+  cur_v_tile[v_i, :, :] = nl.load_transpose2d(
+    v_hbm_tile[:, nl.ds(j * LARGE_TILE_SZ + B_P_SIZE * v_i, B_P_SIZE)],
+    dtype=cur_v_tile.dtype)
 
-    l_prev = l_buffer[olm_buffer_idx-1, i_q_p, 0]
-    l_exp = nl.add(nl.exp(nl.subtract(l_prev, m_current, mask=forward_mask), mask=forward_mask), ps, mask=forward_mask)
-    l_buffer[olm_buffer_idx, i_q_p, 0] = nl.add(m_current, nl.log(l_exp, mask=forward_mask), mask=forward_mask)
-    if use_causal_mask:
-      l_buffer[olm_buffer_idx, i_q_p, 0] = nl.copy(l_buffer[olm_buffer_idx-1, i_q_p, 0], mask=negation_mask)
 
 
 @nki.jit
-def flash_fwd(q, k, v, seed,
+def flash_fwd(q, k, v, seed, logit_bias=None,
               softmax_scale=None,
               use_causal_mask=True,
               mixed_precision=True,
@@ -224,27 +306,35 @@ def flash_fwd(q, k, v, seed,
     - k: shape   (bs, nk_heads, d, seq_k)
     - v: shape   (bs, nv_heads, d, seq_v) if config.should_transpose_v  else (bs, nv_heads, seq_v, d)
     - seed: shape (1,)
+    - logit_bias: shape (bs, n_heads, seq_q, seq_k)
     - o: shape (bs, n_heads, seq_q, d)
     - lse: shape (bs, n_heads, nl.tile_size.pmax, seq // nl.tile_size.pmax) if training else None
     - This kernel requires seq_k == seq_v
 
   IO tensor dtypes:
     - This kernel assumes all IO tensors have the same dtype
-    - If mixed_percision is True, then all Tensor Engine operation will be performed in
+    - If mixed_precision is True, then all Tensor Engine operation will be performed in
       bfloat16 and accumulation will be performed in float32. Otherwise the intermediates
       will be in the same type as the inputs.
 
   Compile-time Constants:
     - softmax_scale: scaling for softmax, is None, default is `1.0/(d**0.5)`
-    - mixed_precision: flag to set non-matmul ops in fp32 precision, defualt is set to `true`, if false, we use same precision as input types
+    - mixed_precision: flag to set non-matmul ops in fp32 precision, default is set to `true`, if false, we use same precision as input types
     - causal_mask: flag to set causal masking
-    - config: Instance of dataclass :class:`nki.kernels.attention.FlashConfig` with Performance config parameters for flash attention with default values
+    - config: Instance of :class:`nki.kernels.attention.FlashConfig` with Performance config parameters for flash attention with default values
         seq_tile_size: `default=2048`, size of the kv tile size for attention computation reduction
         training: bool to indicate training vs inference `default=True`
 
   Performance Notes:
-    For better performance, the kernel is tiled to be of size `LARGE_TILE_SZ`, and Flash attention math techniques are applied in unit
-    of `LARGE_TILE_SZ`. Seqlen that is not divisible by `LARGE_TILE_SZ` is not supported at the moment.
+    For better performance, the kernel is tiled to be of size `config.seq_tile_size`, and Flash attention math techniques are applied in unit
+    of `config.seq_tile_size`. Seqlen that is not divisible by `config.seq_tile_size` is not supported at the moment.
+
+    For large seqlen, `o_buffer` will overflow the statebuf. the kernel is tile `o_buffer` based on the value of `config.attn_core_tile_size`.
+    This is a tradeoff between memory usage and performance. The default value of `config.attn_core_tile_size` is 256, which means the `o_buffer`
+    will roughly take half of the statebuf. The computes are also tiled accordingly. DMA will be rematerialized
+    `seqlen_q // B_P_SIZE // attn_core_tile_size times`.
+
+
 
   GQA support Notes:
     the spmd kernel for launching kernel should be on kv_heads instead of nheads
@@ -273,26 +363,27 @@ def flash_fwd(q, k, v, seed,
 
   o = nl.ndarray((b, h, seqlen_q, d), dtype=q.dtype, buffer=nl.shared_hbm)
   if config.training:
+    if config.lse_dtype:
+      lse_dtype = getattr(nl, config.lse_dtype)
+    else:
+      lse_dtype = acc_type
     lse = nl.ndarray((b, h, nl.tile_size.pmax, seqlen_q // nl.tile_size.pmax),
-                     dtype=acc_type, buffer=nl.shared_hbm)
+                     dtype=lse_dtype, buffer=nl.shared_hbm)
   else:
     lse = None
 
-  i_q_p = nl.arange(B_P_SIZE)[:,None]
-  i_0_f = nl.arange(1)[None, :]
-
+  assert nl.program_ndim() == 2,\
+    f'Expect spmd grid with 2 dimensions, got {nl.program_ndim()} instead!'
   batch_id = nl.program_id(axis=0)
-
-  head_dims = list(range(1, nl.program_ndim()))
-  head_dims_shape = list(nl.num_programs(i) for i in head_dims)
-  head_dims_idx = list(nl.program_id(i) for i in head_dims)
-  head_id = linearize(head_dims_shape, head_dims_idx)
+  head_id = nl.program_id(axis=1)
 
   softmax_scale = softmax_scale or (1.0 / (d ** 0.5))
 
   n_tile_q = seqlen_q // B_P_SIZE # since q will be loaded on tensor engine
 
   LARGE_TILE_SZ = config.seq_tile_size
+  attn_core_tile_size = config.attn_core_tile_size
+
   # FIXME: Add masking for different seqlen values.
   assert config.seq_tile_size >= 512, f" seq tile_size {config.seq_tile_size} cannot be less than 512"
   assert seqlen_k % LARGE_TILE_SZ == 0, f"Need seqlen_k to be divisible by {LARGE_TILE_SZ} but got {seqlen_k}"
@@ -316,131 +407,107 @@ def flash_fwd(q, k, v, seed,
     dropout_p_tensor = None
     seed_local = None
 
-  for i_q_h in nl.affine_range(q_h_per_k_h):
+  if logit_bias is not None:
+    b_logit_bias, h_logit_bias, _, _ = logit_bias.shape
+    assert b_logit_bias == 1 and h_logit_bias == 1, "only support broadcasting logit_bias with batch 1, n_heads 1"
 
+  n_remat = div_ceil(n_tile_q, attn_core_tile_size)
+  attn_core_tile_size = min(n_tile_q, attn_core_tile_size)
+
+  for i_q_h in nl.affine_range(q_h_per_k_h):
     # =============== Global Flash Attention accumulators ====================== #
-    o_buffer = nl.full((n_tile_q, num_large_k_tile, par_dim(B_P_SIZE), d), 0.0, dtype=acc_type, buffer=nl.sbuf)
-    l_buffer = nl.full((n_tile_q, num_large_k_tile, par_dim(B_P_SIZE), 1), 0.0, dtype=acc_type, buffer=nl.sbuf)
-    m_buffer = nl.full((n_tile_q, num_large_k_tile, par_dim(B_P_SIZE), 1), 0.0, dtype=acc_type)
+    l_buffer = nl.zeros((par_dim(B_P_SIZE), n_tile_q), dtype=acc_type,
+                        buffer=nl.sbuf, lazy_initialization=True)
     # =============== Global Flash Attention accumulators END ================== #
 
-    j = 0
-    cur_k_tile = nl.ndarray((par_dim(B_D_SIZE), LARGE_TILE_SZ), dtype=kernel_dtype)
-    cur_v_tile = nl.ndarray((LARGE_TILE_SZ//B_P_SIZE, par_dim(B_P_SIZE), B_D_SIZE), dtype=kernel_dtype)
-    load_tile_size = B_P_SIZE
-    for k_i in nl.affine_range(LARGE_TILE_SZ // load_tile_size):
-      load_p = nl.arange(B_D_SIZE)[:, None]
-      load_f = nl.arange(load_tile_size)[None, :]
-      cur_k_tile[load_p, load_tile_size*k_i+load_f] = nl.load(
-        k[batch_id, head_id, load_p, load_tile_size*k_i+load_f]
-      )
-    if config.should_transpose_v:
-      for v_i in nl.affine_range(LARGE_TILE_SZ // load_tile_size):
-        load_p = nl.arange(B_D_SIZE)[:, None]
-        load_f = nl.arange(B_P_SIZE)[None, :]
-
-        loaded = nl.load(v[batch_id, head_id, load_p, B_P_SIZE*v_i+load_f], dtype=kernel_dtype)
-        store_p = nl.arange(B_P_SIZE)[:, None]
-        store_f = nl.arange(B_D_SIZE)[None, :]
-        cur_v_tile[v_i, store_p, store_f] = nisa.nc_transpose(loaded)
-    else:
-      for v_i in nl.affine_range(LARGE_TILE_SZ // load_tile_size):
-        load_p = nl.arange(B_P_SIZE)[:, None]
-        load_f = nl.arange(B_D_SIZE)[None, :]
-
-        cur_v_tile[v_i, load_p, load_f] = nl.load(v[batch_id, head_id, B_P_SIZE*v_i+load_p, load_f], dtype=kernel_dtype)
-
-    for i in nl.affine_range(n_tile_q):
-      i_f_128 = nl.arange(B_P_SIZE)[None, :]
-      i_f_d = nl.arange(B_D_SIZE)[None, :]
-      i_p_d = nl.arange(B_D_SIZE)[:,None]
-      q_tile = nl.ndarray((B_D_SIZE, B_P_SIZE),dtype=kernel_dtype)
-      q_tile[i_p_d, i_f_128] = nl.load(q[batch_id,
-                                        head_id * q_h_per_k_h + i_q_h, i_p_d,
-                                        i * B_P_SIZE + i_f_128],
-                                      dtype=kernel_dtype) * softmax_scale # load (d, 128) tile in SBUF
-      # handle first tile and compute max and lse explicitly by passing initialize=True
-      _flash_attention_core(q_local_tile=q_tile, k=cur_k_tile, v=cur_v_tile,
-                            q_h_per_k_h=q_h_per_k_h, seqlen_q=seqlen_q, nheads=h,
-                            o_buffer=o_buffer[i], l_buffer=l_buffer[i], m_buffer=m_buffer[i],
-                            batch_id=batch_id, head_id=head_id,
-                            gqa_head_idx=i_q_h, q_tile_idx=i, local_k_large_tile_idx=0,
-                            kernel_dtype=kernel_dtype, acc_type=acc_type,
-                            flash_config=config, use_causal_mask=use_causal_mask,
-                            initialize=True,
-                            B_P_SIZE=B_P_SIZE, B_F_SIZE=B_F_SIZE, B_D_SIZE=B_D_SIZE,
-                            dropout_p=dropout_p, dropout_p_tensor=dropout_p_tensor, seed_tensor=seed_local)
-
-    for j in nl.sequential_range(1, num_large_k_tile):
-      cur_k_tile = nl.ndarray((par_dim(B_D_SIZE), LARGE_TILE_SZ), dtype=kernel_dtype)
-      cur_v_tile = nl.ndarray((LARGE_TILE_SZ//B_P_SIZE, par_dim(B_P_SIZE), B_D_SIZE), dtype=kernel_dtype)
-      load_tile_size = B_P_SIZE
-      for k_i in nl.affine_range(LARGE_TILE_SZ // load_tile_size):
-        load_p = nl.arange(B_D_SIZE)[:, None]
-        load_f = nl.arange(load_tile_size)[None, :]
-        cur_k_tile[load_p, load_tile_size*k_i+load_f] = nl.load(
-          k[batch_id, head_id, load_p, j*LARGE_TILE_SZ+load_tile_size*k_i+load_f]
-        )
-      if config.should_transpose_v:
-        for v_i in nl.affine_range(LARGE_TILE_SZ // load_tile_size):
-          load_p = nl.arange(B_D_SIZE)[:, None]
-          load_f = nl.arange(B_P_SIZE)[None, :]
+    for i0 in nl.sequential_range(n_remat):
+      # =============== Global Flash Attention accumulators ====================== #
+      o_buffer = nl.zeros((attn_core_tile_size, par_dim(B_P_SIZE), d), dtype=acc_type,
+                          buffer=nl.sbuf, lazy_initialization=True)
+      m_buffer = nl.zeros((attn_core_tile_size, par_dim(B_P_SIZE), 1), dtype=acc_type,
+                          buffer=nl.sbuf, lazy_initialization=True)
+      # =============== Global Flash Attention accumulators END ================== #
 
-          loaded = nl.load(v[batch_id, head_id, load_p, j*LARGE_TILE_SZ+B_P_SIZE*v_i+load_f], dtype=kernel_dtype)
-          store_p = nl.arange(B_P_SIZE)[:, None]
-          store_f = nl.arange(B_D_SIZE)[None, :]
-          cur_v_tile[v_i, store_p, store_f] = nisa.nc_transpose(loaded)
-      else:
+      for j in nl.sequential_range(0, num_large_k_tile):
+        cur_k_tile = nl.ndarray((par_dim(B_D_SIZE), LARGE_TILE_SZ), dtype=kernel_dtype)
+        cur_v_tile = nl.ndarray((LARGE_TILE_SZ // B_P_SIZE, par_dim(B_P_SIZE), B_D_SIZE), dtype=kernel_dtype)
+
+        cur_k_tile[:, :] = nl.load(k[batch_id, head_id, :, nl.ds(j*LARGE_TILE_SZ, LARGE_TILE_SZ)])
+
+        load_tile_size = B_P_SIZE
+
+        v_hbm_tile = v[batch_id, head_id]
         for v_i in nl.affine_range(LARGE_TILE_SZ // load_tile_size):
-          load_p = nl.arange(B_P_SIZE)[:, None]
-          load_f = nl.arange(B_D_SIZE)[None, :]
-
-          cur_v_tile[v_i, load_p, load_f] = nl.load(v[batch_id, head_id, j*LARGE_TILE_SZ+B_P_SIZE*v_i+load_p, load_f], dtype=kernel_dtype)
-
-      for i in nl.affine_range(n_tile_q):
-        i_f_128 = nl.arange(B_P_SIZE)[None, :]
-        i_f_d = nl.arange(B_D_SIZE)[None, :]
-        i_p_d = nl.arange(B_D_SIZE)[:,None]
-        q_tile = nl.ndarray((B_D_SIZE, B_P_SIZE),dtype=kernel_dtype)
-        q_tile[i_p_d, i_f_128] = nl.load(q[batch_id,
-                head_id * q_h_per_k_h + i_q_h, i_p_d,
-                i * B_P_SIZE + i_f_128],
-              dtype=kernel_dtype) * softmax_scale # load (d, 128) tile in SBUF
-        _flash_attention_core(q_local_tile=q_tile, k=cur_k_tile, v=cur_v_tile,
-                              q_h_per_k_h=q_h_per_k_h, seqlen_q=seqlen_q, nheads=h,
-                              o_buffer=o_buffer[i], l_buffer=l_buffer[i], m_buffer=m_buffer[i],
-                              batch_id=batch_id, head_id=head_id,
-                              gqa_head_idx=i_q_h, q_tile_idx=i, local_k_large_tile_idx=j,
-                              kernel_dtype=kernel_dtype, acc_type=acc_type,
-                              flash_config=config, use_causal_mask=use_causal_mask,
-                              initialize=False,
-                              B_P_SIZE=B_P_SIZE, B_F_SIZE=B_F_SIZE, B_D_SIZE=B_D_SIZE,
-                              dropout_p=dropout_p, dropout_p_tensor=dropout_p_tensor, seed_tensor=seed_local)
-
-    # -------- write output to buffer on HBM ------------ #
-    for i in nl.affine_range(n_tile_q):
-      out = nl.ndarray((par_dim(B_P_SIZE), B_D_SIZE), dtype=kernel_dtype)
-      out[i_q_p, i_f_d] = nl.multiply(o_buffer[i, num_large_k_tile - 1, i_q_p, i_f_d],
-                                      nl.exp(m_buffer[i, num_large_k_tile - 1, i_q_p, i_0_f] - l_buffer[i, num_large_k_tile - 1, i_q_p, i_0_f]),
-                                      dtype=kernel_dtype)
-
-      nl.store(o[batch_id, head_id * q_h_per_k_h + i_q_h, i*B_P_SIZE + i_q_p, i_f_d], out[i_q_p, i_f_d])
-      if not inference:
-        lse_local = nl.zeros((par_dim(B_P_SIZE), 1), dtype=acc_type)
-        lse_local[i_q_p, i_0_f] = nl.copy(l_buffer[i, num_large_k_tile - 1, i_q_p, i_0_f], dtype=acc_type)
-        nl.store(lse[batch_id, head_id * q_h_per_k_h + i_q_h, i_q_p, i + i_0_f], lse_local[i_q_p, i_0_f])
+          load_v_tile(v_hbm_tile=v_hbm_tile, cur_v_tile=cur_v_tile, j=j, v_i=v_i,
+                      config=config)
+
+        for i1 in nl.affine_range(attn_core_tile_size):
+          i = i0 * attn_core_tile_size + i1
+          # mask are used to only apply computation to the lower half of the matrix,
+          # which reduce the arthimetic intensity by half.
+          # forward_mask imply initialize, i.e. if forward_mask is false, initialize will
+          # be false as well
+          if use_causal_mask:
+            forward_mask = i * B_P_SIZE >= j * LARGE_TILE_SZ
+          else:
+            forward_mask = True
+
+          if (i < n_tile_q) & forward_mask:
+            q_tile = nl.ndarray((B_D_SIZE, B_P_SIZE),dtype=kernel_dtype)
+            q_hbm_tile = q[batch_id, head_id * q_h_per_k_h + i_q_h]
+            q_sbuf_tile = nl.load(q_hbm_tile[:, nl.ds(i * B_P_SIZE, B_P_SIZE)],
+                                  dtype=kernel_dtype) # load (d, 128) tile in SBUF
+            q_tile[:, :] = q_sbuf_tile * softmax_scale
+
+            logit_bias_tile = None
+            if logit_bias is not None:
+              logit_bias_tile = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ), dtype=kernel_dtype)
+              logit_bias_tile[:, :] = nl.load(
+                logit_bias[0, 0, nl.ds(i * B_P_SIZE, B_P_SIZE),
+                           nl.ds(j * LARGE_TILE_SZ, LARGE_TILE_SZ)])
+
+            _flash_attention_core(q_local_tile=q_tile, k=cur_k_tile, v=cur_v_tile,
+                                  q_h_per_k_h=q_h_per_k_h, seqlen_q=seqlen_q, nheads=h,
+                                  o_buffer=o_buffer[i1], l_buffer=l_buffer[:, i], m_buffer=m_buffer[i1],
+                                  batch_id=batch_id, head_id=head_id,
+                                  gqa_head_idx=i_q_h, q_tile_idx=i, local_k_large_tile_idx=j,
+                                  kernel_dtype=kernel_dtype, acc_type=acc_type,
+                                  flash_config=config, use_causal_mask=use_causal_mask,
+                                  initialize=j == 0,
+                                  B_P_SIZE=B_P_SIZE, B_F_SIZE=B_F_SIZE, B_D_SIZE=B_D_SIZE,
+                                  dropout_p=dropout_p, dropout_p_tensor=dropout_p_tensor,
+                                  seed_tensor=seed_local, logit_bias_tile=logit_bias_tile)
+
+      # -------- write output to buffer on HBM ------------ #
+      for i1 in nl.affine_range(attn_core_tile_size):
+        i = i0 * attn_core_tile_size + i1
+
+        if i < n_tile_q:
+          exp = nisa.activation(np.exp, l_buffer[:, i], bias=m_buffer[i1, :, :],
+                                scale=-1.0)
+          out = nl.multiply(o_buffer[i1, :, :], exp,
+                            dtype=kernel_dtype)
+
+          nl.store(o[batch_id, head_id * q_h_per_k_h + i_q_h,
+                     nl.ds(i*B_P_SIZE, B_P_SIZE), :], out)
+
+    if not inference:
+      nl.store(lse[batch_id, head_id * q_h_per_k_h + i_q_h, :, :], l_buffer[:, :])
 
   if config.training:
     return o, lse
 
   return o
 
+
+
 @nki.jit
 def flash_attn_bwd(
   q_ref, k_ref, v_ref, o_ref,
   dy_ref,
   lse_ref,
   seed_ref,
+  logit_bias_ref=None,
   use_causal_mask=False,
   mixed_precision=False,
   dropout_p=0.0,
@@ -457,6 +524,7 @@ def flash_attn_bwd(
    - dy_ref: shape (bs, nheads, head_size, seq)
    - lse_ref: shape (bs, nheads, nl.tile_size.pmax, seq // nl.tile_size.pmax)
    - seed_ref: shape (1,)
+   - logit_bias_ref: shape (bs, n_heads, seq_q, seq_k)
    - out_dq_ref: shape (bs, nheads, head_size, seq)
    - out_dk_ref: shape (bs, nheads, head_size, seq)
    - out_dv_ref: shape (bs, nheads, head_size, seq)
@@ -464,11 +532,11 @@ def flash_attn_bwd(
   Detailed steps:
     1. D = rowsum(dO ◦ O) (pointwise multiply)
 
-    2. Recompute (softmax(Q^T@K))
+    2. Recompute (softmax(Q^T@K + logic_bias))
 
       2.1 Q^T@K
       2.2 Scale the QK score
-      2.3 Apply causal mask
+      2.3 Apply causal mask and add logit_bias
       2.4 softmax
 
     3. Compute the gradients of y = score @ V with respect to the loss
@@ -487,7 +555,6 @@ def flash_attn_bwd(
   mixed_dtype = np.dtype(np.float32) if mixed_precision else kernel_dtype
 
   assert q_ref.dtype == k_ref.dtype == v_ref.dtype == o_ref.dtype == dy_ref.dtype
-  assert lse_ref.dtype == mixed_dtype
 
   # Shape checking
   bs, nheads, d_head, seqlen_q = q_ref.shape
@@ -520,16 +587,18 @@ def flash_attn_bwd(
   # Softmax scaling factor, multiplied onto Q
   softmax_scale = softmax_scale or 1.0 / float(d_head ** 0.5)
 
+  assert nl.program_ndim() == 2,\
+    f'Expect spmd grid with 2 dimensions, got {nl.program_ndim()} instead!'
   # Different batch samples/attention heads have independent attention
   batch_id = nl.program_id(axis=0)
+  head_id = nl.program_id(axis=1)
 
-  head_dims = list(range(1, nl.program_ndim()))
-  head_dims_shape = list(nl.num_programs(i) for i in head_dims)
-  head_dims_idx = list(nl.program_id(i) for i in head_dims)
-  head_id = linearize(head_dims_shape, head_dims_idx)
+  assert nl.num_programs(1) == nheads, \
+    f"The grid shape mismatch, got {nl.num_programs(1)} but should be {nheads}"
 
-  assert n_elts(head_dims_shape) == nheads, \
-    f"The grid shape mismatch, got {n_elts(head_dims_shape)} but should be {nheads}"
+  if logit_bias_ref is not None:
+    b_logit_bias, h_logit_bias, _, _ = logit_bias_ref.shape
+    assert b_logit_bias == 1 and h_logit_bias == 1, "Only support broadcasting logit_bias with batch 1, n_heads 1"
 
   q_seq_n_tiles, q_seq_tile_size = div_ceil(seqlen_q, 128), 128
   d_head_n_tiles, d_head_tile_size = div_ceil(d_head, 128), min(d_head, 128)
@@ -545,45 +614,19 @@ def flash_attn_bwd(
   ##############################################################
   # Step 2.4 Prefetch exp bias for softmax
   ##############################################################
-  softmax_exp_bias = nl.zeros((q_seq_n_tiles, par_dim(q_seq_tile_size), 1), dtype=mixed_dtype)
-  for i_q_seq_tile in nl.affine_range(q_seq_n_tiles):
-    ip_qk = nl.arange(q_seq_tile_size)[:, None]
-    lse_local = nl.load(
-      lse_ref[batch_id, head_id, ip_qk, i_q_seq_tile],
-      dtype=mixed_dtype)
-    softmax_exp_bias[i_q_seq_tile, ip_qk, 0] = lse_local * -1.0
+  softmax_exp_bias = nl.zeros((par_dim(q_seq_tile_size), q_seq_n_tiles), dtype=mixed_dtype)
+  lse_local = nl.load(lse_ref[batch_id, head_id, :, :], dtype=mixed_dtype)
+  softmax_exp_bias[:, :] = lse_local * -1.0
 
   ##############################################################
   # Step 1 Compute rowsum(dO ◦ O)
   ##############################################################
   dy_o_sum = nl.ndarray((q_seq_n_tiles, par_dim(q_seq_tile_size), 1), dtype=mixed_dtype)
-  for i_q_seq_tile in nl.affine_range(q_seq_n_tiles):
-    ip_reduce = nl.arange(q_seq_tile_size)[:, None]
-    dy_o_partial = nl.zeros((par_dim(q_seq_tile_size), d_head_n_tiles), dtype=mixed_dtype)
-    for i_d_head_tile in nl.affine_range(d_head_n_tiles):
-      ip_load = nl.arange(d_head_tile_size)[:, None]
-      if_q = nl.arange(q_seq_tile_size)[None, :]
-      dy_local = nl.load_transpose2d(
-        dy_ref[batch_id, head_id, i_d_head_tile * d_head_tile_size + ip_load, i_q_seq_tile * q_seq_tile_size + if_q],
-        dtype=mixed_dtype)
-      o_local = nl.load_transpose2d(
-        o_ref[batch_id, head_id, i_d_head_tile * d_head_tile_size + ip_load, i_q_seq_tile * q_seq_tile_size + if_q],
-        dtype=mixed_dtype
-      )
-
-      dy_o_partial[ip_reduce, i_d_head_tile] = nisa.tensor_reduce(
-        np.add, data=dy_local*o_local, axis=(1,), dtype=mixed_dtype
-      )
-
-    dy_o_sum[i_q_seq_tile, ip_reduce, 0] = nisa.tensor_reduce(
-      np.add, data=dy_o_partial[ip_reduce, nl.arange(d_head_n_tiles)[None, :]],
-      axis=(1,), dtype=mixed_dtype
-    )
-
-  # Indices for prefetch
-  ip_qk = nl.arange(d_head_tile_size)[:, None]
-  if_q = nl.arange(q_seq_tile_size)[None, :]
-  if_k = nl.arange(k_seq_tile_size)[None, :]
+  compute_rowsum(dy_o_sum=dy_o_sum,
+                 dy_ref_hbm_tile=dy_ref[batch_id, head_id],
+                 o_ref_hbm_tile=o_ref[batch_id, head_id],
+                 d_head_n_tiles=d_head_n_tiles, d_head_tile_size=d_head_tile_size,
+                 q_seq_n_tiles=q_seq_n_tiles, q_seq_tile_size=q_seq_tile_size)
 
   if dropout_p > 0.0:
     seed_local = nl.load(seed_ref[0])
@@ -603,28 +646,25 @@ def flash_attn_bwd(
   _range = nl.sequential_range if dropout_p > 0.0 else nl.affine_range
 
   for i_k_seq_tile in nl.affine_range(k_seq_n_tiles):
-    # Prefetch V, K
-    v_local = nl.zeros((d_head_n_tiles, par_dim(d_head_tile_size), k_seq_tile_size), dtype=kernel_dtype)
-    k_local = nl.zeros((d_head_n_tiles, par_dim(d_head_tile_size), k_seq_tile_size), dtype=kernel_dtype)
-    transposed_k_local = nl.zeros((k_seq_fwd_bwd_tile_multipler, d_head_n_tiles, par_dim(k_seq_tile_size_backward), d_head_tile_size), dtype=kernel_dtype)
-    for i_d_head_tile in nl.affine_range(d_head_n_tiles):
-      k_local[i_d_head_tile, ip_qk, if_k] = nl.load(
-        k_ref[batch_id, head_id, i_d_head_tile * d_head_tile_size + ip_qk, i_k_seq_tile * k_seq_tile_size + if_k],
-        dtype=kernel_dtype)
-      v_local[i_d_head_tile, ip_qk, if_k] = nl.load(
-        v_ref[batch_id, head_id, i_d_head_tile * d_head_tile_size + ip_qk, i_k_seq_tile * k_seq_tile_size + if_k],
-        dtype=kernel_dtype)
-      ##############################################################
-      # Prefetch k transpose for the backward too
-      ##############################################################
-      if_k_backward = nl.arange(k_seq_tile_size_backward)[None, :]
-      ip_k_backward = nl.arange(k_seq_tile_size_backward)[:, None]
-      if_d_head = nl.arange(d_head_tile_size)[None, :]
-      for i_k_seq_tile_backward in nl.affine_range(k_seq_fwd_bwd_tile_multipler):
-        transposed_k_local[i_k_seq_tile_backward, i_d_head_tile, ip_k_backward, if_d_head] = \
-          nisa.nc_transpose(k_local[i_d_head_tile, ip_qk,
-                                    i_k_seq_tile_backward * k_seq_tile_size_backward + if_k_backward])
+    i_k_seq_dslice = nl.ds(i_k_seq_tile * k_seq_tile_size, k_seq_tile_size)
 
+    # Prefetch V, K
+    v_local = nl.zeros((d_head_n_tiles, par_dim(d_head_tile_size), k_seq_tile_size),
+                       dtype=kernel_dtype)
+    k_local = nl.zeros((d_head_n_tiles, par_dim(d_head_tile_size), k_seq_tile_size),
+                       dtype=kernel_dtype)
+    transposed_k_local = nl.zeros((k_seq_fwd_bwd_tile_multipler, d_head_n_tiles,
+                                   par_dim(k_seq_tile_size_backward), d_head_tile_size),
+                                  dtype=kernel_dtype)
+
+    load_kv(k_ref_hbm_tile=k_ref[batch_id, head_id],
+            v_ref_hbm_tile=v_ref[batch_id, head_id],
+            k_local=k_local, transposed_k_local=transposed_k_local, v_local=v_local,
+            d_head_n_tiles=d_head_n_tiles, d_head_tile_size=d_head_tile_size,
+            i_k_seq_tile=i_k_seq_tile, k_seq_tile_size=k_seq_tile_size,
+            k_seq_tile_size_backward=k_seq_tile_size_backward)
+
+    # FIXME: Pass sbuf instead, we will have psum spilling in the current implementation
     dv_psum = nl.zeros((d_head_n_tiles, par_dim(d_head_tile_size), k_seq_tile_size),
                         dtype=np.float32, buffer=nl.psum)
     dk_psum = nl.zeros((d_head_n_tiles, par_dim(d_head_tile_size), k_seq_tile_size),
@@ -633,17 +673,20 @@ def flash_attn_bwd(
       # Prefetch dy, Q
       dy_local = nl.zeros((d_head_n_tiles, par_dim(d_head_tile_size), q_seq_tile_size), dtype=kernel_dtype)
       q_local = nl.zeros((d_head_n_tiles, par_dim(d_head_tile_size), q_seq_tile_size), dtype=kernel_dtype)
-      for i_d_head_tile in nl.affine_range(d_head_n_tiles):
-        ip_qk = nl.arange(d_head_tile_size)[:, None]
-        if_q = nl.arange(q_seq_tile_size)[None, :]
 
-        dy_local[i_d_head_tile, ip_qk, if_q] = nl.load(
-          dy_ref[batch_id, head_id, i_d_head_tile * d_head_tile_size + ip_qk, i_q_seq_tile * q_seq_tile_size + if_q],
-          dtype=kernel_dtype)
+      load_dy_q(dy_ref_hbm_tile = dy_ref[batch_id, head_id],
+                q_ref_hbm_tile = q_ref[batch_id, head_id],
+                dy_local=dy_local, q_local=q_local, d_head_n_tiles=d_head_n_tiles,
+                d_head_tile_size=d_head_tile_size, i_q_seq_tile=i_q_seq_tile,
+                q_seq_tile_size=q_seq_tile_size, softmax_scale=softmax_scale)
 
-        q_local[i_d_head_tile, ip_qk, if_q] = nl.load(
-          q_ref[batch_id, head_id, i_d_head_tile * d_head_tile_size + ip_qk, i_q_seq_tile * q_seq_tile_size + if_q],
-          dtype=kernel_dtype) * softmax_scale
+      logit_bias_tile = None
+      if logit_bias_ref is not None:
+        i_q_seq_dslice = nl.ds(i_q_seq_tile * q_seq_tile_size, q_seq_tile_size)
+        logit_bias_tile = nl.ndarray((par_dim(q_seq_tile_size), k_seq_tile_size),
+                                     buffer=nl.sbuf, dtype=kernel_dtype)
+        logit_bias_tile[:, :] = nl.load(
+          logit_bias_ref[0, 0, i_q_seq_dslice, i_k_seq_dslice])
 
       _flash_attn_bwd_core(
         q_local=q_local, k_local=k_local, transposed_k_local=transposed_k_local,
@@ -656,43 +699,102 @@ def flash_attn_bwd(
         kernel_dtype=kernel_dtype, mixed_dtype=mixed_dtype,
         softmax_scale=softmax_scale,
         seed_local=seed_local, dropout_p=dropout_p, dropout_p_local=dropout_p_local,
+        logit_bias_tile=logit_bias_tile
       )
 
     # Write dK, dV
-    for i_d_head_tile in nl.affine_range(d_head_n_tiles):
-      ip_dkv = nl.arange(d_head_tile_size)[:, None]
-      if_dkv = nl.arange(k_seq_tile_size)[None, :]
-
-      nl.store(
-        out_dv_ref[batch_id, head_id,
-                   i_d_head_tile * d_head_tile_size + ip_dkv,
-                   i_k_seq_tile * k_seq_tile_size + if_dkv],
-        value=dv_psum[i_d_head_tile, ip_dkv, if_dkv],
-      )
-
-      nl.store(
-        out_dk_ref[batch_id, head_id,
-                    i_d_head_tile * d_head_tile_size + ip_dkv,
-                    i_k_seq_tile * k_seq_tile_size + if_dkv],
-        value=dk_psum[i_d_head_tile, ip_dkv, if_dkv],
-      )
+    store_dk_dv(out_dk_ref_hbm_tile=out_dk_ref[batch_id, head_id],
+                out_dv_ref_hbm_tile=out_dv_ref[batch_id, head_id],
+                local_dk=dk_psum, local_dv=dv_psum, i_k_seq_dslice=i_k_seq_dslice,
+                d_head_n_tiles=d_head_n_tiles, d_head_tile_size=d_head_tile_size)
 
   # Write dQ
   for i_q_seq_tile in nl.affine_range(q_seq_n_tiles):
     for i_d_head_tile in nl.affine_range(d_head_n_tiles):
-      ip_dq = nl.arange(d_head_tile_size)[:, None]
-      if_dq = nl.arange(q_seq_tile_size)[None, :]
-
+      i_q_seq_dslice = nl.ds(i_q_seq_tile * q_seq_tile_size, q_seq_tile_size)
+      i_d_head_dslice = nl.ds(i_d_head_tile * d_head_tile_size, d_head_tile_size)
       nl.store(
-        out_dq_ref[batch_id, head_id,
-                   i_d_head_tile * d_head_tile_size + ip_dq,
-                   i_q_seq_tile * q_seq_tile_size + if_dq],
-        value=dq_local_reduced[i_q_seq_tile, i_d_head_tile, ip_dq, if_dq],
+        out_dq_ref[batch_id, head_id, i_d_head_dslice, i_q_seq_dslice],
+        value=dq_local_reduced[i_q_seq_tile, i_d_head_tile, :, :],
       )
 
   return out_dq_ref, out_dk_ref, out_dv_ref
 
 
+@nki.jit(mode='trace')
+def load_dy_q(dy_ref_hbm_tile, q_ref_hbm_tile, dy_local, q_local, d_head_n_tiles, d_head_tile_size, i_q_seq_tile,
+              q_seq_tile_size, softmax_scale):
+  for i_d_head_tile in nl.affine_range(d_head_n_tiles):
+    i_d_head_dslice = nl.ds(i_d_head_tile * d_head_tile_size, d_head_tile_size)
+    i_q_seq_dslice = nl.ds(i_q_seq_tile * q_seq_tile_size, q_seq_tile_size)
+
+    dy_local[i_d_head_tile, :, :] = nl.load(
+      dy_ref_hbm_tile[i_d_head_dslice, i_q_seq_dslice],
+      dtype=dy_local.dtype)
+
+    q_local[i_d_head_tile, :, :] = nl.load(
+      q_ref_hbm_tile[i_d_head_dslice, i_q_seq_dslice],
+      dtype=q_local.dtype) * softmax_scale
+
+
+@nki.jit(mode='trace')
+def store_dk_dv(out_dk_ref_hbm_tile, out_dv_ref_hbm_tile, local_dk, local_dv,
+                d_head_n_tiles, d_head_tile_size, i_k_seq_dslice):
+  for i in nl.affine_range(d_head_n_tiles):
+    i_d_head_dslice = nl.ds(i * d_head_tile_size, d_head_tile_size)
+
+    nl.store(out_dv_ref_hbm_tile[i_d_head_dslice, i_k_seq_dslice],
+             value=local_dv[i, :, :])
+
+    nl.store(out_dk_ref_hbm_tile[i_d_head_dslice, i_k_seq_dslice],
+             value=local_dk[i, :, :])
+
+
+@nki.jit(mode='trace')
+def load_kv(k_ref_hbm_tile, v_ref_hbm_tile, k_local, transposed_k_local, v_local,
+            d_head_n_tiles, d_head_tile_size, i_k_seq_tile, k_seq_tile_size,
+            k_seq_tile_size_backward):
+  k_seq_fwd_bwd_tile_multipler = k_seq_tile_size // k_seq_tile_size_backward
+
+  for i in nl.affine_range(d_head_n_tiles):
+    i_d_head_dslice = nl.ds(i * d_head_tile_size, d_head_tile_size)
+    i_k_seq_dslice = nl.ds(i_k_seq_tile * k_seq_tile_size, k_seq_tile_size)
+    k_local[i, :, :] = nl.load(k_ref_hbm_tile[i_d_head_dslice, i_k_seq_dslice],
+                                           dtype=k_local.dtype)
+    v_local[i, :, :] = nl.load(v_ref_hbm_tile[i_d_head_dslice, i_k_seq_dslice],
+                                           dtype=v_local.dtype)
+    ##############################################################
+    # Prefetch k transpose for the backward too
+    ##############################################################
+    for j in nl.affine_range(k_seq_fwd_bwd_tile_multipler):
+      i_k_dslice = nl.ds(j * k_seq_tile_size_backward, k_seq_tile_size_backward)
+      transposed_k_local[j, i, :, :] = nisa.nc_transpose(k_local[i, :, i_k_dslice])
+
+
+@nki.jit(mode='trace')
+def compute_rowsum(dy_o_sum, dy_ref_hbm_tile, o_ref_hbm_tile, d_head_n_tiles, d_head_tile_size, q_seq_n_tiles,
+                   q_seq_tile_size):
+  mixed_dtype = dy_o_sum.dtype
+  for i in nl.affine_range(q_seq_n_tiles):
+    dy_o_partial = nl.zeros((par_dim(q_seq_tile_size), d_head_n_tiles), dtype=mixed_dtype)
+    for j in nl.affine_range(d_head_n_tiles):
+      d_head_dslice = nl.ds(j * d_head_tile_size, d_head_tile_size)
+      q_seq_dslice = nl.ds(i * q_seq_tile_size, q_seq_tile_size)
+
+      dy_local = nl.load_transpose2d(dy_ref_hbm_tile[d_head_dslice, q_seq_dslice],
+                                     dtype=mixed_dtype)
+      o_local = nl.load_transpose2d(o_ref_hbm_tile[d_head_dslice, q_seq_dslice],
+                                    dtype=mixed_dtype)
+
+      dy_o = nl.multiply(dy_local, o_local, dtype=mixed_dtype)
+      dy_o_partial[:, j] = nisa.tensor_reduce(np.add, data=dy_o, axis=(1,),
+                                              dtype=mixed_dtype)
+
+    dy_o_sum[i, :, 0] = nisa.tensor_reduce(
+      np.add, data=dy_o_partial[:, :], axis=(1,), dtype=mixed_dtype)
+
+
+@nki.jit(mode='trace')
 def _flash_attn_bwd_core(
   q_local, k_local, transposed_k_local, v_local, dy_local,
   dk_psum, dv_psum, dq_local_reduced,
@@ -703,11 +805,7 @@ def _flash_attn_bwd_core(
   kernel_dtype, mixed_dtype,
   softmax_scale,
   seed_local, dropout_p, dropout_p_local,
-  global_i_q_seq_tile = None,
-  global_i_k_seq_tile = None,
-  # Used for nl.loop_reduce on dQ if local_i_k_seq_tile is not an index e.g. if it has an offset
-  local_i_k_seq_tile_for_dq_reduce = None,
-):
+  logit_bias_tile=None):
   """
   The flash backward core function to calculate the gradients of Q, K and V
   of the given tiles. The result will be accumulated into the dk, dv, dq psum
@@ -721,14 +819,7 @@ def _flash_attn_bwd_core(
   k_seq_n_tiles_backward, k_seq_tile_size_backward = seqlen_k // 128, 128
   k_seq_fwd_bwd_tile_multipler = k_seq_tile_size // k_seq_tile_size_backward
 
-  if global_i_q_seq_tile is None:
-    global_i_q_seq_tile = local_i_q_seq_tile
-    global_i_k_seq_tile = local_i_k_seq_tile
-  
-  if local_i_k_seq_tile_for_dq_reduce is None:
-    local_i_k_seq_tile_for_dq_reduce = local_i_k_seq_tile
-
-  mask = global_i_q_seq_tile * q_seq_tile_size >= global_i_k_seq_tile * k_seq_tile_size if use_causal_mask else None
+  mask = local_i_q_seq_tile * q_seq_tile_size >= local_i_k_seq_tile * k_seq_tile_size if use_causal_mask else None
   # PSUM buffer shape: [q_seq_tile_size P, k_seq_tile_size F]
   qk_psum = nl.zeros((par_dim(q_seq_tile_size), k_seq_tile_size),
                       dtype=np.float32, buffer=nl.psum)
@@ -736,66 +827,75 @@ def _flash_attn_bwd_core(
 
   batch_id = nl.program_id(axis=0)
   head_id = nl.program_id(axis=1)
-  # Tensor indices for accessing qk result in k_seq_tile_size
-  if_q = nl.arange(q_seq_tile_size)[None, :]
-  ip_qk = nl.arange(d_head_tile_size)[:, None]
-
-  ip_q = nl.arange(q_seq_tile_size)[:, None]
-  if_k = nl.arange(k_seq_tile_size)[None, :]
 
   # Loop over contraction dim of QK matmul
   for i_d_head_tile in nl.affine_range(d_head_n_tiles):
     ##############################################################
     # Step 2.1 Compute Q^T@K, with matmul(stationary=tensor_q, moving=tensor_k, contract=d_head)
     ##############################################################
-    qk_psum[ip_q, if_k] += nisa.nc_matmul(q_local[i_d_head_tile, ip_qk, if_q],
-                                            k_local[i_d_head_tile, ip_qk, if_k],
+    qk_psum[:, :] += nisa.nc_matmul(q_local[i_d_head_tile, :, :],
+                                            k_local[i_d_head_tile, :, :],
                                             mask=mask)
 
   ######################################
   # Step 2.2. Apply optional causal mask
   ######################################
   if use_causal_mask:
-    # Magic number -9984.0 to replace -inf similar to what Tensorizer uses
-    qk_res_buf[ip_q, if_k] = nisa.affine_select(
-      pred=(global_i_q_seq_tile * q_seq_tile_size + ip_q >= global_i_k_seq_tile * k_seq_tile_size + if_k),
-      on_true_tile=qk_psum[ip_q, if_k], on_false_value=-9984.0, dtype=mixed_dtype,
-      mask=mask)
+    iq, ik = nl.mgrid[0:q_seq_tile_size, 0:k_seq_tile_size]
+    causal_pred = (local_i_q_seq_tile * q_seq_tile_size + iq >= local_i_k_seq_tile * k_seq_tile_size + ik)
+    if logit_bias_tile is not None:
+      # Magic number -9984.0 to replace -inf similar to what Tensorizer uses
+      intermediate = \
+        nl.add(qk_psum[:, :], logit_bias_tile[:, :], dtype=mixed_dtype, mask=mask)
+      qk_res_buf[:, :] = nisa.affine_select(
+        pred=causal_pred, 
+        on_true_tile=intermediate, on_false_value=-9984.0, dtype=mixed_dtype,
+        mask=mask
+      )
+
+    else:
+      # Magic number -9984.0 to replace -inf similar to what Tensorizer uses
+      qk_res_buf[:, :] = nisa.affine_select(
+        pred=causal_pred,
+        on_true_tile=qk_psum[:, :], on_false_value=-9984.0, dtype=mixed_dtype,
+        mask=mask)
   else:
-    # Simply send psum result back to sbuf
-    qk_res_buf[ip_q, if_k] = \
-      nl.copy(qk_psum[ip_q, if_k], dtype=mixed_dtype)
+    if logit_bias_tile is not None:
+      # Simply add logit bias which copies back to sbuf at the same time
+      qk_res_buf[:, :] = \
+        nl.add(qk_psum[:, :], logit_bias_tile[:, :], dtype=mixed_dtype)
+    else:
+      # Simply send psum result back to sbuf
+      qk_res_buf[:, :] = \
+        nl.copy(qk_psum[:, :], dtype=mixed_dtype)
 
   softmax_y = nl.ndarray((par_dim(q_seq_tile_size), k_seq_tile_size), dtype=kernel_dtype, buffer=nl.sbuf)
-  softmax_y[ip_q, if_k] = nisa.activation(np.exp,
-                                            data=qk_res_buf[ip_q, if_k],
-                                            bias=softmax_exp_bias[local_i_q_seq_tile, ip_q, 0],
-                                            scale=1.0,
-                                            mask=mask)
+  softmax_y[:, :] = nisa.activation(np.exp,
+                                    data=qk_res_buf[:, :],
+                                    bias=softmax_exp_bias[:, local_i_q_seq_tile],
+                                    scale=1.0,
+                                    mask=mask)
   #####################################################################
   # Dropout
   #####################################################################
   if dropout_p > 0.0:
-    offset = global_i_k_seq_tile + global_i_q_seq_tile * k_seq_n_tiles \
+    offset = local_i_k_seq_tile + local_i_q_seq_tile * k_seq_n_tiles \
               + head_id * k_seq_n_tiles * q_seq_n_tiles \
               + batch_id * nheads * k_seq_n_tiles * q_seq_n_tiles
     offset_seed = nl.add(seed_local[0, 0], offset, mask=mask)
     nl.random_seed(seed=offset_seed, mask=mask)
-    softmax_y[ip_q, if_k] = nl.dropout(softmax_y[ip_q, if_k], rate=dropout_p_local[ip_q, 0], mask=mask)
-    softmax_y[ip_q, if_k] = nl.multiply(softmax_y[ip_q, if_k], 1 / (1 - dropout_p), mask=mask)
+    softmax_y[:, :] = nl.dropout(softmax_y[:, :], rate=dropout_p_local[:, 0], mask=mask)
+    softmax_y[:, :] = nl.multiply(softmax_y[:, :], 1 / (1 - dropout_p), mask=mask)
 
   #####################################################################
   # Step 3.1 Calculate the backward gradients dL/dV, where y=softmax@V
   # in value projection with matmul(stationary=dy, moving=softmax)
   #####################################################################
   for i_d_head_tile in nl.affine_range(d_head_n_tiles):
-    ip_dv = nl.arange(d_head_tile_size)[:, None]
-    if_dv = nl.arange(k_seq_tile_size)[None, :]
-    if_trans_dy = nl.arange(q_seq_tile_size)[None, :]
-    trans_dy = nisa.nc_transpose(dy_local[i_d_head_tile, ip_dv, if_trans_dy],
+    trans_dy = nisa.nc_transpose(dy_local[i_d_head_tile, :, :],
                                   mask=mask)
-    dv_psum[i_d_head_tile, ip_dv, if_dv] += \
-      nisa.nc_matmul(trans_dy, softmax_y[ip_q, if_k], mask=mask)
+    dv_psum[i_d_head_tile, :, :] += \
+      nisa.nc_matmul(trans_dy, softmax_y[:, :], mask=mask)
 
   #####################################################################
   # Step 3.2 Calculate the backward gradients dL/dsoftmax, where y=softmax@V
@@ -804,15 +904,13 @@ def _flash_attn_bwd_core(
   softmax_dy_psum = nl.zeros((par_dim(q_seq_tile_size), k_seq_tile_size),
                               dtype=np.float32, buffer=nl.psum)
   for i_d_head_tile in nl.affine_range(d_head_n_tiles):
-    ip_softmax_dy = nl.arange(d_head_tile_size)[:, None]
-    if_dy = nl.arange(q_seq_tile_size)[None, :]
-    softmax_dy_psum[ip_q, if_k] += \
-      nisa.nc_matmul(dy_local[i_d_head_tile, ip_softmax_dy, if_dy],
-                      v_local[i_d_head_tile, ip_softmax_dy, if_k],
+    softmax_dy_psum[:, :] += \
+      nisa.nc_matmul(dy_local[i_d_head_tile, :, :],
+                      v_local[i_d_head_tile, :, :],
                       mask=mask)
 
   softmax_dy = nl.ndarray((par_dim(q_seq_tile_size), k_seq_tile_size), dtype=kernel_dtype, buffer=nl.sbuf)
-  softmax_dy[ip_q, if_k] = nl.copy(softmax_dy_psum[ip_q, if_k], dtype=kernel_dtype,
+  softmax_dy[:, :] = nl.copy(softmax_dy_psum[:, :], dtype=kernel_dtype,
                                       mask=mask)
 
   #####################################################################
@@ -820,61 +918,55 @@ def _flash_attn_bwd_core(
   # dL/dx = y * (dL/dy - rowsum(dO_O)), where y = softmax(x)
   #####################################################################
   softmax_dx_local = nl.ndarray((par_dim(q_seq_tile_size), k_seq_tile_size), dtype=kernel_dtype, buffer=nl.sbuf)
-  softmax_dx_local[ip_q, if_k] = \
-    nisa.scalar_tensor_tensor(data=softmax_dy[ip_q, if_k],
+  softmax_dx_local[:, :] = \
+    nisa.scalar_tensor_tensor(data=softmax_dy[:, :],
                               op0=np.subtract,
-                              operand0=dy_o_sum[local_i_q_seq_tile, ip_q, 0],
+                              operand0=dy_o_sum[local_i_q_seq_tile, :, 0],
                               op1=np.multiply,
-                              operand1=softmax_y[ip_q, if_k],
+                              operand1=softmax_y[:, :],
                               mask=mask)
 
   #####################################################################
   # Step 5.1 Calculate dK, with matmul(stationary=Q, moving=softmax_dx)
   #####################################################################
   for i_d_head_tile in nl.affine_range(d_head_n_tiles):
-    ip_trans_q = nl.arange(d_head_tile_size)[:, None]
-    if_trans_q = nl.arange(q_seq_tile_size)[None, :]
-    ip_dk = nl.arange(d_head_tile_size)[:, None]
-    trans_q_local = nisa.nc_transpose(q_local[i_d_head_tile, ip_trans_q, if_trans_q],
+    trans_q_local = nisa.nc_transpose(q_local[i_d_head_tile, :, :],
                                       mask=mask)
-    dk_psum[i_d_head_tile, ip_dk, if_k] += \
+    dk_psum[i_d_head_tile, :, :] += \
       nisa.nc_matmul(trans_q_local,
-                      softmax_dx_local[ip_q, if_k],
+                      softmax_dx_local[:, :],
                       mask=mask)
 
   #####################################################################
   # Step 5.2 Calculate dQ
   #####################################################################
-  if_k = nl.arange(k_seq_tile_size_backward)[None, :]
-  ip_dq = nl.arange(d_head_tile_size)[:, None]
-  if_dq = nl.arange(q_seq_tile_size)[None, :]
-  if_d = nl.arange(d_head_tile_size)[None, :]
-  ip_transposed_k = nl.arange(k_seq_tile_size_backward)[:, None]
   for i_d_head_tile in nl.affine_range(d_head_n_tiles):
     dq_psum = nl.zeros((par_dim(d_head_tile_size), q_seq_tile_size),
                         dtype=np.float32, buffer=nl.psum)
     for i_k_seq_tile_backward in nl.affine_range(k_seq_fwd_bwd_tile_multipler):
+      i_k_seq_dslice = nl.ds(i_k_seq_tile_backward * k_seq_tile_size_backward,
+                             k_seq_tile_size_backward)
       transposed_softmax_dx_local = \
-        nisa.nc_transpose(softmax_dx_local[ip_q, i_k_seq_tile_backward * k_seq_tile_size_backward + if_k],
+        nisa.nc_transpose(softmax_dx_local[:, i_k_seq_dslice],
                           mask=mask)
-      dq_psum[ip_dq, if_dq] += nisa.nc_matmul(
-          transposed_k_local[i_k_seq_tile_backward, i_d_head_tile, ip_transposed_k, if_d],
+      dq_psum[:, :] += nisa.nc_matmul(
+          transposed_k_local[i_k_seq_tile_backward, i_d_head_tile, :, :],
           transposed_softmax_dx_local,
           mask=mask)
-    dq_local = nl.multiply(dq_psum[ip_dq, if_dq], softmax_scale, dtype=kernel_dtype, mask=mask)
-    dq_local_reduced[local_i_q_seq_tile, i_d_head_tile, ip_dq, if_dq] = nl.loop_reduce(
-      dq_local, op=np.add, loop_indices=(local_i_k_seq_tile_for_dq_reduce,),
+    dq_local = nl.multiply(dq_psum[:, :], softmax_scale, dtype=kernel_dtype, mask=mask)
+    dq_local_reduced[local_i_q_seq_tile, i_d_head_tile, :, :] = nl.loop_reduce(
+      dq_local, op=np.add, loop_indices=(local_i_k_seq_tile,),
       dtype=mixed_dtype, mask=mask)
 
 
 @nki.jit
 def fused_self_attn_for_SD_small_head_size(q_ref, k_ref, v_ref, use_causal_mask=False,
-                                           mixed_percision=True):
+                                           mixed_precision=True):
   """
   Fused self attention kernel for small head size Stable Diffusion workload.
 
   Computes softmax(QK^T)V. Decoder model can optionally include a causal mask
-  application. Does not include QKV rojection, output projection, dropout,
+  application. Does not include QKV projection, output projection, dropout,
   residual connection, etc.
 
   This kernel is designed to be used for Stable Diffusion models where the
@@ -890,14 +982,14 @@ def fused_self_attn_for_SD_small_head_size(q_ref, k_ref, v_ref, use_causal_mask=
 
   IO tensor dtypes:
    - This kernel assumes all IO tensors have the same dtype
-   - If mixed_percision is True, then all Tensor Engine operation will be performed in
+   - If mixed_precision is True, then all Tensor Engine operation will be performed in
      bfloat16 and accumulation will be performed in float32. Otherwise the intermediates
      will be in the same type as the inputs.
   """
   # Use q_ref dtype as the intermediate tensor dtype
   # Assume all IO tensors have the same dtype
   kernel_dtype = q_ref.dtype
-  pe_in_dt = nl.bfloat16 if mixed_percision else np.float32
+  pe_in_dt = nl.bfloat16 if mixed_precision else np.float32
   assert q_ref.dtype == k_ref.dtype == v_ref.dtype
 
   # Shape checking
diff --git a/src/reference/tutorial.py b/src/nki_samples/reference/tutorial.py
similarity index 100%
rename from src/reference/tutorial.py
rename to src/nki_samples/reference/tutorial.py
diff --git a/src/reference/vision.py b/src/nki_samples/reference/vision.py
similarity index 100%
rename from src/reference/vision.py
rename to src/nki_samples/reference/vision.py
diff --git a/src/tutorials/average_pool2d/average_pool2d_jax.py b/src/nki_samples/tutorials/average_pool2d/average_pool2d_jax.py
similarity index 100%
rename from src/tutorials/average_pool2d/average_pool2d_jax.py
rename to src/nki_samples/tutorials/average_pool2d/average_pool2d_jax.py
diff --git a/src/tutorials/average_pool2d/average_pool2d_nki_kernels.py b/src/nki_samples/tutorials/average_pool2d/average_pool2d_nki_kernels.py
similarity index 100%
rename from src/tutorials/average_pool2d/average_pool2d_nki_kernels.py
rename to src/nki_samples/tutorials/average_pool2d/average_pool2d_nki_kernels.py
diff --git a/src/tutorials/average_pool2d/average_pool2d_torch.py b/src/nki_samples/tutorials/average_pool2d/average_pool2d_torch.py
similarity index 100%
rename from src/tutorials/average_pool2d/average_pool2d_torch.py
rename to src/nki_samples/tutorials/average_pool2d/average_pool2d_torch.py
diff --git a/src/tutorials/fused_mamba/mamba_nki_kernels.py b/src/nki_samples/tutorials/fused_mamba/mamba_nki_kernels.py
similarity index 100%
rename from src/tutorials/fused_mamba/mamba_nki_kernels.py
rename to src/nki_samples/tutorials/fused_mamba/mamba_nki_kernels.py
diff --git a/src/tutorials/fused_mamba/mamba_torch.py b/src/nki_samples/tutorials/fused_mamba/mamba_torch.py
similarity index 100%
rename from src/tutorials/fused_mamba/mamba_torch.py
rename to src/nki_samples/tutorials/fused_mamba/mamba_torch.py
diff --git a/src/tutorials/layernorm/layernorm_nki_kernel.py b/src/nki_samples/tutorials/layernorm/layernorm_nki_kernel.py
similarity index 100%
rename from src/tutorials/layernorm/layernorm_nki_kernel.py
rename to src/nki_samples/tutorials/layernorm/layernorm_nki_kernel.py
diff --git a/src/tutorials/layernorm/layernorm_torch.py b/src/nki_samples/tutorials/layernorm/layernorm_torch.py
similarity index 100%
rename from src/tutorials/layernorm/layernorm_torch.py
rename to src/nki_samples/tutorials/layernorm/layernorm_torch.py
diff --git a/src/tutorials/matrix_multiplication/matrix_multiplication_nki_kernels.py b/src/nki_samples/tutorials/matrix_multiplication/matrix_multiplication_nki_kernels.py
similarity index 100%
rename from src/tutorials/matrix_multiplication/matrix_multiplication_nki_kernels.py
rename to src/nki_samples/tutorials/matrix_multiplication/matrix_multiplication_nki_kernels.py
diff --git a/src/tutorials/matrix_multiplication/matrix_multiplication_torch.py b/src/nki_samples/tutorials/matrix_multiplication/matrix_multiplication_torch.py
similarity index 100%
rename from src/tutorials/matrix_multiplication/matrix_multiplication_torch.py
rename to src/nki_samples/tutorials/matrix_multiplication/matrix_multiplication_torch.py
diff --git a/src/tutorials/rmsnorm/rmsnorm_jax.py b/src/nki_samples/tutorials/rmsnorm/rmsnorm_jax.py
similarity index 100%
rename from src/tutorials/rmsnorm/rmsnorm_jax.py
rename to src/nki_samples/tutorials/rmsnorm/rmsnorm_jax.py
diff --git a/src/tutorials/rmsnorm/rmsnorm_nki_kernels.py b/src/nki_samples/tutorials/rmsnorm/rmsnorm_nki_kernels.py
similarity index 100%
rename from src/tutorials/rmsnorm/rmsnorm_nki_kernels.py
rename to src/nki_samples/tutorials/rmsnorm/rmsnorm_nki_kernels.py
diff --git a/src/tutorials/rmsnorm/rmsnorm_torch.py b/src/nki_samples/tutorials/rmsnorm/rmsnorm_torch.py
similarity index 100%
rename from src/tutorials/rmsnorm/rmsnorm_torch.py
rename to src/nki_samples/tutorials/rmsnorm/rmsnorm_torch.py
diff --git a/src/tutorials/sd_attention/sd_attention_nki_kernels.py b/src/nki_samples/tutorials/sd_attention/sd_attention_nki_kernels.py
similarity index 100%
rename from src/tutorials/sd_attention/sd_attention_nki_kernels.py
rename to src/nki_samples/tutorials/sd_attention/sd_attention_nki_kernels.py
diff --git a/src/tutorials/sd_attention/sd_attention_torch.py b/src/nki_samples/tutorials/sd_attention/sd_attention_torch.py
similarity index 100%
rename from src/tutorials/sd_attention/sd_attention_torch.py
rename to src/nki_samples/tutorials/sd_attention/sd_attention_torch.py
diff --git a/src/tutorials/tensor_addition/tensor_addition_jax.py b/src/nki_samples/tutorials/tensor_addition/tensor_addition_jax.py
similarity index 100%
rename from src/tutorials/tensor_addition/tensor_addition_jax.py
rename to src/nki_samples/tutorials/tensor_addition/tensor_addition_jax.py
diff --git a/src/tutorials/tensor_addition/tensor_addition_nki_kernels.py b/src/nki_samples/tutorials/tensor_addition/tensor_addition_nki_kernels.py
similarity index 100%
rename from src/tutorials/tensor_addition/tensor_addition_nki_kernels.py
rename to src/nki_samples/tutorials/tensor_addition/tensor_addition_nki_kernels.py
diff --git a/src/tutorials/tensor_addition/tensor_addition_torch.py b/src/nki_samples/tutorials/tensor_addition/tensor_addition_torch.py
similarity index 100%
rename from src/tutorials/tensor_addition/tensor_addition_torch.py
rename to src/nki_samples/tutorials/tensor_addition/tensor_addition_torch.py
diff --git a/src/tutorials/transpose2d/transpose2d_jax.py b/src/nki_samples/tutorials/transpose2d/transpose2d_jax.py
similarity index 100%
rename from src/tutorials/transpose2d/transpose2d_jax.py
rename to src/nki_samples/tutorials/transpose2d/transpose2d_jax.py
diff --git a/src/tutorials/transpose2d/transpose2d_nki_kernels.py b/src/nki_samples/tutorials/transpose2d/transpose2d_nki_kernels.py
similarity index 100%
rename from src/tutorials/transpose2d/transpose2d_nki_kernels.py
rename to src/nki_samples/tutorials/transpose2d/transpose2d_nki_kernels.py
diff --git a/src/tutorials/transpose2d/transpose2d_torch.py b/src/nki_samples/tutorials/transpose2d/transpose2d_torch.py
similarity index 100%
rename from src/tutorials/transpose2d/transpose2d_torch.py
rename to src/nki_samples/tutorials/transpose2d/transpose2d_torch.py
diff --git a/src/reference/__init__.py b/src/reference/__init__.py
deleted file mode 100644
index 922dd83..0000000
--- a/src/reference/__init__.py
+++ /dev/null
@@ -1,32 +0,0 @@
-# Copyright (c) 2023, Amazon.com. All Rights Reserved
-
-"""
-Package containing public kernels for Neuron Kernel Interface (NKI).
-
-Kernels here are the same to the ones available in the 
-NKI Github Sample Repo.
-
-https://github.com/aws-neuron/nki-samples
-"""
-from neuronxcc.nki.kernels.attention import fused_self_attn_for_SD_small_head_size, flash_attn_bwd, flash_fwd
-from neuronxcc.nki.kernels.vision import resize_nearest_fixed_dma_kernel, select_and_scatter_kernel
-from neuronxcc.nki.kernels.tutorial import add_kernel_nx8x128x512
-from neuronxcc.nki.kernels.allocated_attention import allocated_fused_self_attn_for_SD_small_head_size
-from neuronxcc.nki.kernels.allocated_fused_linear import allocated_fused_rms_norm_qkv
-
-from neuronxcc.nki._private_kernels.legacy.attention import \
-  (fused_self_attn_for_SD_small_head_size as _fused_self_attn_for_SD_small_head_size,
-   flash_attn_bwd as _flash_attn_bwd, flash_fwd as _flash_fwd)
-from neuronxcc.nki._private_kernels.legacy.vision import (
-  resize_nearest_fixed_dma_kernel as _resize_nearest_fixed_dma_kernel,
-  select_and_scatter_kernel as _select_and_scatter_kernel)
-from neuronxcc.nki._private_kernels.legacy.tutorial import add_kernel_nx8x128x512 as _add_kernel_nx8x128x512
-from neuronxcc.nki._private_kernels.legacy.allocated_fused_linear import _allocated_fused_rms_norm_qkv
-
-fused_self_attn_for_SD_small_head_size._legacy_func = _fused_self_attn_for_SD_small_head_size
-flash_attn_bwd._legacy_func = _flash_attn_bwd
-flash_fwd._legacy_func = _flash_fwd
-resize_nearest_fixed_dma_kernel._legacy_func = _resize_nearest_fixed_dma_kernel
-select_and_scatter_kernel._legacy_func = _select_and_scatter_kernel
-add_kernel_nx8x128x512._legacy_func = _add_kernel_nx8x128x512
-allocated_fused_rms_norm_qkv._legacy_func = _allocated_fused_rms_norm_qkv
diff --git a/test/unit/README.md b/test/unit/README.md
index e0835de..55dc937 100644
--- a/test/unit/README.md
+++ b/test/unit/README.md
@@ -1 +1,7 @@
-Tests under this folder are unit tests for the kernels in `neuronxcc.nki.kernels`, and they are part of the nki-samples Github Repo. Only public APIs can be used for tests in this folder.
\ No newline at end of file
+Tests under this folder are unit tests for the kernels in `src/nki_samples`. 
+
+To execute the tests, we need to include `src/nki_samples` in the `PYTHONPATH`.
+
+For example, 
+
+PYTHONPATH=$PYTHONPATH:/home/ubuntu/nki-samples/src/ pytest test_flash_attn_fwd.py
\ No newline at end of file
diff --git a/test/unit/__main__.py b/test/unit/__main__.py
deleted file mode 100644
index 34fee3a..0000000
--- a/test/unit/__main__.py
+++ /dev/null
@@ -1,14 +0,0 @@
-import os
-import sys
-
-# This file is basically a hack around the fact that pytest has a bug where it does not discover conftest.py correctly if you launch the test using --pyargs.
-# https://github.com/pytest-dev/pytest/issues/1596
-
-
-# Todo: Using __file__ isn't the most robust. Figure out how to do this using importlib or similar.
-test_root = os.path.dirname(__file__)
-
-if __name__ == "__main__":
-  import pytest
-  errcode = pytest.main([test_root] + sys.argv[1:])
-  sys.exit(errcode)
\ No newline at end of file
diff --git a/test/unit/conftest.py b/test/unit/conftest.py
new file mode 100644
index 0000000..cd663ae
--- /dev/null
+++ b/test/unit/conftest.py
@@ -0,0 +1,28 @@
+import pytest
+
+def pytest_addoption(parser):
+    parser.addoption(
+        "--simulation-only", action="store_true", default=False, help="Run simulation only, it will run test with `simulation` marker in simulation mode"
+    )
+
+def pytest_configure(config):
+    config.addinivalue_line(
+        "markers", "simulation: mark simulation test that can be executed without a NeuronDevice"
+    )
+
+@pytest.fixture
+def simulation_only(request):
+    return request.config.getoption("--simulation-only")
+
+def pytest_collection_modifyitems(session, config, items):
+    if config.getoption("--simulation-only"):
+        # Only run cases with `simulation marker`
+        result = []
+        for item in items:
+            for marker in item.iter_markers():
+                if marker.name == 'simulation':
+                    result.append(item)
+                    break
+        items.clear()
+        items.extend(result)
+        
\ No newline at end of file
diff --git a/test/unit/test_SD_attention_small_head.py b/test/unit/test_SD_attention_small_head.py
index 32e6945..1a54a4b 100644
--- a/test/unit/test_SD_attention_small_head.py
+++ b/test/unit/test_SD_attention_small_head.py
@@ -3,14 +3,13 @@
 """
 import os
 import pytest
-from neuronxcc.nki.kernels.attention import fused_self_attn_for_SD_small_head_size
-from neuronxcc.nki import benchmark, baremetal
+from nki_samples.reference.attention import fused_self_attn_for_SD_small_head_size
+from neuronxcc.nki import benchmark, baremetal, simulate_kernel
 import neuronxcc.nki.language as nl
 import numpy as np
 from scipy.special import softmax
 
 test_trace_file_path='local_trace.ntff'
-numeric_func = baremetal(fused_self_attn_for_SD_small_head_size)
 bench_func = benchmark(warmup=5, iters=20, save_trace_name=test_trace_file_path)(fused_self_attn_for_SD_small_head_size)
 
 def cpu_golden_attn(q, k, v):
@@ -46,11 +45,12 @@ def test_attention_for_SD_perf(self, bs, seqlen, d, dtype, latency):
         assert p50 <= latency*1.05 # short running kernels are subjected to hardware fluctuation
         assert os.path.getsize(test_trace_file_path) > 0
 
+    @pytest.mark.simulation
     @pytest.mark.parametrize("bs,seqlen,d,dtype", [
         [1, 4096, 128, np.float32],
         [1, 4096, 128, nl.bfloat16]
     ])
-    def test_attention_for_SD_numberic(self, bs, seqlen, d, dtype):
+    def test_attention_for_SD_numberic(self, simulation_only, bs, seqlen, d, dtype):
         q = np.random.random_sample((bs, d, seqlen)).astype(np.float32)
         k = np.random.random_sample((bs, seqlen, d)).astype(np.float32)
         v = np.random.random_sample((bs, seqlen, d)).astype(np.float32)
@@ -59,7 +59,11 @@ def test_attention_for_SD_numberic(self, bs, seqlen, d, dtype):
         k_dev = nl.static_cast(k, dtype)
         v_dev = nl.static_cast(v, dtype)
 
-        out = numeric_func[bs](q_dev, k_dev, v_dev)
+        numeric_func = baremetal(fused_self_attn_for_SD_small_head_size)
+        if simulation_only:
+            out = simulate_kernel(numeric_func[bs], q_dev, k_dev, v_dev)
+        else:
+            out = numeric_func[bs](q_dev, k_dev, v_dev)
         out = nl.static_cast(out, np.float32)
         golden_result = cpu_golden_attn(q, k, v)
         assert np.allclose(out, golden_result, atol=1e-2)
diff --git a/test/unit/test_allocated_SD_attention_small_head.py b/test/unit/test_allocated_SD_attention_small_head.py
index ee0de86..712148f 100644
--- a/test/unit/test_allocated_SD_attention_small_head.py
+++ b/test/unit/test_allocated_SD_attention_small_head.py
@@ -3,15 +3,15 @@
 """
 import os
 import pytest
-from neuronxcc.nki.kernels.allocated_attention import allocated_fused_self_attn_for_SD_small_head_size
-from neuronxcc.nki import benchmark, baremetal
+from nki_samples.reference.allocated_attention import allocated_fused_self_attn_for_SD_small_head_size
+from neuronxcc.nki import benchmark, baremetal, simulate_kernel
 import neuronxcc.nki as nki
 import neuronxcc.nki.language as nl
 import numpy as np
 from scipy.special import softmax
 
 test_trace_file_path='local_trace.ntff'
-numeric_func = baremetal(allocated_fused_self_attn_for_SD_small_head_size)
+
 bench_func = benchmark(warmup=5, iters=20, save_trace_name=test_trace_file_path)(allocated_fused_self_attn_for_SD_small_head_size)
 
 def cpu_golden_attn(q, k, v):
@@ -47,12 +47,13 @@ def test_allocated_attention_for_SD_perf(self, bs, seqlen, d, dtype, latency):
         assert p50 <= latency * 1.05 # short running kernels are subjected to hardware fluctuation
         assert os.path.getsize(test_trace_file_path) > 0
 
+    @pytest.mark.simulation
     @pytest.mark.parametrize("bs,seqlen,d,dtype", [
         [1, 4096, 128, np.float32],
         [1, 4096, 128, nl.bfloat16],
         [1, 5120, 128, nl.bfloat16]
     ])
-    def test_allocated_attention_for_SD_numberic(self, bs, seqlen, d, dtype):
+    def test_allocated_attention_for_SD_numberic(self, simulation_only, bs, seqlen, d, dtype):
         q = np.random.random_sample((bs, d, seqlen)).astype(np.float32)
         k = np.random.random_sample((bs, d, seqlen)).astype(np.float32)
         v = np.random.random_sample((bs, seqlen, d)).astype(np.float32)
@@ -61,7 +62,11 @@ def test_allocated_attention_for_SD_numberic(self, bs, seqlen, d, dtype):
         k_dev = nl.static_cast(k, dtype)
         v_dev = nl.static_cast(v, dtype)
 
-        out = numeric_func[bs](q_dev, k_dev, v_dev)
+        numeric_func = baremetal(allocated_fused_self_attn_for_SD_small_head_size)
+        if simulation_only:
+            out = simulate_kernel(numeric_func[bs], q_dev, k_dev, v_dev)
+        else:
+            out = numeric_func[bs](q_dev, k_dev, v_dev)
         out = nl.static_cast(out, np.float32)
         golden_result = cpu_golden_attn(q, k, v)
         assert np.allclose(out, golden_result, atol=1e-2)
diff --git a/test/unit/test_flash_attn_bwd.py b/test/unit/test_flash_attn_bwd.py
index 3aedab0..0f45f9f 100644
--- a/test/unit/test_flash_attn_bwd.py
+++ b/test/unit/test_flash_attn_bwd.py
@@ -2,14 +2,14 @@
 Copyright (c) 2023, Amazon.com. All Rights Reserved
 """
 import pytest
-from neuronxcc.nki.kernels.attention import flash_attn_bwd
-from neuronxcc.nki import benchmark, baremetal
+from nki_samples.reference.attention import flash_attn_bwd
+from neuronxcc.nki import benchmark, baremetal, simulate_kernel
 import neuronxcc.nki.language as nl
 import numpy as np
 
-from TestDecorators import xfail
+xfail = pytest.mark.arch_specific_xfail
+
 
-numeric_func = baremetal(flash_attn_bwd)
 bench_func = benchmark(warmup=5, iters=10)(flash_attn_bwd)
 
 def softmax(x: np.ndarray, dim: int, zero_max_mode=False,
@@ -110,13 +110,14 @@ def test_flash_attn_bwd_perf(self, bs, nheads, seqlen, d, dtype, latency):
         bench_func_(q, k, v, o_proj, dy, lse, seed,
                     use_causal_mask=True, mixed_precision=True)
         latency_res = bench_func_.benchmark_result.nc_latency
-        p99 = latency_res.get_latency_percentile(99)
+        p99 = latency_res.get_latency_percentile(50)
         assert p99 <= latency
 
+    @pytest.mark.simulation
     @pytest.mark.parametrize("bs, nheads, seqlen, d, dtype", [
         [1, 4, 4096, 128, np.float32],
     ])
-    def test_flash_attn_bwd_numerical(self, bs, nheads, seqlen, d, dtype):
+    def test_flash_attn_bwd_numerical(self, simulation_only, bs, nheads, seqlen, d, dtype):
         q = (np.random.random_sample([bs, nheads, d, seqlen]) - 0.5) * 2
         k = (np.random.random_sample([bs, nheads, d, seqlen]) - 0.5) * 2
         v = (np.random.random_sample([bs, nheads, d, seqlen]) - 0.5) * 2
@@ -135,7 +136,13 @@ def test_flash_attn_bwd_numerical(self, bs, nheads, seqlen, d, dtype):
                                                               nl.tile_size.pmax).transpose(0, 1, 3, 2)
         lse = -1.0 * (cached_negative_max + np.log(cached_sum_reciprocal))
 
-        out_dq, out_dk, out_dv = numeric_func[bs, nheads](q, k, v, o_proj, dy, lse, seed,
+        numeric_func = baremetal(flash_attn_bwd)
+        if simulation_only:
+           out_dq, out_dk, out_dv = simulate_kernel(numeric_func[bs, nheads], q, k, v, o_proj, dy, lse, seed,
+                                                          use_causal_mask=True,
+                                                          mixed_precision=True)
+        else:
+          out_dq, out_dk, out_dv = numeric_func[bs, nheads](q, k, v, o_proj, dy, lse, seed,
                                                           use_causal_mask=True,
                                                           mixed_precision=True)
 
diff --git a/test/unit/test_flash_attn_fwd.py b/test/unit/test_flash_attn_fwd.py
index fff4ac2..e52354d 100644
--- a/test/unit/test_flash_attn_fwd.py
+++ b/test/unit/test_flash_attn_fwd.py
@@ -2,12 +2,11 @@
 Copyright (c) 2023, Amazon.com. All Rights Reserved
 """
 import pytest
-from neuronxcc.nki.kernels.attention import flash_fwd, FlashConfig
-from neuronxcc.nki import benchmark, baremetal
+from nki_samples.reference.attention import flash_fwd, FlashConfig
+from neuronxcc.nki import benchmark, baremetal, simulate_kernel
 import neuronxcc.nki.language as nl
 import numpy as np
- 
-numeric_func = baremetal(flash_fwd)
+
 bench_func = benchmark(warmup=5, iters=10)(flash_fwd)
  
 def softmax(x: np.ndarray, dim: int, zero_max_mode=False,
@@ -95,9 +94,10 @@ def test_flash_attn_fwd_perf(self, bs, nheads, seqlen_q, seqlen_k, d, dtype, use
         bench_func_(q, k, v, seed, use_causal_mask=use_causal_mask,
                     mixed_precision=mixed_precision, config=config)
         latency_res = bench_func_.benchmark_result.nc_latency
-        p99 = latency_res.get_latency_percentile(99)
+        p99 = latency_res.get_latency_percentile(50)
         assert p99 <= latency
- 
+    
+    @pytest.mark.simulation
     @pytest.mark.parametrize("bs, nheads, seqlen_q, seqlen_k, d, dtype, use_causal_mask,\
                               training, tile_size, kv_heads, should_transpose_v", [
     [1, 6, 4096, 4096, 128, np.float32, True, True, 2048, 3, False],
@@ -105,7 +105,7 @@ def test_flash_attn_fwd_perf(self, bs, nheads, seqlen_q, seqlen_k, d, dtype, use
     [1, 1, 8192, 4096, 128, np.float32, True, False, 2048, None, False],
     [1, 1, 4096, 8192, 128, np.float32, True, False, 2048, None, False],
     ])
-    def test_flash_attn_fwd_numerical(self, bs, nheads, seqlen_q, seqlen_k, d, dtype, use_causal_mask, 
+    def test_flash_attn_fwd_numerical(self, simulation_only, bs, nheads, seqlen_q, seqlen_k, d, dtype, use_causal_mask, 
                                      training, tile_size, kv_heads, should_transpose_v):
         q = (np.random.random_sample([bs, nheads, d, seqlen_q]) - 0.5) * 2
         k = (np.random.random_sample([bs, kv_heads or nheads, d, seqlen_k]) - 0.5) * 2
@@ -132,7 +132,15 @@ def test_flash_attn_fwd_numerical(self, bs, nheads, seqlen_q, seqlen_k, d, dtype
         config = FlashConfig(**{'seq_tile_size':tile_size, 'training':training, 'should_transpose_v':should_transpose_v})
 
         heads = nheads if kv_heads is None else kv_heads
-        results = numeric_func[bs, heads](q, k, v, seed,
+
+        numeric_func = baremetal(flash_fwd)
+        if simulation_only:
+            results = simulate_kernel(numeric_func[bs, heads], q, k, v, seed,
+                                          use_causal_mask=use_causal_mask,
+                                          mixed_precision=True,
+                                          config=config)
+        else:
+            results = numeric_func[bs, heads](q, k, v, seed,
                                           use_causal_mask=use_causal_mask,
                                           mixed_precision=True,
                                           config=config)
diff --git a/test/unit/test_neuron_profile.py b/test/unit/test_neuron_profile.py
new file mode 100644
index 0000000..e607705
--- /dev/null
+++ b/test/unit/test_neuron_profile.py
@@ -0,0 +1,86 @@
+from neuronxcc.nki import benchmark
+from neuronxcc.nki import profile
+import neuronxcc.nki.language as nl
+import numpy as np
+import pytest
+import os
+import shutil
+import tempfile
+
+
+WORKING_DIRECTORY = tempfile.mkdtemp()
+SAVE_NEFF_NAME = "cus_file123.neff"
+SAVE_TRACE_NAME = "profile-custom.ntff"
+NUM_EXECS = 20
+PROFILE_NTH = 10  
+JSON_REPORTS = "json_reports"
+
+@profile(working_directory=WORKING_DIRECTORY, save_neff_name=SAVE_NEFF_NAME, overwrite=False , save_trace_name=SAVE_TRACE_NAME, num_execs=NUM_EXECS, profile_nth=PROFILE_NTH)
+def nki_tensor_tensor_add(a_tensor, b_tensor):
+  c_output = nl.ndarray(a_tensor.shape, dtype=a_tensor.dtype, buffer=nl.shared_hbm)
+ 
+  a = nl.load(a_tensor)
+  b = nl.load(b_tensor)
+
+  c_tile = a + b
+
+  nl.store(c_output, value=c_tile)
+
+  return c_output
+
+class TestNeuronProfile:
+    def _get_ntff_path(self, trace_val):
+        """
+        Prepares ntff file name based on execution trace number
+        """
+        if trace_val == 1:
+            return os.path.join(WORKING_DIRECTORY, f"{os.path.splitext(os.path.basename(SAVE_TRACE_NAME))[0]}.ntff")
+        else:
+            return os.path.join(WORKING_DIRECTORY, f"{os.path.splitext(os.path.basename(SAVE_TRACE_NAME))[0]}_exec_{trace_val}.ntff")
+
+    @pytest.fixture
+    def traces(self):
+        ret = []
+        if NUM_EXECS < PROFILE_NTH:
+            ret.append(self._get_ntff_path(PROFILE_NTH))
+        else:
+            curr = PROFILE_NTH
+            while curr <= NUM_EXECS:
+                ret.append(self._get_ntff_path(curr))
+                curr += PROFILE_NTH
+        return ret
+    
+    @pytest.fixture
+    def num_reports(self):
+        if NUM_EXECS < PROFILE_NTH:
+            return 1
+        else:
+            return NUM_EXECS // PROFILE_NTH
+
+    def test_output_artifacts_created(self, traces, num_reports):
+        # delete artifact directory, only testing non-overwrite functionality
+        if os.path.exists(WORKING_DIRECTORY):
+            shutil.rmtree(WORKING_DIRECTORY)
+
+        # creates dummy input to invoke profile kernel
+        a = np.zeros([128, 1024]).astype(np.float16)
+        b = np.random.random_sample([128, 1024]).astype(np.float16)
+
+        output_nki = nki_tensor_tensor_add(a, b)
+
+        # now asserting artifacts are correctly created     
+        assert os.path.exists(os.path.join(WORKING_DIRECTORY, SAVE_NEFF_NAME)) # neff
+        
+        for trace in traces:
+            assert os.path.exists(trace) # trace
+        
+        # json reports
+        report_dir = os.path.join(WORKING_DIRECTORY, JSON_REPORTS)
+
+        assert os.path.exists(report_dir) # actually exists
+        assert len(os.listdir(report_dir)) == num_reports # report all iterations queried
+
+        # post condition cleanup
+        if os.path.exists(WORKING_DIRECTORY):
+            shutil.rmtree(WORKING_DIRECTORY)
+
diff --git a/test/unit/test_resize_nearest.py b/test/unit/test_resize_nearest.py
index 2bbc601..72e7aef 100644
--- a/test/unit/test_resize_nearest.py
+++ b/test/unit/test_resize_nearest.py
@@ -3,12 +3,11 @@
 """
 import pytest
 
-from neuronxcc.nki.kernels.vision import resize_nearest_fixed_dma_kernel
-from neuronxcc.nki import benchmark, baremetal
+from nki_samples.reference.vision import resize_nearest_fixed_dma_kernel
+from neuronxcc.nki import benchmark, baremetal, simulate_kernel
 import neuronxcc.nki.language as nl
 import numpy as np
 
-numeric_func = baremetal(resize_nearest_fixed_dma_kernel)
 bench_func = benchmark(warmup=5, iters=10)(resize_nearest_fixed_dma_kernel)
 
 
@@ -49,20 +48,25 @@ def test_resize_nearest_for_perf(self, in_b, in_h, in_w, in_c, out_b, out_h, out
         bench_func_ = bench_func[in_b]
         bench_func_(input_dev, (out_b, out_h, out_w, out_c))
         latency_res = bench_func_.benchmark_result.nc_latency
-        p99 = latency_res.get_latency_percentile(99)
+        p99 = latency_res.get_latency_percentile(50)
         assert p99 <= latency
 
+    @pytest.mark.simulation
     @pytest.mark.parametrize("in_b, in_h, in_w, in_c, out_b, out_h, out_w, out_c, dtype", [
  	    [10, 30, 20, 1280, 10, 59, 38, 1280, np.float32],
         [1, 30, 20, 1280, 1, 59, 38, 1280, nl.float16],
         [1, 30, 20, 1280, 1, 59, 38, 1280, nl.bfloat16],
  	])
-    def test_resize_nearest_for_numberic(self, in_b, in_h, in_w, in_c, out_b, out_h, out_w, out_c, dtype):
+    def test_resize_nearest_for_numberic(self, simulation_only, in_b, in_h, in_w, in_c, out_b, out_h, out_w, out_c, dtype):
         input_tensor = np.random.random_sample((in_b, in_h, in_w, in_c)).astype(np.float32)
 
         input_dev = nl.static_cast(input_tensor, dtype)
 
-        output_tensor = numeric_func[in_b](input_dev, (out_b, out_h, out_w, out_c))
+        numeric_func = baremetal(resize_nearest_fixed_dma_kernel)
+        if simulation_only:
+            output_tensor = simulate_kernel(numeric_func[in_b], input_dev, (out_b, out_h, out_w, out_c))
+        else:
+            output_tensor = numeric_func[in_b](input_dev, (out_b, out_h, out_w, out_c))
         output_tensor = nl.static_cast(output_tensor, np.float32)
         golden_result = cpu_golden_result(input_tensor, output_tensor.shape)
         assert np.allclose(output_tensor, golden_result, atol=1e-2)
diff --git a/test/unit/test_rmsnorm_qkv.py b/test/unit/test_rmsnorm_qkv.py
index 24ad31c..28838d1 100644
--- a/test/unit/test_rmsnorm_qkv.py
+++ b/test/unit/test_rmsnorm_qkv.py
@@ -2,12 +2,11 @@
 Copyright (c) 2024, Amazon.com. All Rights Reserved
 """
 import pytest
-from neuronxcc.nki.kernels.allocated_fused_linear import allocated_fused_rms_norm_qkv
-from neuronxcc.nki import benchmark, baremetal
+from nki_samples.reference.allocated_fused_linear import allocated_fused_rms_norm_qkv
+from neuronxcc.nki import benchmark, baremetal, simulate_kernel
 import neuronxcc.nki.language as nl
 import numpy as np
 
-numeric_func = baremetal(allocated_fused_rms_norm_qkv)
 bench_func = benchmark(warmup=5, iters=10)(allocated_fused_rms_norm_qkv)
 
 np.random.seed(0)
@@ -31,7 +30,7 @@ class TestRMSNormQKV:
     [1, 128, 512, 512, np.float16, 25],
     [1, 512, 1024, 384, nl.bfloat16, 40],
     [1, 128, 1024, 512, nl.bfloat16, 28],
-    [1, 1024, 8192, 512, nl.bfloat16, 301 * 1.02]
+    # [1, 1024, 8192, 512, nl.bfloat16, 301 * 1.02], # FIXME: performance is flaky
   ])
   def test_allocated_rmsnorm_qkv_perf(self, batch, seqlen, dim, d_head, dtype, latency):
     hidden = np.random.random_sample((batch, seqlen, dim)).astype(np.float32)
@@ -42,23 +41,28 @@ def test_allocated_rmsnorm_qkv_perf(self, batch, seqlen, dim, d_head, dtype, lat
 
     bench_func(hidden, weights)
     latency_res = bench_func.benchmark_result.nc_latency
-    p99 = latency_res.get_latency_percentile(99)
+    p99 = latency_res.get_latency_percentile(50)
     assert p99 <= latency
 
+  @pytest.mark.simulation
   @pytest.mark.parametrize("batch, seqlen, dim, d_head, dtype", [
     [1, 128, 512, 512, np.float16],
     [1, 512, 1024, 384, nl.bfloat16],
     [1, 128, 1024, 512, nl.bfloat16],
     [1, 1024, 8192, 512, nl.bfloat16]
   ])
-  def test_allocated_rmsnorm_qkv_numeric(self, batch, seqlen, dim, d_head, dtype):
+  def test_allocated_rmsnorm_qkv_numeric(self, simulation_only, batch, seqlen, dim, d_head, dtype):
     hidden = np.random.random_sample((batch, seqlen, dim))
     weights = np.random.random_sample((dim, d_head))
 
     hidden_dev = nl.static_cast(hidden, dtype)
     weights_dev = nl.static_cast(weights, dtype)
 
-    out = numeric_func(hidden_dev, weights_dev)
+    numeric_func = baremetal(allocated_fused_rms_norm_qkv)
+    if simulation_only:
+      out = simulate_kernel(numeric_func, hidden_dev, weights_dev)
+    else:
+      out = numeric_func(hidden_dev, weights_dev)
     out = nl.static_cast(out, np.float32)
     golden_res = nl.static_cast(cpu_golden_result(hidden, None, weights, dtype, do_norm=True), np.float32)
     assert np.allclose(out, golden_res, atol=1e-2, rtol=1e-2)
diff --git a/test/unit/test_select_and_scatter.py b/test/unit/test_select_and_scatter.py
index fc99b37..08e787f 100644
--- a/test/unit/test_select_and_scatter.py
+++ b/test/unit/test_select_and_scatter.py
@@ -1,11 +1,10 @@
 import pytest
 
-from neuronxcc.nki.kernels.vision import select_and_scatter_kernel
-from neuronxcc.nki import benchmark, baremetal
+from nki_samples.reference.vision import select_and_scatter_kernel
+from neuronxcc.nki import benchmark, baremetal, simulate_kernel
 import neuronxcc.nki.language as nl
 import numpy as np
 
-numeric_func = baremetal(select_and_scatter_kernel)
 bench_func = benchmark(warmup=5, iters=10)(select_and_scatter_kernel)
 
 np.random.seed(0)
@@ -51,14 +50,15 @@ def test_select_and_scatter_for_perf(self, n, c, operand_h, operand_w, source_h,
 
         bench_func(operand_dev, source_dev)
         latency_res = bench_func.benchmark_result.nc_latency
-        p99 = latency_res.get_latency_percentile(99)
+        p99 = latency_res.get_latency_percentile(50)
         assert p99 <= latency
 
+    @pytest.mark.simulation
     @pytest.mark.parametrize("n, c, operand_h, operand_w, source_h, source_w, dtype", [
  	    [8, 64, 112, 112, 56, 56, np.float32],
  	    [8, 64, 112, 112, 56, 56, nl.bfloat16],
  	])
-    def test_select_and_scatter_for_numeric(self, n, c, operand_h, operand_w, source_h, source_w, dtype):
+    def test_select_and_scatter_for_numeric(self,simulation_only, n, c, operand_h, operand_w, source_h, source_w, dtype):
         operand_dev = nl.static_cast(np.random.random_sample((n, c, operand_h, operand_w)), dtype)
         source_dev = nl.static_cast(np.random.random_sample((n, c, source_h, source_w)), dtype)
 
@@ -66,7 +66,11 @@ def test_select_and_scatter_for_numeric(self, n, c, operand_h, operand_w, source
         operand_tensor = nl.static_cast(operand_dev, np.float32)
         source_tensor = nl.static_cast(source_dev, np.float32)
 
-        output_dev = numeric_func(operand_dev, source_dev)
+        numeric_func = baremetal(select_and_scatter_kernel)
+        if simulation_only:
+            output_dev = simulate_kernel(numeric_func, operand_dev, source_dev)
+        else:
+            output_dev = numeric_func(operand_dev, source_dev)
         golden_result = cpu_golden_result(operand_tensor, source_tensor)
         nki_result = nl.static_cast(output_dev, np.float32)