diff --git a/.github/workflows/run_simulation_tests.yml b/.github/workflows/run_simulation_tests.yml
new file mode 100644
index 0000000..7ea89ef
--- /dev/null
+++ b/.github/workflows/run_simulation_tests.yml
@@ -0,0 +1,32 @@
+name: Run Python Simulation Tests
+
+on:
+  push:
+    branches: [ "main" ]
+  pull_request:
+    branches: [ "main" ]
+
+permissions:
+  contents: read
+
+jobs:
+  build:
+
+    runs-on: ubuntu-latest
+
+    steps:
+    - uses: actions/checkout@v4
+    - name: Set up Python 3.10
+      uses: actions/setup-python@v3
+      with:
+        python-version: "3.10"
+    - name: Install dependencies
+      run: |
+        python -m pip install --upgrade pip
+        python -m pip config set global.extra-index-url https://pip.repos.neuron.amazonaws.com
+        python -m pip install wget awscli
+        python -m pip install pytest
+        python -m pip install neuronx-cc==2.*
+    - name: Test with pytest
+      run: |
+        PYTHONPATH=$PYTHONPATH:src/ pytest test/unit/ --simulation-only
\ No newline at end of file
diff --git a/src/nki_samples/reference/__init__.py b/src/nki_samples/reference/__init__.py
new file mode 100644
index 0000000..c9e5d37
--- /dev/null
+++ b/src/nki_samples/reference/__init__.py
@@ -0,0 +1,10 @@
+# Copyright (c) 2023, Amazon.com. All Rights Reserved
+
+"""
+Package containing public kernels for Neuron Kernel Interface (NKI).
+
+Kernels here are also available in the `neuronxcc.nki.kernels` namespace, and they 
+are synced with this repository on every Neuron SDK release. 
+
+https://github.com/aws-neuron/nki-samples
+"""
diff --git a/src/nki_samples/reference/allocated_attention.py b/src/nki_samples/reference/allocated_attention.py
new file mode 100644
index 0000000..94b513f
--- /dev/null
+++ b/src/nki_samples/reference/allocated_attention.py
@@ -0,0 +1,283 @@
+import functools
+import neuronxcc.nki as nki
+import neuronxcc.nki.language as nl
+import neuronxcc.nki.isa as nisa
+import neuronxcc.nki.compiler as ncc
+from neuronxcc.nki.language import par_dim
+import numpy as np
+
+@nki.jit
+def allocated_fused_self_attn_for_SD_small_head_size(q_ref, k_ref, v_ref,
+                                           use_causal_mask=False,
+                                           mixed_precision=True):
+  """
+  Allocated fused self attention kernel for small head size Stable Diffusion workload.
+  
+  Computes (softmax(Q.T@K)V).T. The wired layout is chosen to avoid transpose as
+  much as possible to simplify the debug. The kernel uses the direct allocation API,
+  and implements double buffering to achieve better performance than automatic allocation.
+  As of NeuronSDK 2.21, it achieves 18% better performance than auto allocated equivalent.
+  To see the performance gap, you can use ``force_auto_alloc`` decorator to override
+  manual allocation and benchmark the performance difference.
+
+  This kernel is designed to be used for Stable Diffusion models where the 
+  n_heads is equal to 128. Seqlen must be divisible by 1024, and smaller than 5120. 
+  Assertion is thrown if ``n_heads`` or sequence length does not satisfy the requirement.
+  These restrictions are to simplify the address calculation in allocations.
+
+  IO tensor layouts:
+   - q_ptr: shape   (bs, d_heads, seq_q)
+   - k_ptr: shape   (bs, d_heads, seq_k)
+   - v_ptr: shape   (bs, seq_v, n_heads)
+   - out_ptr: shape (bs, d_heads, seq_q)
+   - We use seq_q and seq_k just for clarity, this kernel requires seq_q == seq_k
+
+  IO tensor dtypes:
+   - This kernel assumes all IO tensors have the same dtype
+   - If mixed_precision is True, then all Tensor Engine operation will be performed in
+     bfloat16 and accumulation will be performed in float32. Otherwise the intermediates
+     will be in the same type as the inputs.
+  """
+  # Use q_ref dtype as the intermediate tensor dtype
+  # Assume all IO tensors have the same dtype
+  kernel_dtype = np.float32
+  pe_in_dt = nl.bfloat16 if mixed_precision else kernel_dtype
+
+  kernel_dtype_itemsize = np.dtype(kernel_dtype).itemsize
+  pe_in_dt_itemsize = np.dtype(pe_in_dt).itemsize
+  assert q_ref.dtype == k_ref.dtype == v_ref.dtype
+
+  # Shape checking
+  bs, d_head, seqlen = q_ref.shape
+  assert d_head <= 128, "Cannot use this kernel for d_head > 128"
+  assert tuple(q_ref.shape) == (bs, d_head, seqlen), 'Input shape mismatch!'
+  assert tuple(k_ref.shape) == (bs, d_head, seqlen), 'Input shape mismatch!'
+  assert tuple(v_ref.shape) == (bs, seqlen,
+                                d_head), f'Input shape mismatch! Expected: {(bs, seqlen, d_head)} Actual: {tuple(v_ref.shape)}'
+  out_ref = nl.ndarray((bs, d_head, seqlen), dtype=q_ref.dtype, buffer=nl.shared_hbm)
+
+  assert d_head == 128
+
+  cur_addr = 0
+
+  id0 = nl.arange(0, 128)[:, None]
+  id1 = nl.arange(0, 128)[None, :]
+  identity = nl.shared_constant(np.identity(128, dtype=np.int8), dtype=nl.bfloat16)
+  identity_load = nl.ndarray((par_dim(128), 128), dtype=pe_in_dt, buffer=ncc.sbuf.mod_alloc(base_addr=cur_addr))
+  cur_addr += 128 * pe_in_dt_itemsize
+  identity_load[id0, id1] = nl.load(identity)
+
+  identity_load_fp32 = nl.ndarray((par_dim(128), 128), dtype=np.float32, buffer=ncc.sbuf.mod_alloc(base_addr=cur_addr))
+  cur_addr += 128 * np.dtype(np.float32).itemsize
+  identity_load_fp32[id0, id1] = nl.load(identity)
+
+  # Softmax scaling factor, multiplied onto Q
+  softmax_scale = 0.125
+
+  # Different batch samples/attention heads have independent attention
+  batch_id = nl.program_id(axis=0)
+
+  q_seq_n_tiles, q_seq_tile_size = seqlen // 128, 128
+  k_seq_n_tiles, k_seq_tile_size = seqlen // 512, 512
+  # No tiling on d_head dimension since the number of d_head fits in SB
+  d_head_tile_size = d_head
+  v_seq_n_tiles, v_seq_tile_size = seqlen // 128, 128
+
+  ###################################
+  # Step 1. preload tensors
+  ###################################
+  v_local = nl.ndarray((v_seq_n_tiles, par_dim(v_seq_tile_size), d_head), dtype=pe_in_dt, buffer=ncc.sbuf.mod_alloc(base_addr=cur_addr, num_free_tiles=(v_seq_n_tiles, ))) # 8kb
+  cur_addr += v_seq_n_tiles * d_head * pe_in_dt_itemsize
+
+  for i_v_seq_tile in nl.affine_range(v_seq_n_tiles):
+    ip_v = nl.arange(v_seq_tile_size)[:, None]
+    if_v = nl.arange(d_head_tile_size)[None, :]
+    v_local[i_v_seq_tile, ip_v, if_v] = nl.load(
+      v_ref[batch_id, i_v_seq_tile * v_seq_tile_size + ip_v, if_v],
+      dtype=pe_in_dt)
+
+  q_local = nl.ndarray((q_seq_n_tiles, par_dim(d_head_tile_size), q_seq_tile_size), dtype=pe_in_dt, buffer=ncc.sbuf.mod_alloc(base_addr=cur_addr, num_free_tiles=(q_seq_n_tiles, ))) # 8kb
+  cur_addr += q_seq_n_tiles * q_seq_tile_size * pe_in_dt_itemsize
+  ip_q = nl.arange(d_head_tile_size)[:, None]
+  if_q = nl.arange(q_seq_tile_size)[None, :]
+  for i_q_seq_tile in nl.affine_range(q_seq_n_tiles):
+    q_local[i_q_seq_tile, ip_q, if_q] = nl.load(
+      q_ref[batch_id, ip_q, i_q_seq_tile * q_seq_tile_size + if_q],
+      dtype=pe_in_dt)
+    q_local[i_q_seq_tile, ip_q, if_q] = nl.multiply(q_local[i_q_seq_tile, ip_q, if_q], softmax_scale)
+
+  k_local = nl.ndarray((k_seq_n_tiles, par_dim(d_head_tile_size), k_seq_tile_size), dtype=pe_in_dt, buffer=ncc.sbuf.mod_alloc(base_addr=cur_addr, num_free_tiles=(k_seq_n_tiles, ))) # 8kb
+  cur_addr += k_seq_n_tiles * k_seq_tile_size * pe_in_dt_itemsize
+  ip_k = nl.arange(d_head_tile_size)[:, None]
+  if_k = nl.arange(k_seq_tile_size)[None, :]
+  for i_k_seq_tile in nl.affine_range(k_seq_n_tiles):
+    k_local[i_k_seq_tile, ip_k, if_k] = nl.load(
+      k_ref[batch_id,
+            ip_k,
+            i_k_seq_tile * k_seq_tile_size + if_k
+            ],
+      dtype=pe_in_dt)
+
+  for i_q_seq_tile in nl.affine_range(q_seq_n_tiles//2):  # indent = 2
+    # perform activation and reduction in softmax in larger tile to amortize instruction overhead
+    reduction_size = 1024
+    reduction_tiles = seqlen // reduction_size
+
+    # =================================== SBUF Allocation Starts ===================================
+
+    # The num_free_tiles is intentionally set to (1, ) to disable double buffering on the first matmul.
+    # From the profile, when the first matmul is double buffered, the tensor_scalar_reduce instruction that writes to this buffer
+    # spends long time waiting for the matmul it depends on to be executed. The instruction scheduler made a bad decision and 
+    # clogged the pipeline when double buffering is on. This is a workaround to hint the scheduler.
+    qk_res_buf = nl.ndarray((2, par_dim(q_seq_tile_size), seqlen), dtype=kernel_dtype, buffer=ncc.sbuf.mod_alloc(base_addr=cur_addr, num_free_tiles=(1, ))) # 32 k
+    cur_addr += seqlen * kernel_dtype_itemsize
+    exp_res = nl.ndarray((2, par_dim(q_seq_tile_size), seqlen),dtype=pe_in_dt, buffer=ncc.sbuf.mod_alloc(base_addr=cur_addr, num_free_tiles=(2, ))) # 16 kb
+    cur_addr += seqlen * 2 * pe_in_dt_itemsize
+    trans_softmax_res = nl.ndarray(
+        (2, par_dim(v_seq_tile_size), seqlen), name='trans_softmax_res',
+        dtype=pe_in_dt, buffer=ncc.sbuf.mod_alloc(base_addr=cur_addr, num_free_tiles=(2, ))) # 16kb
+    cur_addr += seqlen * 2 * pe_in_dt_itemsize
+    
+    sum_divisor = nl.ndarray((2, par_dim(q_seq_tile_size), d_head_tile_size), dtype=kernel_dtype, buffer=ncc.sbuf.mod_alloc(base_addr=cur_addr, num_free_tiles=(2, ))) # 1kb
+    cur_addr += 2 * d_head_tile_size * kernel_dtype_itemsize
+    sum_reciprocal_broadcast = nl.ndarray((2, par_dim(q_seq_tile_size), d_head_tile_size), dtype=kernel_dtype, buffer=ncc.sbuf.mod_alloc(base_addr=cur_addr, num_free_tiles=(2, ))) # 1kb
+    cur_addr += 2 * d_head_tile_size * kernel_dtype_itemsize
+    
+    attn_res_sbuf = nl.ndarray((2, par_dim(d_head_tile_size), q_seq_tile_size), dtype=kernel_dtype,
+                                buffer=ncc.sbuf.mod_alloc(base_addr=cur_addr, num_free_tiles=(2, )), name="attn_res_sbuf") # 1kb
+    cur_addr += 2 * q_seq_tile_size * kernel_dtype_itemsize
+    attn_res_div = nl.ndarray((2, par_dim(q_seq_tile_size), d_head_tile_size), dtype=kernel_dtype,
+                                buffer=ncc.sbuf.mod_alloc(base_addr=cur_addr, num_free_tiles=(2,))) # 1kb
+    cur_addr += 2 * d_head_tile_size * kernel_dtype_itemsize
+    
+    neg_max_res = nl.ndarray((2, par_dim(q_seq_tile_size), k_seq_n_tiles), dtype=kernel_dtype, buffer=ncc.sbuf.mod_alloc(base_addr=cur_addr, num_free_tiles=(2, ))) # 64b
+    cur_addr += 2 * k_seq_n_tiles * kernel_dtype_itemsize
+    partial_sum_res = nl.ndarray((2, par_dim(q_seq_tile_size), reduction_tiles), dtype=kernel_dtype, buffer=ncc.sbuf.mod_alloc(base_addr=cur_addr, num_free_tiles=(2, ))) # 32b
+    cur_addr += 2 * reduction_tiles * kernel_dtype_itemsize
+    neg_max_res_final = nl.ndarray((2, par_dim(q_seq_tile_size), 1), dtype=kernel_dtype, buffer=ncc.sbuf.mod_alloc(base_addr=cur_addr, num_free_tiles=(2, ))) # 8b
+    cur_addr += 2 * 1 * kernel_dtype_itemsize
+    sum_res = nl.ndarray((2, par_dim(q_seq_tile_size), 1), dtype=kernel_dtype, buffer=ncc.sbuf.mod_alloc(base_addr=cur_addr, num_free_tiles=(2, ))) # 8b
+    cur_addr += 2 * 1 * kernel_dtype_itemsize
+    sum_reciprocal = nl.ndarray((2, par_dim(q_seq_tile_size), 1), dtype=kernel_dtype, buffer=ncc.sbuf.mod_alloc(base_addr=cur_addr, num_free_tiles=(2, ))) # 8b
+    cur_addr += 2 * 1 * kernel_dtype_itemsize
+
+    # =================================== SBUF Allocation End ===================================
+ 
+    qk_psum = nl.ndarray((2, k_seq_n_tiles, par_dim(q_seq_tile_size), k_seq_tile_size),
+                          dtype=np.float32, buffer=ncc.psum.mod_alloc(base_bank=0, num_bank_tiles=(2, 4)))
+    
+    assert k_seq_tile_size == 4 * v_seq_tile_size
+    local_tp_buf = nl.ndarray((2, k_seq_n_tiles, par_dim(q_seq_tile_size), k_seq_tile_size), dtype=np.float32,
+                                  buffer=ncc.psum.mod_alloc(base_bank=0, num_bank_tiles=(2, 4)))
+    
+    def psum_addr(bank_map, idx, pdim_size, fdim_size):
+      return (bank_map[idx], 0, 0)
+    
+    # Result psum buffer has the hidden dim as P
+    # qk_psum is using 0, 1, 2, 3 for fisrt interleave group, and 4, 5, 6, 7 for the second.
+    # assign 1 and 5 avoid bank collision between groups
+    attn_res_psum = nl.ndarray((2, par_dim(d_head_tile_size), q_seq_tile_size),
+                            dtype=np.float32, buffer=ncc.psum.alloc(functools.partial(psum_addr, bank_map={(0, ): 1, (1, ): 5})))
+
+    sum_local_tp_buf = nl.ndarray((2, par_dim(q_seq_tile_size), k_seq_tile_size), dtype=np.float32,
+                                  buffer=ncc.psum.alloc(functools.partial(psum_addr, bank_map={(0, ): 2, (1, ): 7})))
+
+    for i_interleave_grp in nl.affine_range(2):
+      # A SBUF buffer tile for an independent softmax tile
+      ip_max = nl.arange(q_seq_tile_size)[:, None]
+      if_max = nl.arange(k_seq_n_tiles)[None, :]
+
+      # Loop over RHS free of matmul(stationary=tensor_q, moving=tensor_k, contract=d_head)
+      for i_k_seq_tile in nl.affine_range(k_seq_n_tiles):  # indent = 4
+
+        # Tensor indices for accessing qk result in k_seq_tile_size
+        ip_qk = nl.arange(q_seq_tile_size)[:, None]
+        if_qk = nl.arange(k_seq_tile_size)[None, :]
+
+        ##############################################################
+        # Step 2. matmul(stationary=tensor_q, moving=tensor_k, contract=d_head)
+        ##############################################################
+        qk_psum[i_interleave_grp, i_k_seq_tile, ip_qk, if_qk] = nisa.nc_matmul(moving=k_local[i_k_seq_tile, ip_k, if_k],
+                                                stationary=q_local[i_q_seq_tile*2+i_interleave_grp, ip_q, if_q])
+
+        ###################################
+        # Step 3. Apply optional causal mask
+        ###################################
+        if use_causal_mask:
+          assert not use_causal_mask, "Causal mask not supported yet!"
+          # Magic number -9984.0 to replace -inf similar to what Tensorizer uses
+          qk_res_buf[i_interleave_grp, ip_qk, i_k_seq_tile * k_seq_tile_size + if_qk] = nisa.affine_select(
+            pred=(i_q_seq_tile * q_seq_tile_size + ip_qk >= i_k_seq_tile * k_seq_tile_size + if_qk),
+            on_true_tile=qk_psum[i_interleave_grp, i_k_seq_tile,ip_qk, if_qk], on_false_value=-9984.0, dtype=kernel_dtype)
+        else:
+          # Copy result to SBUF and find partial maximum for softmax
+          qk_res_buf[i_interleave_grp, ip_qk, i_k_seq_tile * k_seq_tile_size + if_qk] = nisa.tensor_scalar_reduce(data=qk_psum[i_interleave_grp, i_k_seq_tile,ip_qk, if_qk], op0=np.add, operand0=1.0,
+              reduce_op=np.max, reduce_res=neg_max_res[i_interleave_grp, ip_max, i_k_seq_tile], dtype=kernel_dtype)
+
+      # Find global max from tiles
+      neg_max_res_final[i_interleave_grp, ip_max, 0] = nisa.tensor_reduce(
+        np.max, data=neg_max_res[i_interleave_grp, ip_max, if_max],
+        axis=(1,), dtype=kernel_dtype, negate=True)
+
+      ip_softmax = nl.arange(q_seq_tile_size)[:, None]
+      if_softmax = nl.arange(seqlen)[None, :]
+      ip_sum_res = nl.arange(q_seq_tile_size)[:, None]
+      if_sum_res = nl.arange(d_head_tile_size)[None, :]
+
+      if_reduction = nl.arange(reduction_size)[None, :]
+      for i_exp in nl.affine_range(reduction_tiles):
+        exp_res[i_interleave_grp, ip_softmax, i_exp*reduction_size + if_reduction] = nisa.activation_reduce(np.exp,
+          data=qk_res_buf[i_interleave_grp, ip_softmax, i_exp * reduction_size + if_reduction],
+          reduce_op=np.sum, reduce_res=partial_sum_res[i_interleave_grp, ip_softmax, i_exp],
+          bias=neg_max_res_final[i_interleave_grp, ip_max, 0], scale=1.0,                                                                                          
+        )
+
+      sum_res[i_interleave_grp, ip_softmax, 0] = nisa.tensor_reduce(np.add, data=partial_sum_res[i_interleave_grp, :, :], axis=(1,),
+                            dtype=kernel_dtype)
+      
+      sum_reciprocal[i_interleave_grp, ip_softmax, 0] = nl.divide(1.0, sum_res[i_interleave_grp, ip_softmax, 0])
+      sum_reciprocal_broadcast[i_interleave_grp, ip_softmax, if_sum_res] = sum_reciprocal[i_interleave_grp, ip_softmax, 0].broadcast_to((q_seq_tile_size, d_head_tile_size))
+      sum_divisor[i_interleave_grp, ip_sum_res, if_sum_res] = nl.copy(sum_reciprocal_broadcast[i_interleave_grp, ip_softmax, if_sum_res], dtype=kernel_dtype)
+
+      ###################################
+      # Step 5. transpose(softmax_res)
+      ###################################
+      ip_scores_t = nl.arange(v_seq_tile_size)[:, None]
+      if_scores_t = nl.arange(v_seq_tile_size)[None, :]
+      # Loop over matmul_1 contraction
+      for i_v_seq_tile in nl.affine_range(v_seq_n_tiles // 4):
+        for i_offset in nl.affine_range(4):
+          ip_scores = nl.arange(v_seq_tile_size)[:, None]
+          if_scores = nl.arange(v_seq_tile_size)[None, :]
+          
+          local_tp_buf[i_interleave_grp, i_v_seq_tile, ip_scores, i_offset*v_seq_tile_size + if_scores] = nisa.nc_matmul(
+            exp_res[i_interleave_grp, ip_scores, (i_v_seq_tile*4+i_offset) * v_seq_tile_size + if_scores],
+            identity_load)
+
+        if_batch = nl.arange(k_seq_tile_size)[None, :]
+        trans_softmax_res[i_interleave_grp, ip_scores_t, i_v_seq_tile*k_seq_tile_size + if_batch] = nl.copy(local_tp_buf[i_interleave_grp, i_v_seq_tile, ip_scores, if_batch])
+
+      ip_out = nl.arange(d_head_tile_size)[:, None]
+      if_out = nl.arange(q_seq_tile_size)[None, :]
+
+      for i_v_seq_tile in nl.affine_range(v_seq_n_tiles):
+        ######################################################################
+        # Step 6. matmul_1(stationary=v_local, moving=trans_softmax_res, contract=seqlen_v=seqlen_k)
+        ######################################################################
+        ip_v_t = nl.arange(v_seq_tile_size)[:, None]
+        if_v_t = nl.arange(d_head_tile_size)[None, :]
+        attn_res_psum[i_interleave_grp, ip_out, if_out] += \
+          nisa.nc_matmul(moving=trans_softmax_res[i_interleave_grp, ip_scores_t, i_v_seq_tile*v_seq_tile_size+if_scores_t],
+                        stationary=v_local[i_v_seq_tile, ip_v_t, if_v_t])
+      
+      attn_res_sbuf[i_interleave_grp, ip_out, if_out] = nisa.tensor_copy(attn_res_psum[i_interleave_grp, ip_out, if_out], 
+                                                                    dtype=kernel_dtype, engine=nisa.vector_engine)
+
+      sum_local_tp_buf[i_interleave_grp, ip_sum_res, if_sum_res] = nisa.nc_matmul(sum_divisor[i_interleave_grp, ip_sum_res, if_sum_res], identity_load_fp32)
+      attn_res_div[i_interleave_grp, ip_sum_res, if_sum_res] = attn_res_sbuf[i_interleave_grp, :, :] * sum_local_tp_buf[i_interleave_grp, ip_sum_res, if_sum_res]
+
+      nl.store(
+        out_ref[batch_id, ip_out, (i_q_seq_tile*2+i_interleave_grp) * q_seq_tile_size + if_out],
+        value=attn_res_div[i_interleave_grp, :, :])
+      
+  return out_ref
\ No newline at end of file
diff --git a/src/nki_samples/reference/allocated_fused_linear.py b/src/nki_samples/reference/allocated_fused_linear.py
new file mode 100644
index 0000000..21e32af
--- /dev/null
+++ b/src/nki_samples/reference/allocated_fused_linear.py
@@ -0,0 +1,114 @@
+"""
+Copyright (c) 2024, Amazon.com. All Rights Reserved
+
+kernels - Fused normalization with linear layers
+
+"""
+
+import neuronxcc.nki.language as nl
+import neuronxcc.nki.isa as nisa
+import neuronxcc.nki.compiler as ncc
+import math
+import numpy as np
+from neuronxcc import nki
+from neuronxcc.nki.language import par_dim
+
+@nki.jit
+def allocated_fused_rms_norm_qkv(hidden, weights, norm_dtype=nl.float32, eps=1e-6):
+  """
+  Allocated kernel that computes RMSNorm(hidden) @ wQKV. This kernel is designed to only handle fp16/bf16 tensor types.
+  Internally, normalizations are cast to fp32 to avoid NaN errors.
+
+  Args:
+      hidden (_type_): Input tensor of the attention block in BSH layout
+      weights (_type_): Fused QKV linear weights, assumed to be eltwise-multiplied with RMS norm weight vector (gamma)
+      out_tensor (_type_): Output tensor
+      norm_dtype (_type_, optional): Data type for RMS norm, should be f32 to avoid NaN. Defaults to nl.float32.
+      eps (_type_, optional): RMS norm epsilon term. Defaults to 1e-6.
+  """
+  # Hidden should be in BSH layout.
+  batch, batchless_shape = hidden.shape[0], hidden.shape[1:]
+  seqlen, dim = batchless_shape
+  _dim, head_dim = weights.shape
+
+  assert dim <= 8192 and dim & 128 == 0, "Unsupported hidden dimension"
+  assert _dim == dim, "Reduction dimension must match"
+  assert head_dim <= 512, "Head dimension must be 512 or less"
+
+  out_tensor = nl.ndarray((batch, seqlen, head_dim), dtype=hidden.dtype, buffer=nl.shared_hbm)
+
+  pmax, fmax = nl.tile_size.pmax, nl.tile_size.psum_fmax # 128, 512
+  ix, iy = nl.mgrid[0:pmax, 0:dim]
+  i_lhs = nl.mgrid[0:pmax, 0:pmax]
+  i_rhs = nl.mgrid[0:pmax, 0:fmax]
+  i_res = nl.mgrid[0:pmax, 0:fmax]
+  M = math.ceil(dim / pmax)
+  NUM_TRANSP_TILES = math.ceil(dim / fmax)
+  NUM_TILES = math.ceil(seqlen / pmax)
+  TILES_INT = math.ceil(NUM_TILES / 2)
+  scale = 1 / dim
+
+  iden_x, iden_y = nl.mgrid[0:pmax, 0:128]
+
+  identity_a = nl.shared_constant(np.identity(n=128, dtype=np.int8), dtype=hidden.dtype)
+  identity_tensor = nl.ndarray((par_dim(pmax), 128), dtype=weights.dtype, buffer=ncc.sbuf.mod_alloc(base_addr=0))
+  identity_tensor[iden_x, iden_y] = nl.load(identity_a, dtype=weights.dtype)
+  bias_placeholder = nl.ndarray((par_dim(pmax), 1), dtype=np.float32, buffer=ncc.sbuf.mod_alloc(base_addr=128*2))
+  bias_placeholder[...] = 0
+  
+  for b in nl.affine_range(batch):
+    weights_buffer = nl.ndarray((M, par_dim(pmax), fmax), dtype=weights.dtype,
+                                buffer=ncc.sbuf.mod_alloc(base_addr=260+(3*dim+fmax)*2+(dim+1)*4, num_free_tiles=(M,)))
+    # Preload the entire weights tensor. everything fits in SBUF for LLaMA 3.1 70B
+    for m in nl.affine_range(M):
+      weights_buffer[m, i_rhs.p, i_rhs.x] = nl.load(weights[m*pmax+i_rhs.p, i_rhs.x],
+                                                    mask=(m*pmax+i_rhs.p<dim) & (i_rhs.x<head_dim))
+    for i in nl.affine_range(TILES_INT):
+      # Double buffer the input tensor
+      in_bufs = nl.ndarray((2, par_dim(pmax), dim), dtype=hidden.dtype, buffer=ncc.sbuf.mod_alloc(base_addr=260, num_free_tiles=(2,)))
+      for i_interleave_grp in nl.affine_range(2):
+        in_bufs[i_interleave_grp] = nl.load(hidden[b, (2*i+i_interleave_grp)*pmax+ix, iy], mask=(2*i+i_interleave_grp)*pmax+ix < seqlen)
+        act = nl.ndarray((par_dim(pmax), dim), dtype=norm_dtype, buffer=ncc.sbuf.mod_alloc(base_addr=260+(2*dim)*2))
+
+        # Write the RMS and RMS Reciprocal tensors back out here, in-place
+        square_sum = nl.ndarray((par_dim(pmax), 1), dtype=norm_dtype, buffer=ncc.sbuf.mod_alloc(base_addr=260+(2*dim)*2+(dim)*4))
+
+        # Write the output of RMS and RMS^T (in-place) out to here
+        out_tile = nl.ndarray((par_dim(pmax), dim), dtype=weights.dtype,
+                              buffer=ncc.sbuf.mod_alloc(base_addr=260+(2*dim)*2+(dim+1)*4))
+        
+        # Store the final output tiles to here before sending back to DRAM
+        output_sbuf = nl.ndarray((par_dim(pmax), fmax), dtype=weights.dtype,
+                                buffer=ncc.sbuf.mod_alloc(base_addr=260+(3*dim)*2+(dim+1)*4))
+
+        act[...] = nisa.activation_reduce(op=nl.square, data=in_bufs[i_interleave_grp], reduce_op=np.add, reduce_res=square_sum[...], bias=bias_placeholder[...])
+        square_sum[...] = nisa.tensor_scalar(square_sum[...], np.multiply, scale, op1=np.add, operand1=eps)
+        square_sum[...] = nisa.activation(op=nl.rsqrt, data=square_sum[...], bias=bias_placeholder[...])
+
+        # all PE array ops must output to FP32 on trn1 but must match input dtype in trn2
+        if nisa.get_nc_version() == nisa.nc_version.gen3:
+          transpose_res_psum = nl.ndarray((NUM_TRANSP_TILES, par_dim(pmax), 4*pmax), dtype=weights.dtype,
+                                          buffer=ncc.psum.mod_alloc(base_bank=0, num_bank_tiles=(1,))) # FIXME: perf is better when all tiles are on bank 0?
+        else:
+          transpose_res_psum = nl.ndarray((NUM_TRANSP_TILES, par_dim(pmax), 4*pmax), dtype=np.float32,
+                                          buffer=ncc.psum.mod_alloc(base_bank=0, num_bank_tiles=(1,))) # FIXME: perf is better when all tiles are on bank 0?
+
+        for m in nl.affine_range(NUM_TRANSP_TILES):
+          # Perform (hidden .* RMS Reciprocal)^T in tiles of fmax (512)
+          out_tile[i_rhs.p, m*fmax+i_rhs.x] = nl.multiply(in_bufs[i_interleave_grp, i_rhs.p, m*fmax + i_rhs.x], square_sum[...], dtype=weights.dtype)
+          for j in nl.affine_range(4):
+            transpose_res_psum[m, i_lhs.p, j*pmax+i_lhs.x] = nisa.nc_matmul(out_tile[i_lhs.p, (m*4+j) * pmax + i_lhs.x], identity_tensor[...],
+                                                                            is_transpose=True)
+          out_tile[i_rhs.p, m * 4*pmax + i_rhs.x] = nl.copy(transpose_res_psum[m], dtype=hidden.dtype)
+        
+        # perform (RMSNorm(hidden)^T)^T @ wQKV
+        res_psum = nl.ndarray((1, par_dim(pmax), fmax), dtype=nl.float32,
+                              buffer=ncc.psum.mod_alloc(base_bank=7, num_bank_tiles=(1,)))
+        for m in nl.affine_range(M):
+          res_psum[0] += nisa.nc_matmul(out_tile[i_lhs.p, m*pmax+i_lhs.x], weights_buffer[m, i_rhs.p, i_rhs.x])
+        
+        output_sbuf[...] = nl.copy(res_psum[0], dtype=out_tensor.dtype)
+        nl.store(out_tensor[b, (2*i+i_interleave_grp)*pmax+i_res.p, i_res.x],
+                value=output_sbuf,
+                mask=((2*i+i_interleave_grp)*pmax+i_res.p<seqlen) & (i_res.x<head_dim))
+  return out_tensor
\ No newline at end of file
diff --git a/src/nki_samples/reference/attention.py b/src/nki_samples/reference/attention.py
new file mode 100644
index 0000000..3c456a6
--- /dev/null
+++ b/src/nki_samples/reference/attention.py
@@ -0,0 +1,1170 @@
+"""
+Copyright (c) 2023, Amazon.com. All Rights Reserved
+
+kernels - Builtin high performance attention kernels
+
+"""
+import numpy as np
+
+import neuronxcc.nki.isa as nisa
+import neuronxcc.nki.language as nl
+from neuronxcc import nki
+
+from neuronxcc.nki.language import par_dim
+from dataclasses import dataclass
+from functools import reduce as functools_reduce
+from operator import mul as operator_mul
+
+
+def n_elts(shape):
+  return functools_reduce(operator_mul, shape, 1)
+
+
+def linearize(shape, indices):
+  return sum(i * (n_elts(shape[dim + 1:]))
+             for dim, i in enumerate(indices))
+
+
+def div_ceil(n, d):
+  return (n + d - 1) // d
+
+
+@dataclass(frozen=True)
+class FlashConfig:
+  """
+    Config class for flash attention with default values
+  """
+  seq_tile_size:int = 2048
+  attn_core_tile_size:int = 256
+  training:bool = True
+  should_transpose_v:bool = False
+  lse_dtype: str = ""
+
+
+@nki.jit(mode='trace')
+def transpose_p_local(p_local_transposed, p_local, LARGE_TILE_SZ):
+  for i in nl.affine_range(LARGE_TILE_SZ // 512):
+    if nisa.get_nc_version() == nisa.nc_version.gen3:
+      p_local_t_tmp = nl.ndarray((par_dim(128), 512), buffer=nl.sbuf, dtype=p_local.dtype)
+    else:
+      p_local_t_tmp = nl.ndarray((par_dim(128), 512), buffer=nl.psum, dtype=np.float32)
+
+    for j in nl.affine_range(512 // 128):
+      j_128_slice = nl.ds(j * 128, 128)
+      i_j_128_slice = nl.ds(i * 512 + j * 128, 128)
+
+      if nisa.get_nc_version() == nisa.nc_version.gen3:
+        p_local_t_tmp[:, j_128_slice] = nisa.dma_transpose(
+          p_local[:, i_j_128_slice])
+      else:
+        p_local_t_tmp[:, j_128_slice] = nisa.nc_transpose(
+          p_local[:, i_j_128_slice])
+
+    p_local_transposed[:, nl.ds(i * 512, 512)] = nl.copy(
+      p_local_t_tmp, dtype=p_local_transposed.dtype)
+
+
+@nki.jit(mode='trace')
+def dropout_p_local(p_local, dropout_p, dropout_p_tensor, seed_tensor,
+                    seed_offset_base, k_r_i, REDUCTION_TILE):
+  B_F_SIZE = 512
+  for k_d_i in nl.sequential_range(REDUCTION_TILE // B_F_SIZE):
+    p_local_f_slice = nl.ds(k_r_i * REDUCTION_TILE + k_d_i * B_F_SIZE, B_F_SIZE)
+
+    offset = k_d_i + seed_offset_base
+    offset_seed = nl.add(seed_tensor, offset, dtype=nl.int32)
+    nl.random_seed(seed=offset_seed)
+    softmax_dropout = nl.dropout(p_local[:, p_local_f_slice],
+                                 rate=dropout_p_tensor[:, 0])
+    p_local[:, p_local_f_slice] = nl.multiply(
+      softmax_dropout, 1 / (1 - dropout_p))
+
+
+@nki.jit(mode='trace')
+def _flash_attention_core(q_local_tile, k, v,
+                          q_h_per_k_h, seqlen_q, nheads,
+                          o_buffer, l_buffer, m_buffer,
+                          batch_id, head_id, gqa_head_idx, q_tile_idx,
+                          local_k_large_tile_idx,
+                          kernel_dtype, acc_type,
+                          flash_config: FlashConfig,
+                          use_causal_mask, initialize,
+                          B_P_SIZE=128, B_F_SIZE=512, B_D_SIZE=128,
+                          dropout_p=0.0, dropout_p_tensor=None, seed_tensor=None,
+                          logit_bias_tile=None):
+  """
+  The flash attention core function to calcualte self attention between a tile of q and a block of K and V.
+  The q_local_tile has (B_P_SIZE, B_F_SIZE), which is loaded into the SBUF already. The block size of K and V
+  is defined in the seq_tile_size of the flash_config. The results are stored in the following three buffers
+  o_buffer: (B_P_SIZE, d)
+  l_buffer: (B_P_SIZE, 1)
+  m_buffer: (B_P_SIZE, 1)
+  """
+  LARGE_TILE_SZ = flash_config.seq_tile_size
+  num_k_tile_per_large_tile = LARGE_TILE_SZ // B_F_SIZE
+  seqlen_k = k.shape[-1]
+  seq_q_num_tiles = seqlen_q // B_P_SIZE
+  seq_k_num_tiles = seqlen_k // B_F_SIZE
+
+  qk_res_buf = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ), buffer=nl.sbuf, dtype=acc_type)
+  max_local = nl.ndarray((par_dim(B_P_SIZE), num_k_tile_per_large_tile), dtype=acc_type)
+
+  for k_i in nl.affine_range(num_k_tile_per_large_tile):
+    k_i_b_f_slice = nl.ds(k_i * B_F_SIZE, B_F_SIZE)
+
+    qk_psum = nl.ndarray((par_dim(B_P_SIZE), B_F_SIZE),
+                        dtype=np.float32, buffer=nl.psum)  # (128, 512)
+    if use_causal_mask:
+      multiplication_required_selection = q_tile_idx * B_P_SIZE >= local_k_large_tile_idx * LARGE_TILE_SZ + k_i * B_F_SIZE
+    else:
+      multiplication_required_selection = True
+
+    if multiplication_required_selection:
+      qk_psum[:, :] = nl.matmul(q_local_tile, k[:, k_i_b_f_slice], transpose_x=True) # (p(128), 512)
+    else:
+      qk_psum[:, :] = 0
+
+    if use_causal_mask:
+      left_diagonal_selection = q_tile_idx * B_P_SIZE >= local_k_large_tile_idx * LARGE_TILE_SZ + (k_i + 1) * B_F_SIZE
+      diagonal_and_right_selection = (q_tile_idx * B_P_SIZE < local_k_large_tile_idx * LARGE_TILE_SZ + (k_i + 1) * B_F_SIZE)
+      right_diagonal_selection = ((q_tile_idx + 1) * B_P_SIZE <= local_k_large_tile_idx * LARGE_TILE_SZ + k_i * B_F_SIZE)
+      diagonal = ((q_tile_idx * B_P_SIZE < local_k_large_tile_idx * LARGE_TILE_SZ + (k_i + 1) * B_F_SIZE) &
+                  ((q_tile_idx + 1) * B_P_SIZE > local_k_large_tile_idx * LARGE_TILE_SZ + k_i * B_F_SIZE))
+
+      i_q_p, i_q_f = nl.mgrid[0:B_P_SIZE, 0:B_F_SIZE]
+      q_pos = q_tile_idx * B_P_SIZE + i_q_p
+      k_pos = local_k_large_tile_idx * LARGE_TILE_SZ + k_i * B_F_SIZE + i_q_f
+      pred = q_pos >= k_pos
+
+      qk_select_tmp = nl.ndarray(qk_psum.shape, dtype=qk_psum.dtype, buffer=nl.sbuf)
+
+      if logit_bias_tile is not None:
+        if right_diagonal_selection:
+          qk_select_tmp[...] = qk_psum
+
+          # For tiles to the right of the diagonal, do affine_select.
+          # Magic number -9984.0 to replace -inf similar to what Tensorizer uses
+          qk_res_buf[:, k_i_b_f_slice] = nisa.affine_select(
+              pred=pred,
+              on_true_tile=qk_select_tmp, on_false_value=-9984.0, dtype=acc_type)
+
+        # For tiles on the diagonal, add logit bias and need to do affine_select.
+        intermediate = \
+            nl.add(qk_psum, logit_bias_tile[:, k_i_b_f_slice],
+                   dtype=acc_type, mask=diagonal)
+        qk_res_buf[:, k_i_b_f_slice] = nisa.affine_select(
+            pred=pred,
+            on_true_tile=intermediate, on_false_value=-9984.0, dtype=acc_type,
+            mask=diagonal)
+
+        # For tiles on the left of the diagonal, just add logit bias, no select required.
+        qk_res_buf[:, k_i_b_f_slice] = \
+            nl.add(qk_psum, logit_bias_tile[:, k_i_b_f_slice],
+                   dtype=acc_type, mask=left_diagonal_selection)
+      else:
+        # For tiles on and to the right of the diagonal, need to do affine_select.
+        # Magic number -9984.0 to replace -inf similar to what Tensorizer uses
+        if diagonal_and_right_selection:
+          qk_select_tmp[...] = qk_psum
+
+          qk_res_buf[:, k_i_b_f_slice] = nisa.affine_select(
+            pred=pred,
+            on_true_tile=qk_select_tmp, on_false_value=-9984.0, dtype=acc_type)
+
+        # For tiles on the left of the diagonal, direct copy, no select required.
+        qk_res_buf[:, k_i_b_f_slice] = \
+          nl.copy(qk_psum, dtype=acc_type, mask=left_diagonal_selection)
+    else:
+      if logit_bias_tile is not None:
+        # Simply add logit bias which copies back to sbuf at the same time
+        qk_res_buf[:, k_i_b_f_slice] = \
+            nl.add(qk_psum, logit_bias_tile[:, k_i_b_f_slice], dtype=acc_type)
+      else:
+        # Simply send psum result back to sbuf
+        qk_res_buf[:, k_i_b_f_slice] = nl.copy(qk_psum, dtype=acc_type)
+
+    # Calculate max of the current tile
+    max_local[:, k_i] = nisa.tensor_reduce(
+      np.max, qk_res_buf[:, k_i_b_f_slice], axis=(1,), dtype=acc_type,
+      negate=False)
+
+  max_ = nisa.tensor_reduce(np.max, max_local[:, :], axis=(1, ),
+                            dtype=acc_type, negate=False)
+
+  o_previous_scaled = nl.ndarray((par_dim(B_P_SIZE), B_D_SIZE), dtype=o_buffer.dtype)
+
+  if initialize:
+    m_buffer[:, 0] = nl.copy(max_)
+    m_current = max_
+  else:
+    m_previous = nl.copy(m_buffer[:, 0])
+    m_buffer[:, 0] = nl.maximum(m_previous, max_) # (128,1)
+
+    m_current = m_buffer[:, 0]
+    # Compute scaling factor
+    alpha = nisa.activation(np.exp, m_current, bias=m_previous, scale=-1.0)
+    o_previous_scaled[...] = nl.multiply(o_buffer[:, :], alpha)
+
+  p_local = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ), dtype=kernel_dtype)
+  REDUCTION_TILE = min(2048, LARGE_TILE_SZ // 2)
+
+  p_partial_sum = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ // REDUCTION_TILE), dtype=acc_type)
+
+  for k_r_i in nl.affine_range(LARGE_TILE_SZ // REDUCTION_TILE):
+    k_r_i_reduce_slice = nl.ds(k_r_i * REDUCTION_TILE, REDUCTION_TILE)
+
+    # dropout
+    if dropout_p > 0.0:
+      # compute exp(qk-max)
+      p_local[:, k_r_i_reduce_slice] = \
+        nisa.activation(np.exp, qk_res_buf[:, k_r_i_reduce_slice],
+                        bias=-1 * m_current, scale=1.0,
+                        dtype=kernel_dtype)
+
+      seed_offset_base = k_r_i * (REDUCTION_TILE // B_F_SIZE) \
+                         + local_k_large_tile_idx * (LARGE_TILE_SZ // B_F_SIZE) \
+                         + q_tile_idx * seq_k_num_tiles \
+                         + (head_id * q_h_per_k_h + gqa_head_idx) * seq_k_num_tiles * seq_q_num_tiles \
+                         + batch_id * nheads * seq_k_num_tiles * seq_q_num_tiles
+
+      dropout_p_local(p_local=p_local, dropout_p=dropout_p,
+                      dropout_p_tensor=dropout_p_tensor, seed_tensor=seed_tensor,
+                      seed_offset_base=seed_offset_base, k_r_i=k_r_i,
+                      REDUCTION_TILE=REDUCTION_TILE)
+
+      # Compute partial row-tile sum of exp(qk-max))
+      # FIXME: Use activation accumulate and accumulate over k_r_i loop?
+      p_partial_sum[:, k_r_i] = nl.sum(p_local[:, k_r_i_reduce_slice],
+                                       axis=1, dtype=acc_type)
+    else:
+      # compute exp(qk-max)
+      # Compute partial row-tile sum of exp(qk-max))
+      # FIXME: Use activation accumulate to accumulate over k_r_i loop?
+      p_local[:, k_r_i_reduce_slice] = \
+        nisa.activation_reduce(np.exp, qk_res_buf[:, k_r_i_reduce_slice],
+                               bias=-1 * m_current, scale=1.0,
+                               reduce_op=nl.add, reduce_res=p_partial_sum[:, k_r_i],
+                               dtype=kernel_dtype)
+
+  ps = nl.sum(p_partial_sum, axis=1, dtype=acc_type)
+
+  p_local_transposed = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ), dtype=kernel_dtype)
+  transpose_p_local(p_local_transposed=p_local_transposed, p_local=p_local,
+                    LARGE_TILE_SZ=LARGE_TILE_SZ)
+
+  pv_psum = nl.zeros((par_dim(B_P_SIZE), B_D_SIZE), dtype=np.float32,
+                     buffer=nl.psum, lazy_initialization=True)
+  for k_i in nl.affine_range(LARGE_TILE_SZ // B_P_SIZE):
+    pv_psum[:, :] += nl.matmul(p_local_transposed[:, nl.ds(k_i * B_P_SIZE, B_P_SIZE)],
+                               v[k_i, :, :], transpose_x=True) # (128, 128) (p(Br), d)
+
+  if initialize:
+    o_buffer[:, :] = nl.copy(pv_psum[:, :])
+    l_buffer[:, 0] = nl.add(nl.log(ps), max_)
+  else:
+    o_buffer[:, :] = nl.add(o_previous_scaled, pv_psum)
+
+    exp = nisa.activation(nl.exp, m_current, bias=l_buffer[:, 0], scale=-1.0)
+    l_buffer[:, 0] = nl.add(m_current, nisa.activation(nl.log, exp, bias=ps))
+
+
+@nki.jit(mode='trace')
+def load_v_tile(v_hbm_tile, cur_v_tile, j, v_i, config):
+  LARGE_TILE_SZ = config.seq_tile_size
+  B_P_SIZE = 128
+
+  if not config.should_transpose_v:
+    cur_v_tile[v_i, :, :] = nl.load(
+      v_hbm_tile[nl.ds(j * LARGE_TILE_SZ + B_P_SIZE * v_i, B_P_SIZE), :],
+      dtype=cur_v_tile.dtype)
+    return
+
+  if nisa.get_nc_version() == nisa.nc_version.gen3:
+    cur_v_tile_transposed = nisa.dma_transpose(
+      v_hbm_tile[:, nl.ds(j * LARGE_TILE_SZ + B_P_SIZE * v_i, B_P_SIZE)])
+    cur_v_tile[v_i, :, :] = nisa.tensor_copy(cur_v_tile_transposed,
+                                             dtype=cur_v_tile.dtype)
+    return
+
+  cur_v_tile[v_i, :, :] = nl.load_transpose2d(
+    v_hbm_tile[:, nl.ds(j * LARGE_TILE_SZ + B_P_SIZE * v_i, B_P_SIZE)],
+    dtype=cur_v_tile.dtype)
+
+
+
+@nki.jit
+def flash_fwd(q, k, v, seed, logit_bias=None,
+              softmax_scale=None,
+              use_causal_mask=True,
+              mixed_precision=True,
+              dropout_p=0.0, config=None):
+  """
+  Flash Attention Forward kernel
+
+  IO tensor layouts:
+    - q: shape   (bs, n_heads, d, seq_q)
+    - k: shape   (bs, nk_heads, d, seq_k)
+    - v: shape   (bs, nv_heads, d, seq_v) if config.should_transpose_v  else (bs, nv_heads, seq_v, d)
+    - seed: shape (1,)
+    - logit_bias: shape (bs, n_heads, seq_q, seq_k)
+    - o: shape (bs, n_heads, seq_q, d)
+    - lse: shape (bs, n_heads, nl.tile_size.pmax, seq // nl.tile_size.pmax) if training else None
+    - This kernel requires seq_k == seq_v
+
+  IO tensor dtypes:
+    - This kernel assumes all IO tensors have the same dtype
+    - If mixed_precision is True, then all Tensor Engine operation will be performed in
+      bfloat16 and accumulation will be performed in float32. Otherwise the intermediates
+      will be in the same type as the inputs.
+
+  Compile-time Constants:
+    - softmax_scale: scaling for softmax, is None, default is `1.0/(d**0.5)`
+    - mixed_precision: flag to set non-matmul ops in fp32 precision, default is set to `true`, if false, we use same precision as input types
+    - causal_mask: flag to set causal masking
+    - config: Instance of :class:`nki.kernels.attention.FlashConfig` with Performance config parameters for flash attention with default values
+        seq_tile_size: `default=2048`, size of the kv tile size for attention computation reduction
+        training: bool to indicate training vs inference `default=True`
+
+  Performance Notes:
+    For better performance, the kernel is tiled to be of size `config.seq_tile_size`, and Flash attention math techniques are applied in unit
+    of `config.seq_tile_size`. Seqlen that is not divisible by `config.seq_tile_size` is not supported at the moment.
+
+    For large seqlen, `o_buffer` will overflow the statebuf. the kernel is tile `o_buffer` based on the value of `config.attn_core_tile_size`.
+    This is a tradeoff between memory usage and performance. The default value of `config.attn_core_tile_size` is 256, which means the `o_buffer`
+    will roughly take half of the statebuf. The computes are also tiled accordingly. DMA will be rematerialized
+    `seqlen_q // B_P_SIZE // attn_core_tile_size times`.
+
+
+
+  GQA support Notes:
+    the spmd kernel for launching kernel should be on kv_heads instead of nheads
+
+  Example usage:
+    MHA: q: [b, h, d, s], k: [b, h, d, s], v: [b, h, s, d]
+      usage: `flash_fwd[b, h](q, k, v, ...)`
+    GQA: q: [b, h, d, s], k: [b, kv_h, d, s], v: [b, kv_h, s, d]
+      usage: `flash_fwd[b, kv_h](q, k, v, ...)`
+  """
+  config = config or FlashConfig()
+  B_F_SIZE=512
+  B_P_SIZE=128
+  b, h, d, seqlen_q  = q.shape
+  B_D_SIZE = d
+  _, k_h, _, seqlen_k = k.shape
+  if config.should_transpose_v:
+    assert tuple(v.shape) == (b, k_h, d, seqlen_k), f"Expect shape of V to be {(b, k_h, d, seqlen_k)} (batch, heads, d_head, seqlen_k) but got {v.shape}"
+    assert tuple(k.shape) == (b, k_h, d, seqlen_k), f"Expect shape of K to be {(b, k_h, d, seqlen_k)} (batch, heads, d_head, seqlen_k) but got {k.shape}"
+  else:
+    assert tuple(v.shape) == (b, k_h, seqlen_k, d), f"Expect shape of V to be {(b, k_h, seqlen_k, d)} (batch, heads, seqlen_k, d_head) but got {v.shape}"
+    assert tuple(k.shape) == (b, k_h, d, seqlen_k), f"Expect shape of K to be {(b, k_h, d, seqlen_k)} (batch, heads, d_head, seqlen_k) but got {k.shape}"
+  assert d <= 128, f" we do not support head_dim > 128, got head dim {d}"
+  kernel_dtype = nl.bfloat16 if mixed_precision else q.dtype
+  acc_type = np.dtype(np.float32) if mixed_precision else kernel_dtype
+
+  o = nl.ndarray((b, h, seqlen_q, d), dtype=q.dtype, buffer=nl.shared_hbm)
+  if config.training:
+    if config.lse_dtype:
+      lse_dtype = getattr(nl, config.lse_dtype)
+    else:
+      lse_dtype = acc_type
+    lse = nl.ndarray((b, h, nl.tile_size.pmax, seqlen_q // nl.tile_size.pmax),
+                     dtype=lse_dtype, buffer=nl.shared_hbm)
+  else:
+    lse = None
+
+  assert nl.program_ndim() == 2,\
+    f'Expect spmd grid with 2 dimensions, got {nl.program_ndim()} instead!'
+  batch_id = nl.program_id(axis=0)
+  head_id = nl.program_id(axis=1)
+
+  softmax_scale = softmax_scale or (1.0 / (d ** 0.5))
+
+  n_tile_q = seqlen_q // B_P_SIZE # since q will be loaded on tensor engine
+
+  LARGE_TILE_SZ = config.seq_tile_size
+  attn_core_tile_size = config.attn_core_tile_size
+
+  # FIXME: Add masking for different seqlen values.
+  assert config.seq_tile_size >= 512, f" seq tile_size {config.seq_tile_size} cannot be less than 512"
+  assert seqlen_k % LARGE_TILE_SZ == 0, f"Need seqlen_k to be divisible by {LARGE_TILE_SZ} but got {seqlen_k}"
+  num_large_k_tile = seqlen_k // LARGE_TILE_SZ
+
+  # inference flag, check if lse is none
+  inference = not config.training
+  if inference:
+    assert lse is None, "lse should be none for inference"
+    assert seed is None, f"seed should be None for inference, but got {seed}"
+    assert dropout_p==0.0, f"dropout should be 0.0 for inference but got {dropout_p}"
+  else:
+    assert lse is not None, "lse should not be none for training"
+  q_h_per_k_h = h // k_h
+
+  if dropout_p > 0.0 and not inference:
+    seed_local = nl.load(seed[0])
+    # TODO: Remove this once the dropout supports scale prob
+    dropout_p_tensor = nl.full((B_P_SIZE, 1), fill_value=dropout_p, dtype=np.float32)
+  else:
+    dropout_p_tensor = None
+    seed_local = None
+
+  if logit_bias is not None:
+    b_logit_bias, h_logit_bias, _, _ = logit_bias.shape
+    assert b_logit_bias == 1 and h_logit_bias == 1, "only support broadcasting logit_bias with batch 1, n_heads 1"
+
+  n_remat = div_ceil(n_tile_q, attn_core_tile_size)
+  attn_core_tile_size = min(n_tile_q, attn_core_tile_size)
+
+  for i_q_h in nl.affine_range(q_h_per_k_h):
+    # =============== Global Flash Attention accumulators ====================== #
+    l_buffer = nl.zeros((par_dim(B_P_SIZE), n_tile_q), dtype=acc_type,
+                        buffer=nl.sbuf, lazy_initialization=True)
+    # =============== Global Flash Attention accumulators END ================== #
+
+    for i0 in nl.sequential_range(n_remat):
+      # =============== Global Flash Attention accumulators ====================== #
+      o_buffer = nl.zeros((attn_core_tile_size, par_dim(B_P_SIZE), d), dtype=acc_type,
+                          buffer=nl.sbuf, lazy_initialization=True)
+      m_buffer = nl.zeros((attn_core_tile_size, par_dim(B_P_SIZE), 1), dtype=acc_type,
+                          buffer=nl.sbuf, lazy_initialization=True)
+      # =============== Global Flash Attention accumulators END ================== #
+
+      for j in nl.sequential_range(0, num_large_k_tile):
+        cur_k_tile = nl.ndarray((par_dim(B_D_SIZE), LARGE_TILE_SZ), dtype=kernel_dtype)
+        cur_v_tile = nl.ndarray((LARGE_TILE_SZ // B_P_SIZE, par_dim(B_P_SIZE), B_D_SIZE), dtype=kernel_dtype)
+
+        cur_k_tile[:, :] = nl.load(k[batch_id, head_id, :, nl.ds(j*LARGE_TILE_SZ, LARGE_TILE_SZ)])
+
+        load_tile_size = B_P_SIZE
+
+        v_hbm_tile = v[batch_id, head_id]
+        for v_i in nl.affine_range(LARGE_TILE_SZ // load_tile_size):
+          load_v_tile(v_hbm_tile=v_hbm_tile, cur_v_tile=cur_v_tile, j=j, v_i=v_i,
+                      config=config)
+
+        for i1 in nl.affine_range(attn_core_tile_size):
+          i = i0 * attn_core_tile_size + i1
+          # mask are used to only apply computation to the lower half of the matrix,
+          # which reduce the arthimetic intensity by half.
+          # forward_mask imply initialize, i.e. if forward_mask is false, initialize will
+          # be false as well
+          if use_causal_mask:
+            forward_mask = i * B_P_SIZE >= j * LARGE_TILE_SZ
+          else:
+            forward_mask = True
+
+          if (i < n_tile_q) & forward_mask:
+            q_tile = nl.ndarray((B_D_SIZE, B_P_SIZE),dtype=kernel_dtype)
+            q_hbm_tile = q[batch_id, head_id * q_h_per_k_h + i_q_h]
+            q_sbuf_tile = nl.load(q_hbm_tile[:, nl.ds(i * B_P_SIZE, B_P_SIZE)],
+                                  dtype=kernel_dtype) # load (d, 128) tile in SBUF
+            q_tile[:, :] = q_sbuf_tile * softmax_scale
+
+            logit_bias_tile = None
+            if logit_bias is not None:
+              logit_bias_tile = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ), dtype=kernel_dtype)
+              logit_bias_tile[:, :] = nl.load(
+                logit_bias[0, 0, nl.ds(i * B_P_SIZE, B_P_SIZE),
+                           nl.ds(j * LARGE_TILE_SZ, LARGE_TILE_SZ)])
+
+            _flash_attention_core(q_local_tile=q_tile, k=cur_k_tile, v=cur_v_tile,
+                                  q_h_per_k_h=q_h_per_k_h, seqlen_q=seqlen_q, nheads=h,
+                                  o_buffer=o_buffer[i1], l_buffer=l_buffer[:, i], m_buffer=m_buffer[i1],
+                                  batch_id=batch_id, head_id=head_id,
+                                  gqa_head_idx=i_q_h, q_tile_idx=i, local_k_large_tile_idx=j,
+                                  kernel_dtype=kernel_dtype, acc_type=acc_type,
+                                  flash_config=config, use_causal_mask=use_causal_mask,
+                                  initialize=j == 0,
+                                  B_P_SIZE=B_P_SIZE, B_F_SIZE=B_F_SIZE, B_D_SIZE=B_D_SIZE,
+                                  dropout_p=dropout_p, dropout_p_tensor=dropout_p_tensor,
+                                  seed_tensor=seed_local, logit_bias_tile=logit_bias_tile)
+
+      # -------- write output to buffer on HBM ------------ #
+      for i1 in nl.affine_range(attn_core_tile_size):
+        i = i0 * attn_core_tile_size + i1
+
+        if i < n_tile_q:
+          exp = nisa.activation(np.exp, l_buffer[:, i], bias=m_buffer[i1, :, :],
+                                scale=-1.0)
+          out = nl.multiply(o_buffer[i1, :, :], exp,
+                            dtype=kernel_dtype)
+
+          nl.store(o[batch_id, head_id * q_h_per_k_h + i_q_h,
+                     nl.ds(i*B_P_SIZE, B_P_SIZE), :], out)
+
+    if not inference:
+      nl.store(lse[batch_id, head_id * q_h_per_k_h + i_q_h, :, :], l_buffer[:, :])
+
+  if config.training:
+    return o, lse
+
+  return o
+
+
+
+@nki.jit
+def flash_attn_bwd(
+  q_ref, k_ref, v_ref, o_ref,
+  dy_ref,
+  lse_ref,
+  seed_ref,
+  logit_bias_ref=None,
+  use_causal_mask=False,
+  mixed_precision=False,
+  dropout_p=0.0,
+  softmax_scale=None,
+):
+  """
+  Flash attention backward kernel. Compute the backward gradients.
+
+  IO tensor layouts:
+   - q_ref: shape (bs, nheads, head_size, seq)
+   - k_ref: shape (bs, nheads, head_size, seq)
+   - v_ref: shape (bs, nheads, head_size, seq)
+   - o_ref: shape (bs, nheads, head_size, seq)
+   - dy_ref: shape (bs, nheads, head_size, seq)
+   - lse_ref: shape (bs, nheads, nl.tile_size.pmax, seq // nl.tile_size.pmax)
+   - seed_ref: shape (1,)
+   - logit_bias_ref: shape (bs, n_heads, seq_q, seq_k)
+   - out_dq_ref: shape (bs, nheads, head_size, seq)
+   - out_dk_ref: shape (bs, nheads, head_size, seq)
+   - out_dv_ref: shape (bs, nheads, head_size, seq)
+
+  Detailed steps:
+    1. D = rowsum(dO ◦ O) (pointwise multiply)
+
+    2. Recompute (softmax(Q^T@K + logic_bias))
+
+      2.1 Q^T@K
+      2.2 Scale the QK score
+      2.3 Apply causal mask and add logit_bias
+      2.4 softmax
+
+    3. Compute the gradients of y = score @ V with respect to the loss
+
+    4. Compute the gradients of y = softmax(x)
+
+    5. Compute the gradients of Q^T@K
+
+      4.1 Compute dQ
+      4.2 Compute dK
+  """
+
+  # Use q_ref dtype as the intermediate tensor dtype
+  # Assume all IO tensors have the same dtype
+  kernel_dtype = q_ref.dtype
+  mixed_dtype = np.dtype(np.float32) if mixed_precision else kernel_dtype
+
+  assert q_ref.dtype == k_ref.dtype == v_ref.dtype == o_ref.dtype == dy_ref.dtype
+
+  # Shape checking
+  bs, nheads, d_head, seqlen_q = q_ref.shape
+  _, _, _, seqlen_k = k_ref.shape
+  assert tuple(k_ref.shape) == (bs, nheads, d_head, seqlen_k), \
+    f"Input K shape mismatch, got {k_ref.shape}"
+  assert tuple(v_ref.shape) == (bs, nheads, d_head, seqlen_k), \
+    f"Input V shape mismatch, got {v_ref.shape}"
+  assert tuple(o_ref.shape) == (bs, nheads, d_head, seqlen_q), \
+    f"Input o shape mismatch, got {o_ref.shape}"
+  assert tuple(dy_ref.shape) == (bs, nheads, d_head, seqlen_q), \
+    f"Input dy shape mismatch, got {dy_ref.shape}"
+  assert tuple(lse_ref.shape) == (bs, nheads, nl.tile_size.pmax, seqlen_q // nl.tile_size.pmax), \
+    f"Input lse shape mismatch, got {lse_ref.shape}"
+  if seed_ref is not None:
+    assert tuple(seed_ref.shape) == (1,), \
+      f"Input seed shape mismatch, got {seed_ref.shape}"
+
+  out_dq_ref = nl.ndarray((bs, nheads, d_head, seqlen_q), dtype=q_ref.dtype,
+                          buffer=nl.shared_hbm)
+  out_dk_ref = nl.ndarray((bs, nheads, d_head, seqlen_k), dtype=q_ref.dtype,
+                          buffer=nl.shared_hbm)
+  out_dv_ref = nl.ndarray((bs, nheads, d_head, seqlen_k), dtype=q_ref.dtype,
+                          buffer=nl.shared_hbm)
+
+  # FIXME: Add masking for different seqlen values.
+  assert seqlen_q % 128 == 0 and seqlen_k % 128 == 0, \
+    f"Input sequence lengths must be divisible by 128, got seqlen_q == {seqlen_q} and seqlen_k == {seqlen_k}"
+
+  # Softmax scaling factor, multiplied onto Q
+  softmax_scale = softmax_scale or 1.0 / float(d_head ** 0.5)
+
+  assert nl.program_ndim() == 2,\
+    f'Expect spmd grid with 2 dimensions, got {nl.program_ndim()} instead!'
+  # Different batch samples/attention heads have independent attention
+  batch_id = nl.program_id(axis=0)
+  head_id = nl.program_id(axis=1)
+
+  assert nl.num_programs(1) == nheads, \
+    f"The grid shape mismatch, got {nl.num_programs(1)} but should be {nheads}"
+
+  if logit_bias_ref is not None:
+    b_logit_bias, h_logit_bias, _, _ = logit_bias_ref.shape
+    assert b_logit_bias == 1 and h_logit_bias == 1, "Only support broadcasting logit_bias with batch 1, n_heads 1"
+
+  q_seq_n_tiles, q_seq_tile_size = div_ceil(seqlen_q, 128), 128
+  d_head_n_tiles, d_head_tile_size = div_ceil(d_head, 128), min(d_head, 128)
+
+  if seqlen_k >= 512:
+    k_seq_n_tiles, k_seq_tile_size = seqlen_k // 512, 512
+  else:
+    k_seq_n_tiles, k_seq_tile_size = seqlen_k // 128, 128
+
+  k_seq_n_tiles_backward, k_seq_tile_size_backward = seqlen_k // 128, 128
+  k_seq_fwd_bwd_tile_multipler = k_seq_tile_size // k_seq_tile_size_backward
+
+  ##############################################################
+  # Step 2.4 Prefetch exp bias for softmax
+  ##############################################################
+  softmax_exp_bias = nl.zeros((par_dim(q_seq_tile_size), q_seq_n_tiles), dtype=mixed_dtype)
+  lse_local = nl.load(lse_ref[batch_id, head_id, :, :], dtype=mixed_dtype)
+  softmax_exp_bias[:, :] = lse_local * -1.0
+
+  ##############################################################
+  # Step 1 Compute rowsum(dO ◦ O)
+  ##############################################################
+  dy_o_sum = nl.ndarray((q_seq_n_tiles, par_dim(q_seq_tile_size), 1), dtype=mixed_dtype)
+  compute_rowsum(dy_o_sum=dy_o_sum,
+                 dy_ref_hbm_tile=dy_ref[batch_id, head_id],
+                 o_ref_hbm_tile=o_ref[batch_id, head_id],
+                 d_head_n_tiles=d_head_n_tiles, d_head_tile_size=d_head_tile_size,
+                 q_seq_n_tiles=q_seq_n_tiles, q_seq_tile_size=q_seq_tile_size)
+
+  if dropout_p > 0.0:
+    seed_local = nl.load(seed_ref[0])
+    # TODO: Remove this once the dropout supports scale prob
+    dropout_p_local = nl.full((q_seq_tile_size, 1), fill_value=dropout_p, dtype=np.float32)
+  else:
+    seed_local = None
+    dropout_p_local = None
+
+  dq_local_reduced = nl.zeros((q_seq_n_tiles, d_head_n_tiles, par_dim(d_head_tile_size), q_seq_tile_size),
+                              dtype=mixed_dtype)
+
+  # affine_range give the compiler permission to vectorize instructions
+  # inside the loop which improves the performance. However, when using the
+  # the dropout we should use sequential_range to avoid setting
+  # seed vectorization. TODO: the compiler should avoid vectorizing seed setting
+  _range = nl.sequential_range if dropout_p > 0.0 else nl.affine_range
+
+  for i_k_seq_tile in nl.affine_range(k_seq_n_tiles):
+    i_k_seq_dslice = nl.ds(i_k_seq_tile * k_seq_tile_size, k_seq_tile_size)
+
+    # Prefetch V, K
+    v_local = nl.zeros((d_head_n_tiles, par_dim(d_head_tile_size), k_seq_tile_size),
+                       dtype=kernel_dtype)
+    k_local = nl.zeros((d_head_n_tiles, par_dim(d_head_tile_size), k_seq_tile_size),
+                       dtype=kernel_dtype)
+    transposed_k_local = nl.zeros((k_seq_fwd_bwd_tile_multipler, d_head_n_tiles,
+                                   par_dim(k_seq_tile_size_backward), d_head_tile_size),
+                                  dtype=kernel_dtype)
+
+    load_kv(k_ref_hbm_tile=k_ref[batch_id, head_id],
+            v_ref_hbm_tile=v_ref[batch_id, head_id],
+            k_local=k_local, transposed_k_local=transposed_k_local, v_local=v_local,
+            d_head_n_tiles=d_head_n_tiles, d_head_tile_size=d_head_tile_size,
+            i_k_seq_tile=i_k_seq_tile, k_seq_tile_size=k_seq_tile_size,
+            k_seq_tile_size_backward=k_seq_tile_size_backward)
+
+    # FIXME: Pass sbuf instead, we will have psum spilling in the current implementation
+    dv_psum = nl.zeros((d_head_n_tiles, par_dim(d_head_tile_size), k_seq_tile_size),
+                        dtype=np.float32, buffer=nl.psum)
+    dk_psum = nl.zeros((d_head_n_tiles, par_dim(d_head_tile_size), k_seq_tile_size),
+                        dtype=np.float32, buffer=nl.psum)
+    for i_q_seq_tile in _range(q_seq_n_tiles):
+      # Prefetch dy, Q
+      dy_local = nl.zeros((d_head_n_tiles, par_dim(d_head_tile_size), q_seq_tile_size), dtype=kernel_dtype)
+      q_local = nl.zeros((d_head_n_tiles, par_dim(d_head_tile_size), q_seq_tile_size), dtype=kernel_dtype)
+
+      load_dy_q(dy_ref_hbm_tile = dy_ref[batch_id, head_id],
+                q_ref_hbm_tile = q_ref[batch_id, head_id],
+                dy_local=dy_local, q_local=q_local, d_head_n_tiles=d_head_n_tiles,
+                d_head_tile_size=d_head_tile_size, i_q_seq_tile=i_q_seq_tile,
+                q_seq_tile_size=q_seq_tile_size, softmax_scale=softmax_scale)
+
+      logit_bias_tile = None
+      if logit_bias_ref is not None:
+        i_q_seq_dslice = nl.ds(i_q_seq_tile * q_seq_tile_size, q_seq_tile_size)
+        logit_bias_tile = nl.ndarray((par_dim(q_seq_tile_size), k_seq_tile_size),
+                                     buffer=nl.sbuf, dtype=kernel_dtype)
+        logit_bias_tile[:, :] = nl.load(
+          logit_bias_ref[0, 0, i_q_seq_dslice, i_k_seq_dslice])
+
+      _flash_attn_bwd_core(
+        q_local=q_local, k_local=k_local, transposed_k_local=transposed_k_local,
+        v_local=v_local, dy_local=dy_local,
+        dk_psum=dk_psum, dv_psum=dv_psum, dq_local_reduced=dq_local_reduced,
+        softmax_exp_bias=softmax_exp_bias, dy_o_sum=dy_o_sum,
+        local_i_q_seq_tile=i_q_seq_tile, local_i_k_seq_tile=i_k_seq_tile,
+        seqlen_q=seqlen_q, seqlen_k=seqlen_k, d_head=d_head, nheads=nheads,
+        use_causal_mask=use_causal_mask,
+        kernel_dtype=kernel_dtype, mixed_dtype=mixed_dtype,
+        softmax_scale=softmax_scale,
+        seed_local=seed_local, dropout_p=dropout_p, dropout_p_local=dropout_p_local,
+        logit_bias_tile=logit_bias_tile
+      )
+
+    # Write dK, dV
+    store_dk_dv(out_dk_ref_hbm_tile=out_dk_ref[batch_id, head_id],
+                out_dv_ref_hbm_tile=out_dv_ref[batch_id, head_id],
+                local_dk=dk_psum, local_dv=dv_psum, i_k_seq_dslice=i_k_seq_dslice,
+                d_head_n_tiles=d_head_n_tiles, d_head_tile_size=d_head_tile_size)
+
+  # Write dQ
+  for i_q_seq_tile in nl.affine_range(q_seq_n_tiles):
+    for i_d_head_tile in nl.affine_range(d_head_n_tiles):
+      i_q_seq_dslice = nl.ds(i_q_seq_tile * q_seq_tile_size, q_seq_tile_size)
+      i_d_head_dslice = nl.ds(i_d_head_tile * d_head_tile_size, d_head_tile_size)
+      nl.store(
+        out_dq_ref[batch_id, head_id, i_d_head_dslice, i_q_seq_dslice],
+        value=dq_local_reduced[i_q_seq_tile, i_d_head_tile, :, :],
+      )
+
+  return out_dq_ref, out_dk_ref, out_dv_ref
+
+
+@nki.jit(mode='trace')
+def load_dy_q(dy_ref_hbm_tile, q_ref_hbm_tile, dy_local, q_local, d_head_n_tiles, d_head_tile_size, i_q_seq_tile,
+              q_seq_tile_size, softmax_scale):
+  for i_d_head_tile in nl.affine_range(d_head_n_tiles):
+    i_d_head_dslice = nl.ds(i_d_head_tile * d_head_tile_size, d_head_tile_size)
+    i_q_seq_dslice = nl.ds(i_q_seq_tile * q_seq_tile_size, q_seq_tile_size)
+
+    dy_local[i_d_head_tile, :, :] = nl.load(
+      dy_ref_hbm_tile[i_d_head_dslice, i_q_seq_dslice],
+      dtype=dy_local.dtype)
+
+    q_local[i_d_head_tile, :, :] = nl.load(
+      q_ref_hbm_tile[i_d_head_dslice, i_q_seq_dslice],
+      dtype=q_local.dtype) * softmax_scale
+
+
+@nki.jit(mode='trace')
+def store_dk_dv(out_dk_ref_hbm_tile, out_dv_ref_hbm_tile, local_dk, local_dv,
+                d_head_n_tiles, d_head_tile_size, i_k_seq_dslice):
+  for i in nl.affine_range(d_head_n_tiles):
+    i_d_head_dslice = nl.ds(i * d_head_tile_size, d_head_tile_size)
+
+    nl.store(out_dv_ref_hbm_tile[i_d_head_dslice, i_k_seq_dslice],
+             value=local_dv[i, :, :])
+
+    nl.store(out_dk_ref_hbm_tile[i_d_head_dslice, i_k_seq_dslice],
+             value=local_dk[i, :, :])
+
+
+@nki.jit(mode='trace')
+def load_kv(k_ref_hbm_tile, v_ref_hbm_tile, k_local, transposed_k_local, v_local,
+            d_head_n_tiles, d_head_tile_size, i_k_seq_tile, k_seq_tile_size,
+            k_seq_tile_size_backward):
+  k_seq_fwd_bwd_tile_multipler = k_seq_tile_size // k_seq_tile_size_backward
+
+  for i in nl.affine_range(d_head_n_tiles):
+    i_d_head_dslice = nl.ds(i * d_head_tile_size, d_head_tile_size)
+    i_k_seq_dslice = nl.ds(i_k_seq_tile * k_seq_tile_size, k_seq_tile_size)
+    k_local[i, :, :] = nl.load(k_ref_hbm_tile[i_d_head_dslice, i_k_seq_dslice],
+                                           dtype=k_local.dtype)
+    v_local[i, :, :] = nl.load(v_ref_hbm_tile[i_d_head_dslice, i_k_seq_dslice],
+                                           dtype=v_local.dtype)
+    ##############################################################
+    # Prefetch k transpose for the backward too
+    ##############################################################
+    for j in nl.affine_range(k_seq_fwd_bwd_tile_multipler):
+      i_k_dslice = nl.ds(j * k_seq_tile_size_backward, k_seq_tile_size_backward)
+      transposed_k_local[j, i, :, :] = nisa.nc_transpose(k_local[i, :, i_k_dslice])
+
+
+@nki.jit(mode='trace')
+def compute_rowsum(dy_o_sum, dy_ref_hbm_tile, o_ref_hbm_tile, d_head_n_tiles, d_head_tile_size, q_seq_n_tiles,
+                   q_seq_tile_size):
+  mixed_dtype = dy_o_sum.dtype
+  for i in nl.affine_range(q_seq_n_tiles):
+    dy_o_partial = nl.zeros((par_dim(q_seq_tile_size), d_head_n_tiles), dtype=mixed_dtype)
+    for j in nl.affine_range(d_head_n_tiles):
+      d_head_dslice = nl.ds(j * d_head_tile_size, d_head_tile_size)
+      q_seq_dslice = nl.ds(i * q_seq_tile_size, q_seq_tile_size)
+
+      dy_local = nl.load_transpose2d(dy_ref_hbm_tile[d_head_dslice, q_seq_dslice],
+                                     dtype=mixed_dtype)
+      o_local = nl.load_transpose2d(o_ref_hbm_tile[d_head_dslice, q_seq_dslice],
+                                    dtype=mixed_dtype)
+
+      dy_o = nl.multiply(dy_local, o_local, dtype=mixed_dtype)
+      dy_o_partial[:, j] = nisa.tensor_reduce(np.add, data=dy_o, axis=(1,),
+                                              dtype=mixed_dtype)
+
+    dy_o_sum[i, :, 0] = nisa.tensor_reduce(
+      np.add, data=dy_o_partial[:, :], axis=(1,), dtype=mixed_dtype)
+
+
+@nki.jit(mode='trace')
+def _flash_attn_bwd_core(
+  q_local, k_local, transposed_k_local, v_local, dy_local,
+  dk_psum, dv_psum, dq_local_reduced,
+  softmax_exp_bias, dy_o_sum,
+  local_i_q_seq_tile, local_i_k_seq_tile,
+  seqlen_q, seqlen_k, d_head, nheads,
+  use_causal_mask,
+  kernel_dtype, mixed_dtype,
+  softmax_scale,
+  seed_local, dropout_p, dropout_p_local,
+  logit_bias_tile=None):
+  """
+  The flash backward core function to calculate the gradients of Q, K and V
+  of the given tiles. The result will be accumulated into the dk, dv, dq psum
+  """
+  q_seq_n_tiles, q_seq_tile_size = div_ceil(seqlen_q, 128), 128
+  d_head_n_tiles, d_head_tile_size = div_ceil(d_head, 128), min(d_head, 128)
+  if seqlen_k >= 512:
+    k_seq_n_tiles, k_seq_tile_size = seqlen_k // 512, 512
+  else:
+    k_seq_n_tiles, k_seq_tile_size = seqlen_k // 128, 128
+  k_seq_n_tiles_backward, k_seq_tile_size_backward = seqlen_k // 128, 128
+  k_seq_fwd_bwd_tile_multipler = k_seq_tile_size // k_seq_tile_size_backward
+
+  mask = local_i_q_seq_tile * q_seq_tile_size >= local_i_k_seq_tile * k_seq_tile_size if use_causal_mask else None
+  # PSUM buffer shape: [q_seq_tile_size P, k_seq_tile_size F]
+  qk_psum = nl.zeros((par_dim(q_seq_tile_size), k_seq_tile_size),
+                      dtype=np.float32, buffer=nl.psum)
+  qk_res_buf = nl.ndarray((par_dim(q_seq_tile_size), k_seq_tile_size), buffer=nl.sbuf, dtype=kernel_dtype)
+
+  batch_id = nl.program_id(axis=0)
+  head_id = nl.program_id(axis=1)
+
+  # Loop over contraction dim of QK matmul
+  for i_d_head_tile in nl.affine_range(d_head_n_tiles):
+    ##############################################################
+    # Step 2.1 Compute Q^T@K, with matmul(stationary=tensor_q, moving=tensor_k, contract=d_head)
+    ##############################################################
+    qk_psum[:, :] += nisa.nc_matmul(q_local[i_d_head_tile, :, :],
+                                            k_local[i_d_head_tile, :, :],
+                                            mask=mask)
+
+  ######################################
+  # Step 2.2. Apply optional causal mask
+  ######################################
+  if use_causal_mask:
+    iq, ik = nl.mgrid[0:q_seq_tile_size, 0:k_seq_tile_size]
+    causal_pred = (local_i_q_seq_tile * q_seq_tile_size + iq >= local_i_k_seq_tile * k_seq_tile_size + ik)
+    if logit_bias_tile is not None:
+      # Magic number -9984.0 to replace -inf similar to what Tensorizer uses
+      intermediate = \
+        nl.add(qk_psum[:, :], logit_bias_tile[:, :], dtype=mixed_dtype, mask=mask)
+      qk_res_buf[:, :] = nisa.affine_select(
+        pred=causal_pred, 
+        on_true_tile=intermediate, on_false_value=-9984.0, dtype=mixed_dtype,
+        mask=mask
+      )
+
+    else:
+      # Magic number -9984.0 to replace -inf similar to what Tensorizer uses
+      qk_res_buf[:, :] = nisa.affine_select(
+        pred=causal_pred,
+        on_true_tile=qk_psum[:, :], on_false_value=-9984.0, dtype=mixed_dtype,
+        mask=mask)
+  else:
+    if logit_bias_tile is not None:
+      # Simply add logit bias which copies back to sbuf at the same time
+      qk_res_buf[:, :] = \
+        nl.add(qk_psum[:, :], logit_bias_tile[:, :], dtype=mixed_dtype)
+    else:
+      # Simply send psum result back to sbuf
+      qk_res_buf[:, :] = \
+        nl.copy(qk_psum[:, :], dtype=mixed_dtype)
+
+  softmax_y = nl.ndarray((par_dim(q_seq_tile_size), k_seq_tile_size), dtype=kernel_dtype, buffer=nl.sbuf)
+  softmax_y[:, :] = nisa.activation(np.exp,
+                                    data=qk_res_buf[:, :],
+                                    bias=softmax_exp_bias[:, local_i_q_seq_tile],
+                                    scale=1.0,
+                                    mask=mask)
+  #####################################################################
+  # Dropout
+  #####################################################################
+  if dropout_p > 0.0:
+    offset = local_i_k_seq_tile + local_i_q_seq_tile * k_seq_n_tiles \
+              + head_id * k_seq_n_tiles * q_seq_n_tiles \
+              + batch_id * nheads * k_seq_n_tiles * q_seq_n_tiles
+    offset_seed = nl.add(seed_local[0, 0], offset, mask=mask)
+    nl.random_seed(seed=offset_seed, mask=mask)
+    softmax_y[:, :] = nl.dropout(softmax_y[:, :], rate=dropout_p_local[:, 0], mask=mask)
+    softmax_y[:, :] = nl.multiply(softmax_y[:, :], 1 / (1 - dropout_p), mask=mask)
+
+  #####################################################################
+  # Step 3.1 Calculate the backward gradients dL/dV, where y=softmax@V
+  # in value projection with matmul(stationary=dy, moving=softmax)
+  #####################################################################
+  for i_d_head_tile in nl.affine_range(d_head_n_tiles):
+    trans_dy = nisa.nc_transpose(dy_local[i_d_head_tile, :, :],
+                                  mask=mask)
+    dv_psum[i_d_head_tile, :, :] += \
+      nisa.nc_matmul(trans_dy, softmax_y[:, :], mask=mask)
+
+  #####################################################################
+  # Step 3.2 Calculate the backward gradients dL/dsoftmax, where y=softmax@V
+  # in value projection with matmul(stationary=dy, moving=v)
+  #####################################################################
+  softmax_dy_psum = nl.zeros((par_dim(q_seq_tile_size), k_seq_tile_size),
+                              dtype=np.float32, buffer=nl.psum)
+  for i_d_head_tile in nl.affine_range(d_head_n_tiles):
+    softmax_dy_psum[:, :] += \
+      nisa.nc_matmul(dy_local[i_d_head_tile, :, :],
+                      v_local[i_d_head_tile, :, :],
+                      mask=mask)
+
+  softmax_dy = nl.ndarray((par_dim(q_seq_tile_size), k_seq_tile_size), dtype=kernel_dtype, buffer=nl.sbuf)
+  softmax_dy[:, :] = nl.copy(softmax_dy_psum[:, :], dtype=kernel_dtype,
+                                      mask=mask)
+
+  #####################################################################
+  # Step 4 Calculate the softmax backward gradients dL/dx, where y=softmax(x)
+  # dL/dx = y * (dL/dy - rowsum(dO_O)), where y = softmax(x)
+  #####################################################################
+  softmax_dx_local = nl.ndarray((par_dim(q_seq_tile_size), k_seq_tile_size), dtype=kernel_dtype, buffer=nl.sbuf)
+  softmax_dx_local[:, :] = \
+    nisa.scalar_tensor_tensor(data=softmax_dy[:, :],
+                              op0=np.subtract,
+                              operand0=dy_o_sum[local_i_q_seq_tile, :, 0],
+                              op1=np.multiply,
+                              operand1=softmax_y[:, :],
+                              mask=mask)
+
+  #####################################################################
+  # Step 5.1 Calculate dK, with matmul(stationary=Q, moving=softmax_dx)
+  #####################################################################
+  for i_d_head_tile in nl.affine_range(d_head_n_tiles):
+    trans_q_local = nisa.nc_transpose(q_local[i_d_head_tile, :, :],
+                                      mask=mask)
+    dk_psum[i_d_head_tile, :, :] += \
+      nisa.nc_matmul(trans_q_local,
+                      softmax_dx_local[:, :],
+                      mask=mask)
+
+  #####################################################################
+  # Step 5.2 Calculate dQ
+  #####################################################################
+  for i_d_head_tile in nl.affine_range(d_head_n_tiles):
+    dq_psum = nl.zeros((par_dim(d_head_tile_size), q_seq_tile_size),
+                        dtype=np.float32, buffer=nl.psum)
+    for i_k_seq_tile_backward in nl.affine_range(k_seq_fwd_bwd_tile_multipler):
+      i_k_seq_dslice = nl.ds(i_k_seq_tile_backward * k_seq_tile_size_backward,
+                             k_seq_tile_size_backward)
+      transposed_softmax_dx_local = \
+        nisa.nc_transpose(softmax_dx_local[:, i_k_seq_dslice],
+                          mask=mask)
+      dq_psum[:, :] += nisa.nc_matmul(
+          transposed_k_local[i_k_seq_tile_backward, i_d_head_tile, :, :],
+          transposed_softmax_dx_local,
+          mask=mask)
+    dq_local = nl.multiply(dq_psum[:, :], softmax_scale, dtype=kernel_dtype, mask=mask)
+    dq_local_reduced[local_i_q_seq_tile, i_d_head_tile, :, :] = nl.loop_reduce(
+      dq_local, op=np.add, loop_indices=(local_i_k_seq_tile,),
+      dtype=mixed_dtype, mask=mask)
+
+
+@nki.jit
+def fused_self_attn_for_SD_small_head_size(q_ref, k_ref, v_ref, use_causal_mask=False,
+                                           mixed_precision=True):
+  """
+  Fused self attention kernel for small head size Stable Diffusion workload.
+
+  Computes softmax(QK^T)V. Decoder model can optionally include a causal mask
+  application. Does not include QKV projection, output projection, dropout,
+  residual connection, etc.
+
+  This kernel is designed to be used for Stable Diffusion models where the
+  n_heads is smaller or equal to 128. Assertion is thrown if `n_heads` does
+  not satisfy the requirement.
+
+  IO tensor layouts:
+   - q_ptr: shape   (bs, n_heads, seq_q)
+   - k_ptr: shape   (bs, seq_k, n_heads)
+   - v_ptr: shape   (bs, seq_v, n_heads)
+   - out_ptr: shape (bs, seq_q, n_heads)
+   - We use seq_q and seq_k just for clarity, this kernel requires seq_q == seq_k
+
+  IO tensor dtypes:
+   - This kernel assumes all IO tensors have the same dtype
+   - If mixed_precision is True, then all Tensor Engine operation will be performed in
+     bfloat16 and accumulation will be performed in float32. Otherwise the intermediates
+     will be in the same type as the inputs.
+  """
+  # Use q_ref dtype as the intermediate tensor dtype
+  # Assume all IO tensors have the same dtype
+  kernel_dtype = q_ref.dtype
+  pe_in_dt = nl.bfloat16 if mixed_precision else np.float32
+  assert q_ref.dtype == k_ref.dtype == v_ref.dtype
+
+  # Shape checking
+  bs, d_head, seqlen = q_ref.shape
+  assert d_head <= 128, "Cannot use this kernel for d_head > 128"
+  assert tuple(q_ref.shape) == (bs, d_head, seqlen), 'Input shape mismatch!'
+  assert tuple(k_ref.shape) == (bs, seqlen, d_head), 'Input shape mismatch!'
+  assert tuple(v_ref.shape) == (bs, seqlen,  d_head), \
+    f'Input shape mismatch! Expected: {(bs, seqlen, d_head)} Actual: {tuple(v_ref.shape)}'
+
+  out_ref = nl.ndarray((bs, seqlen, d_head), dtype=q_ref.dtype, buffer=nl.shared_hbm)
+
+  # Softmax scaling factor, multiplied onto Q
+  softmax_scale = 0.125
+
+  # Different batch samples/attention heads have independent attention
+  batch_id = nl.program_id(axis=0)
+  # batch_id = 0
+
+  # TODO: make q_seq_tile_size user input
+  # The matmuls currently use a fixed tile size of (128, 128). This may not achieve the best
+  # performance for dense attention. However, since this kernel is in preparation
+  # for block-sparse attention, this tile size is acceptable because the block
+  # size of block-sparse attention cannot be too large.
+  q_seq_n_tiles, q_seq_tile_size = seqlen // 128, 128
+  k_seq_n_tiles, k_seq_tile_size = seqlen // 128, 128
+  # No tiling on d_head dimension since the number of d_head fits in SB
+  d_head_tile_size = d_head
+  v_seq_n_tiles, v_seq_tile_size = seqlen // 128, 128
+
+  ###################################
+  # Step 1. transpose(tensor_v)
+  ###################################
+  # Buffer for v matrix transposed
+  # Pre-fetch and keep it in SBUF throughout different softmax tiles
+  trans_v = nl.ndarray((par_dim(v_seq_tile_size), v_seq_n_tiles, d_head), dtype=pe_in_dt)
+
+  for i_k_seq_tile in nl.affine_range(k_seq_n_tiles):
+    ip_v = nl.arange(v_seq_tile_size)[:, None]
+    if_v = nl.arange(d_head_tile_size)[None, :]
+    trans_v[ip_v, i_k_seq_tile, if_v] = nl.load(
+      v_ref[batch_id, i_k_seq_tile * k_seq_tile_size + ip_v, if_v],
+      dtype=pe_in_dt)
+
+  q_local = nl.ndarray((q_seq_n_tiles, par_dim(d_head_tile_size), q_seq_tile_size), dtype=pe_in_dt)
+  ip_q = nl.arange(d_head_tile_size)[:, None]
+  if_q = nl.arange(q_seq_tile_size)[None, :]
+  for i_q_seq_tile in nl.affine_range(q_seq_n_tiles):
+    q_local[i_q_seq_tile, ip_q, if_q] = nl.load(
+      q_ref[batch_id, ip_q, i_q_seq_tile * q_seq_tile_size + if_q],
+      dtype=pe_in_dt) * softmax_scale
+
+  k_local = nl.ndarray((k_seq_n_tiles, par_dim(d_head_tile_size), k_seq_tile_size), dtype=pe_in_dt)
+  ip_k = nl.arange(d_head_tile_size)[:, None]
+  if_k = nl.arange(k_seq_tile_size)[None, :]
+  for i_k_seq_tile in nl.affine_range(k_seq_n_tiles):
+    k_local[i_k_seq_tile, ip_k, if_k] = nl.load_transpose2d(
+      k_ref[batch_id,
+            i_k_seq_tile * k_seq_tile_size + nl.arange(k_seq_tile_size)[:, None],
+            nl.arange(d_head_tile_size)[None, :]],
+      dtype=pe_in_dt)
+
+  for i_q_seq_tile in nl.affine_range(q_seq_n_tiles):  # indent = 2
+    # A SBUF buffer for an independent softmax tile
+    qk_res_buf = nl.ndarray((par_dim(q_seq_tile_size), seqlen), dtype=kernel_dtype)
+
+    neg_max_res = nl.ndarray((par_dim(q_seq_tile_size), k_seq_n_tiles), dtype=kernel_dtype)
+    ip_max = nl.arange(q_seq_tile_size)[:, None]
+    if_max = nl.arange(k_seq_n_tiles)[None, :]
+
+    # Loop over RHS free of matmul(stationary=tensor_q, moving=tensor_k, contract=d_head)
+    for i_k_seq_tile in nl.affine_range(k_seq_n_tiles):  # indent = 4
+
+      # Since the K^T tile is the RHS, the q_seq_len dimension will be P in the result
+      # PSUM buffer shape: [q_seq_tile_size P, k_seq_tile_size F]
+      qk_psum = nl.zeros((par_dim(q_seq_tile_size), k_seq_tile_size),
+                         dtype=np.float32, buffer=nl.psum)
+
+      # Tensor indices for accessing qk result in k_seq_tile_size
+      ip_qk = nl.arange(q_seq_tile_size)[:, None]
+      if_qk = nl.arange(k_seq_tile_size)[None, :]
+
+      ##############################################################
+      # Step 2. matmul(stationary=tensor_q, moving=tensor_k, contract=d_head)
+      ##############################################################
+      qk_psum[ip_qk, if_qk] += nisa.nc_matmul(moving=k_local[i_k_seq_tile, ip_k, if_k],
+                                              stationary=q_local[i_q_seq_tile, ip_q, if_q])
+
+      ###################################
+      # Step 3. Apply optional causal mask
+      ###################################
+      if use_causal_mask:
+        # Magic number -9984.0 to replace -inf similar to what Tensorizer uses
+        qk_res_buf[ip_qk, i_k_seq_tile * k_seq_tile_size + if_qk] = nisa.affine_select(
+          pred=(i_q_seq_tile * q_seq_tile_size + ip_qk >= i_k_seq_tile * k_seq_tile_size + if_qk),
+          on_true_tile=qk_psum[ip_qk, if_qk], on_false_value=-9984.0, dtype=kernel_dtype)
+      else:
+        # Simply send psum result back to sbuf
+        qk_res_buf[ip_qk, i_k_seq_tile * k_seq_tile_size + if_qk] = nl.copy(qk_psum[ip_qk, if_qk],
+                                                                              dtype=kernel_dtype)
+
+      ###################################
+      # Step 4. Softmax
+      ###################################
+      # TODO: use TensorScalarCacheReduce to avoid an extra copy
+      # We want to break this reduction in tiles because we want to overlap it with the previous matmul
+      neg_max_res[ip_max, i_k_seq_tile] = nisa.tensor_reduce(
+        np.max, data=qk_res_buf[ip_qk, i_k_seq_tile * k_seq_tile_size + if_qk],
+        axis=(1,), dtype=kernel_dtype, negate=True)
+
+    neg_max_res_final = nisa.tensor_reduce(
+      np.min, data=neg_max_res[ip_max, if_max],
+      axis=(1,), dtype=kernel_dtype, negate=False)
+
+    ip_softmax = nl.arange(q_seq_tile_size)[:, None]
+    if_softmax = nl.arange(seqlen)[None, :]
+    ip_sum_res = nl.arange(q_seq_tile_size)[:, None]
+    if_sum_res = nl.arange(d_head_tile_size)[None, :]
+
+    softmax_res = nl.ndarray((par_dim(q_seq_tile_size), seqlen), dtype=pe_in_dt)
+    sum_divisor = nl.ndarray((par_dim(q_seq_tile_size), d_head_tile_size), dtype=kernel_dtype)
+
+    # Simply use a large tile of seq_len in size since this is a "blocking" instruction
+    # Assuming the compiler will merge exp and reduce_add into a single instruction on ACT
+    exp_res = nisa.activation(np.exp,
+                              data=qk_res_buf[ip_softmax, if_softmax],
+                              bias=neg_max_res_final, scale=1.0)
+
+    sum_res = nisa.tensor_reduce(np.add, data=exp_res, axis=(1,),
+                          dtype=kernel_dtype)
+    softmax_res[ip_softmax, if_softmax] = nl.copy(exp_res, dtype=pe_in_dt)
+
+    sum_reciprocal_broadcast = (1.0 / sum_res).broadcast_to((q_seq_tile_size, d_head_tile_size))
+    sum_divisor[ip_sum_res, if_sum_res] = nl.copy(sum_reciprocal_broadcast, dtype=kernel_dtype)
+
+    # Buffer for transposed softmax results (FP32 in PSUM)
+    trans_softmax_res = nl.ndarray(
+      (par_dim(k_seq_tile_size), k_seq_n_tiles, q_seq_tile_size),
+      dtype=pe_in_dt)
+
+    # Result psum buffer has the hidden dim as P
+    attn_res_psum = nl.zeros((par_dim(d_head_tile_size), q_seq_tile_size),
+                             dtype=np.float32, buffer=nl.psum)
+
+    ip_scores_t = nl.arange(k_seq_tile_size)[:, None]
+    if_scores_t = nl.arange(q_seq_tile_size)[None, :]
+    # Loop over matmul_1 contraction
+    for i_k_seq_tile in nl.affine_range(k_seq_n_tiles):
+      ###################################
+      # Step 5. transpose(softmax_res)
+      ###################################
+      ip_scores = nl.arange(q_seq_tile_size)[:, None]
+      if_scores = nl.arange(k_seq_tile_size)[None, :]
+
+      trans_softmax_res[ip_scores_t, i_k_seq_tile, if_scores_t] = nisa.nc_transpose(
+        softmax_res[ip_scores, i_k_seq_tile * k_seq_tile_size + if_scores])
+
+    ip_out = nl.arange(d_head_tile_size)[:, None]
+    if_out = nl.arange(q_seq_tile_size)[None, :]
+    for i_k_seq_tile in nl.affine_range(k_seq_n_tiles):
+      ######################################################################
+      # Step 6. matmul_1(stationary=trans_v, moving=trans_softmax_res, contract=seqlen_v=seqlen_k)
+      ######################################################################
+      ip_v_t = nl.arange(k_seq_tile_size)[:, None]
+      if_v_t = nl.arange(d_head_tile_size)[None, :]
+      attn_res_psum[ip_out, if_out] += \
+        nisa.nc_matmul(moving=trans_softmax_res[ip_scores_t, i_k_seq_tile, if_scores_t],
+                       stationary=trans_v[ip_v_t, i_k_seq_tile, if_v_t])
+
+    attn_res_sbuf = nl.copy(attn_res_psum[ip_out, if_out], dtype=kernel_dtype)
+
+    attn_res_div = attn_res_sbuf * nisa.nc_transpose(sum_divisor[ip_sum_res, if_sum_res])
+
+    nl.store(
+      out_ref[batch_id, i_q_seq_tile * q_seq_tile_size + if_out, ip_out],
+      value=attn_res_div)
+
+  return out_ref
diff --git a/src/nki_samples/reference/tutorial.py b/src/nki_samples/reference/tutorial.py
new file mode 100644
index 0000000..b32492b
--- /dev/null
+++ b/src/nki_samples/reference/tutorial.py
@@ -0,0 +1,31 @@
+"""
+Copyright (c) 2023, Amazon.com. All Rights Reserved
+
+kernels - Builtin high performance NKI kernels used in tutorial
+
+"""
+
+from neuronxcc import nki
+import neuronxcc.nki.language as nl
+
+
+@nki.jit
+def add_kernel_nx8x128x512(a_ptr, b_ptr, n_elements):
+  c_ptr = nl.ndarray(a_ptr.shape, dtype=a_ptr.dtype, buffer=nl.shared_hbm)
+
+  ix = nl.arange(128)[:, None]
+  iy = nl.arange(512)[None, :]
+
+  tile_size = 128 * 512
+  block_size = 8 * tile_size
+
+  j = nl.program_id(axis=0)
+
+  for i in nl.affine_range(8):
+    offset = j * block_size + i * tile_size + 512 * ix + iy
+    a = nl.load(a_ptr[j, i, ix, iy], mask=offset < n_elements)
+    b = nl.load(b_ptr[j, i, ix, iy], mask=offset < n_elements)
+    c = nl.add(a, b, mask=offset < n_elements)
+    nl.store(c_ptr[j, i, ix, iy], value=c, mask=offset < n_elements)
+
+  return c_ptr
diff --git a/src/reference/vision.py b/src/nki_samples/reference/vision.py
similarity index 93%
rename from src/reference/vision.py
rename to src/nki_samples/reference/vision.py
index bc54941..4899d27 100644
--- a/src/reference/vision.py
+++ b/src/nki_samples/reference/vision.py
@@ -8,10 +8,13 @@
 
 import neuronxcc.nki.language as nl
 import neuronxcc.nki.isa as nisa
+from neuronxcc import nki
 from neuronxcc.nki.language import par_dim
 import neuronxcc.nki.typing as nt
 
-def select_and_scatter_kernel(operand_tensor, source_tensor, out_tensor):
+
+@nki.jit
+def select_and_scatter_kernel(operand_tensor, source_tensor):
   """
   Implementation of a select-and-scatter kernel.
 
@@ -51,7 +54,10 @@ def select_and_scatter_kernel(operand_tensor, source_tensor, out_tensor):
   assert C == 64 and N % 2 == 0
 
   kernel_dtype = operand_tensor.dtype
-  assert operand_tensor.dtype == source_tensor.dtype == out_tensor.dtype
+  assert operand_tensor.dtype == source_tensor.dtype
+
+  out_tensor = nl.ndarray((N, C, H, W), dtype=operand_tensor.dtype,
+                          buffer=nl.shared_hbm)
 
   p = 128  # num of partitions to use
   for ib in nl.affine_range(N // 2):
@@ -156,8 +162,11 @@ def select_and_scatter_kernel(operand_tensor, source_tensor, out_tensor):
       nl.store(out_tensor[2 * ib + ib_1, 0:64, 0:H, 0:W],
                value=out_local[(ib_1 * 64):((ib_1 + 1) * 64), 0:H, 0:W])
 
+  return out_tensor
 
-def resize_nearest_fixed_dma_kernel(data_tensor, out_tensor):
+
+@nki.jit
+def resize_nearest_fixed_dma_kernel(data_tensor, out_shape):
   """
   Resize the input image to the given size using the nearest interpolation mode. This kernel is designed to be used when the scaling factor is not an integer. 
 
@@ -174,7 +183,9 @@ def resize_nearest_fixed_dma_kernel(data_tensor, out_tensor):
   
   """
   in_b, in_h, in_w, in_c = data_tensor.shape
-  out_b, out_h, out_w, out_c = out_tensor.shape
+  out_b, out_h, out_w, out_c = out_shape
+  out_tensor = nl.ndarray(out_shape, dtype=data_tensor.dtype,
+                          buffer=nl.shared_hbm)
 
   assert in_b == out_b, "Input batch and output batch must be identical"
   assert in_c == out_c, "Input channel and output channel must be identical"
@@ -198,3 +209,5 @@ def resize_nearest_fixed_dma_kernel(data_tensor, out_tensor):
     local_data = nl.load(target_addr)
     dst_addr_0 = out_tile[b_map, i, c_map]
     nl.store(dst_addr_0, value=local_data)
+
+  return out_tensor
diff --git a/src/tutorials/average_pool2d/average_pool2d_jax.py b/src/nki_samples/tutorials/average_pool2d/average_pool2d_jax.py
similarity index 68%
rename from src/tutorials/average_pool2d/average_pool2d_jax.py
rename to src/nki_samples/tutorials/average_pool2d/average_pool2d_jax.py
index e3b428d..139c42d 100644
--- a/src/tutorials/average_pool2d/average_pool2d_jax.py
+++ b/src/nki_samples/tutorials/average_pool2d/average_pool2d_jax.py
@@ -4,29 +4,22 @@
 JAX implementation for average pool 2D NKI tutorial.
 
 """
-from functools import partial
-from jax_neuronx import nki_call
-import jax
+# NKI_EXAMPLE_40_BEGIN
 import jax.numpy as jnp
-
-from average_pool2d_nki_kernels import tensor_avgpool_kernel_
-
-
-def tensor_avgpool_kernel(in_array, pool_size):
-  return nki_call(
-    partial(tensor_avgpool_kernel_, pool_size=pool_size),
-    in_array,
-    out_shape=jax.ShapeDtypeStruct((C, HOUT, WOUT), dtype=in_array.dtype),
-  )
+# NKI_EXAMPLE_40_END
+from average_pool2d_nki_kernels import tensor_avgpool_kernel
 
 
+# NKI_EXAMPLE_40_BEGIN
 # Reference JAX implementation
 def jax_average_pool_2D(in_tensor, pool_size):
   c, h_in, w_in = in_tensor.shape
   reshaped = in_tensor.reshape(c, h_in // pool_size, pool_size, w_in // pool_size, pool_size)
   return jnp.nanmean(reshaped, axis=(2, 4))
+  # NKI_EXAMPLE_40_END
 
 
+# NKI_EXAMPLE_41_BEGIN
 if __name__ == "__main__":
   POOL_SIZE = 2
   C, HIN, WIN = 2, 6, 6
@@ -34,7 +27,9 @@ def jax_average_pool_2D(in_tensor, pool_size):
 
   in_array = jnp.arange(C * HIN * WIN, dtype=jnp.float32).reshape(C, HIN, WIN)
 
+  # NKI_EXAMPLE_39_BEGIN
   out_nki = tensor_avgpool_kernel(in_array, pool_size=POOL_SIZE)
+  # NKI_EXAMPLE_39_END
   out_jax = jax_average_pool_2D(in_array, pool_size=POOL_SIZE)
 
   print(in_array, out_nki, out_jax)
@@ -42,4 +37,5 @@ def jax_average_pool_2D(in_tensor, pool_size):
   if jnp.allclose(out_nki, out_jax):
     print("NKI and JAX match")
   else:
-    print("NKI and JAX differ")
\ No newline at end of file
+    print("NKI and JAX differ")
+    # NKI_EXAMPLE_41_END
diff --git a/src/tutorials/average_pool2d/average_pool2d_nki_kernels.py b/src/nki_samples/tutorials/average_pool2d/average_pool2d_nki_kernels.py
similarity index 59%
rename from src/tutorials/average_pool2d/average_pool2d_nki_kernels.py
rename to src/nki_samples/tutorials/average_pool2d/average_pool2d_nki_kernels.py
index c81a4a5..68d3a31 100644
--- a/src/tutorials/average_pool2d/average_pool2d_nki_kernels.py
+++ b/src/nki_samples/tutorials/average_pool2d/average_pool2d_nki_kernels.py
@@ -5,48 +5,40 @@
 
 """
 import numpy as np
+# NKI_EXAMPLE_37_BEGIN
 import neuronxcc.nki as nki
 import neuronxcc.nki.language as nl
+from neuronxcc.nki.typing import tensor
 
-
-def tensor_avgpool_kernel_(in_tensor, out_tensor, pool_size):
+@nki.jit
+def tensor_avgpool_kernel(in_tensor, pool_size):
   """NKI kernel to compute a 2D avg-pool operation
 
   Args:
       in_tensor: an input tensor, of shape C x H x W
       pool_size: an integer representing a (square) pool-window size
+
+  Return:
       out_tensor: the resulting output tensor, of shape C x (H/pool_size) x (W/pool_size)
   """
 
   # Get input/output dimensions
   sz_cin, sz_hin, sz_win = in_tensor.shape
-  sz_cout, sz_hout, sz_wout = out_tensor.shape
-  assert sz_cin == sz_cout
+  sz_hout = sz_hin // pool_size
+  sz_wout = sz_win // pool_size
+  # Create output tensor shared between all SPMD instances as result tensor
+  out_tensor = nl.ndarray((sz_cin, sz_hout, sz_wout), dtype=in_tensor.dtype,
+                          buffer=nl.shared_hbm)
 
   # Set relevant sizes
   sz_p = sz_cin
   sz_pool = pool_size
 
-  # Generate tensor h/w index patterns
-  # 3D indexing according to [C, H, W]
-  i_p = nl.arange(sz_p)[:, None, None] # 3D for
-  i_win = nl.arange(sz_win)[None, None, :]
-  i_hin = nl.arange(sz_hin)[None, :, None]
-
-  i_wout = nl.arange(sz_wout)[None, None, :]
-  i_hout = nl.arange(sz_hout)[None, :, None]
-
   # Generate pool index patterns (requires two extra dimensions, for the pool window)
-  i_0 = nl.arange(sz_p)[:, None, None, None, None] #
-  i_1 = nl.arange(sz_hin//sz_pool)[None, :, None, None, None] # y_outer
-  i_2 = nl.arange(sz_pool)[None, None, :, None, None] # y_inner
-  i_3 = nl.arange(sz_win//sz_pool)[None, None, None, :, None] # x_outer
-  i_4 = nl.arange(sz_pool)[None, None, None, None, :] # x_inner
+  i0, i1, i2, i3, i4 = nl.mgrid[0:sz_p, 0:sz_hin//sz_pool, 0:sz_pool, 0:sz_win//sz_pool, 0:sz_pool]
 
   # Load input data from external memory to on-chip memory
-  # Declare ndarray to force a 3D tensor (temporary requirement)
-  in_tile = nl.ndarray([sz_p, sz_hin, sz_win], dtype=in_tensor.dtype)
-  in_tile[:,:,:] = nl.load(in_tensor[i_p, i_hin, i_win])
+  in_tile: tensor[sz_p, sz_hin, sz_win] = nl.load(in_tensor)
 
   # Perform the pooling operation:
   # We use numpy's advanced indexing, in order to extend in_tile to 5D, and then reduce-average two dimension.
@@ -54,10 +46,15 @@ def tensor_avgpool_kernel_(in_tensor, out_tensor, pool_size):
   # axis[1] and axis[2] together index the rows, with axis[2] responsible for inner strides
   # (i.e. inside a pooling window), and axis[1] responsible for the outer strides. As such, we reduce over axis[2].
   # Similarly, axis[3] and axis[4] together index the columns, and we thus reduce over axis[4].
-  out_tile = nl.sum(in_tile[i_0, sz_pool*i_1+i_2, sz_pool*i_3+i_4], axis=[2,4]) / (pool_size*pool_size)
+  out_tile : tensor[sz_p, sz_hout, sz_wout] = nl.sum(in_tile[i0, sz_pool*i1+i2, sz_pool*i3+i4],
+                                                     axis=[2,4]) / (pool_size*pool_size)
+
+  # Store the results back to hbm
+  nl.store(out_tensor, value=out_tile)
 
-  # Store the results back to external memory
-  nl.store(out_tensor[i_p, i_hout, i_wout], value=out_tile)
+  # Transfer the ownership of `out_tensor` to the caller
+  return out_tensor
+  # NKI_EXAMPLE_37_END
 
 
 # Reference NumPy implementation
@@ -74,10 +71,8 @@ def np_average_pool_2D(in_tensor, pool_size):
   HOUT, WOUT = HIN//POOL_SIZE, WIN//POOL_SIZE
 
   in_tensor = np.arange(C * HIN * WIN, dtype=np.float16).reshape(C, HIN, WIN)
-  out_nki = np.zeros((C, HOUT, WOUT), dtype=np.float16)
 
-  tensor_avgpool_kernel_baremetal = nki.baremetal(tensor_avgpool_kernel_)
-  tensor_avgpool_kernel_baremetal(in_tensor, out_nki, POOL_SIZE)
+  out_nki = tensor_avgpool_kernel(in_tensor, POOL_SIZE)
 
   out_np = np_average_pool_2D(in_tensor, POOL_SIZE)
 
diff --git a/src/tutorials/average_pool2d/average_pool2d_torch.py b/src/nki_samples/tutorials/average_pool2d/average_pool2d_torch.py
similarity index 78%
rename from src/tutorials/average_pool2d/average_pool2d_torch.py
rename to src/nki_samples/tutorials/average_pool2d/average_pool2d_torch.py
index 3409a31..c5fb4ea 100644
--- a/src/tutorials/average_pool2d/average_pool2d_torch.py
+++ b/src/nki_samples/tutorials/average_pool2d/average_pool2d_torch.py
@@ -4,13 +4,14 @@
 PyTorch implementation for average pool 2D NKI tutorial.
 
 """
+# NKI_EXAMPLE_38_BEGIN
 import torch
-from torch_neuronx import nki_jit
 from torch_xla.core import xla_model as xm
-
-from average_pool2d_nki_kernels import tensor_avgpool_kernel_
+# NKI_EXAMPLE_38_END
+from average_pool2d_nki_kernels import tensor_avgpool_kernel
 
 
+# NKI_EXAMPLE_38_BEGIN
 if __name__ == "__main__":
   device = xm.xla_device()
 
@@ -22,8 +23,7 @@
   in_tensor = torch.arange(C * HIN * WIN, dtype=torch.bfloat16).reshape(C, HIN, WIN).to(device=device)
   out_nki = torch.zeros((C, HOUT, WOUT), dtype=torch.bfloat16).to(device=device)
 
-  tensor_avgpool_kernel_torch = nki_jit(tensor_avgpool_kernel_)
-  tensor_avgpool_kernel_torch(in_tensor, out_nki, POOL_SIZE)
+  out_nki = tensor_avgpool_kernel(in_tensor, POOL_SIZE)
 
   out_torch = torch.nn.functional.avg_pool2d(in_tensor, POOL_SIZE, POOL_SIZE)
 
@@ -33,3 +33,4 @@
     print("NKI and Torch match")
   else:
     print("NKI and Torch differ")
+    # NKI_EXAMPLE_38_END
diff --git a/src/tutorials/fused_mamba/mamba_nki_kernels.py b/src/nki_samples/tutorials/fused_mamba/mamba_nki_kernels.py
similarity index 94%
rename from src/tutorials/fused_mamba/mamba_nki_kernels.py
rename to src/nki_samples/tutorials/fused_mamba/mamba_nki_kernels.py
index 9f8af60..4ff6642 100644
--- a/src/tutorials/fused_mamba/mamba_nki_kernels.py
+++ b/src/nki_samples/tutorials/fused_mamba/mamba_nki_kernels.py
@@ -4,16 +4,19 @@
 Mamba-v1 NKI kernel implementation.
 
 """
+# NKI_EXAMPLE_25_BEGIN
 import neuronxcc.nki as nki
 import neuronxcc.nki.language as nl
 import neuronxcc.nki.isa as nisa
 import numpy as np
+# NKI_EXAMPLE_25_END
 import os
 import argparse
 import itertools
 
-
-def mamba_v1(delta, u, A, B, C, output):
+# NKI_EXAMPLE_25_BEGIN
+@nki.jit
+def mamba_v1(delta, u, A, B, C):
     """Computes the SSM operation in the Mamba model.
 
     :param delta: (batch_size, channels, seq_len)
@@ -24,6 +27,9 @@ def mamba_v1(delta, u, A, B, C, output):
     :return: (batch_size, channels, seq_len)
     """
     batch_size, channels, seq_len = delta.shape
+    output = nl.ndarray((batch_size, channels, seq_len), dtype=delta.dtype,
+                        buffer=nl.shared_hbm)
+
     _, state_size = A.shape
 
     # We can relax this using mask paramters in all the NKI API calls
@@ -84,8 +90,12 @@ def mamba_v1(delta, u, A, B, C, output):
             nl.store(output[i_batch, channel_start:channel_start+channel_psize, 0:seq_len],
                     scanC_accum[i_channel_tile, 0:channel_psize, 0:seq_len])
 
+    return output
+# NKI_EXAMPLE_25_END
 
-def mamba_v2(delta, u, A, B, C, output):
+# NKI_EXAMPLE_26_BEGIN
+@nki.jit
+def mamba_v2(delta, u, A, B, C):
     """Computes the SSM operation in the Mamba model.
 
     :param delta: (batch_size, channels, seq_len)
@@ -96,6 +106,8 @@ def mamba_v2(delta, u, A, B, C, output):
     :return: (batch_size, channels, seq_len)
     """
     batch_size, channels, seq_len = delta.shape
+    output = nl.ndarray((batch_size, channels, seq_len), dtype=delta.dtype,
+                        buffer=nl.shared_hbm)
     _, state_size = A.shape
 
     assert channels % 128 == 0
@@ -153,8 +165,12 @@ def mamba_v2(delta, u, A, B, C, output):
             nl.store(output[i_batch, channel_start:channel_start+channel_psize, 0:seq_len],
                     scanC_accum[0:channel_psize, 0:seq_len])
 
+    return output
+# NKI_EXAMPLE_26_END
+
 
-def mamba_v3(delta, u, A, B, C, output):
+@nki.jit
+def mamba_v3(delta, u, A, B, C):
     """Computes the SSM operation in the Mamba model.
 
     :param delta: (batch_size, channels, seq_len)
@@ -165,6 +181,8 @@ def mamba_v3(delta, u, A, B, C, output):
     :return: (batch_size, channels, seq_len)
     """
     batch_size, channels, seq_len = delta.shape
+    output = nl.ndarray((batch_size, channels, seq_len), dtype=delta.dtype,
+                        buffer=nl.shared_hbm)
     _, state_size = A.shape
 
     # Map channels to the partition dimension
@@ -239,6 +257,7 @@ def mamba_v3(delta, u, A, B, C, output):
             # Store scanC_accum for a single batch to output
             nl.store(output[i_batch, channel_start:channel_start+channel_psize, 0:seq_len],
                     scanC_accum[0:channel_psize, 0:seq_len])
+    return output
 
 
 def parse_args():
@@ -310,9 +329,7 @@ def parse_args():
         if args.mode == "accuracy":
             # v1: reference kernel
             print(f">>>> Running v1 (reference).")
-            nki_out_v1 = np.empty((batch, channels, seq_len), dtype=dtype)
-            nki.baremetal(mamba_v1)\
-                         (delta, u, A, B, C, nki_out_v1)
+            nki_out_v1 = mamba_v1(delta, u, A, B, C)
 
             for version in args.version:
                 if version == "v1":
@@ -321,9 +338,7 @@ def parse_args():
 
                 print(f">>>> Running version {version}.")
                 func = func_dict[version]
-                nki_out_test = np.empty((batch, channels, seq_len), dtype=dtype)
-                nki.baremetal(func)\
-                             (delta, u, A, B, C, nki_out_test)
+                nki_out_test = func(delta, u, A, B, C)
                 print(f">>>> mamba {version} matches?", np.all(nki_out_test == nki_out_v1))
                 assert np.all(nki_out_test == nki_out_v1)
 
@@ -333,11 +348,10 @@ def parse_args():
             for version in args.version:
                 print(f">>>> Running version {version}.")
                 func = func_dict[version]
-                nki_out_test = np.empty((batch, channels, seq_len), dtype=dtype)
                 nki.benchmark(func,
                               save_neff_name='file.neff',
                               save_trace_name='profile.ntff')\
-                             (delta, u, A, B, C, nki_out_test)
+                             (delta, u, A, B, C)
                 # TODO: rename neff/ntff (bug in nki.benchmark with neff name)
                 os.rename("file.neff", f"{version}_b{batch}_sl{seq_len}_c{channels}_ss{state_size}.neff")
                 os.rename("profile.ntff", f"{version}_b{batch}_sl{seq_len}_c{channels}_ss{state_size}.ntff")
diff --git a/src/tutorials/fused_mamba/mamba_torch.py b/src/nki_samples/tutorials/fused_mamba/mamba_torch.py
similarity index 95%
rename from src/tutorials/fused_mamba/mamba_torch.py
rename to src/nki_samples/tutorials/fused_mamba/mamba_torch.py
index a2e593f..cd94a0b 100644
--- a/src/tutorials/fused_mamba/mamba_torch.py
+++ b/src/nki_samples/tutorials/fused_mamba/mamba_torch.py
@@ -5,6 +5,7 @@
 
 """
 
+# NKI_EXAMPLE_24_BEGIN
 import torch
 import torch_neuronx
 import torch_xla.core.xla_model as xm
@@ -99,16 +100,14 @@ def parse_args():
     torch_out = mamba_layer(delta, A, B, u, C)
     xm.mark_step()
     print(torch_out)
+    # NKI_EXAMPLE_24_END
 
     if args.mode == "accuracy":
         # Call NKI mamba_v1 kernel to check accuracy
         from mamba_nki_kernels import mamba_v1
-        from torch_neuronx import nki_jit
-
-        nki_out = torch.empty((batch, channels, seq_len), dtype=dtype, device=device)
 
         xm.mark_step()
-        nki_jit(mamba_v1)(delta, u, A, B, C, nki_out)
+        nki_out = mamba_v1(delta, u, A, B, C)
         xm.mark_step()
 
         allclose = torch.allclose(torch_out, nki_out, atol=1e-2, rtol=1e-2)
diff --git a/src/tutorials/layernorm/layernorm_nki_kernel.py b/src/nki_samples/tutorials/layernorm/layernorm_nki_kernel.py
similarity index 64%
rename from src/tutorials/layernorm/layernorm_nki_kernel.py
rename to src/nki_samples/tutorials/layernorm/layernorm_nki_kernel.py
index 503ce7d..c0c235c 100644
--- a/src/tutorials/layernorm/layernorm_nki_kernel.py
+++ b/src/nki_samples/tutorials/layernorm/layernorm_nki_kernel.py
@@ -4,21 +4,27 @@
 LayerNorm NKI kernel implementation.
 
 """
+# NKI_EXAMPLE_45_BEGIN
 import neuronxcc.nki as nki
 import neuronxcc.nki.language as nl
 import neuronxcc.nki.isa as nisa
 import numpy as np
 import math
+# NKI_EXAMPLE_45_END
 import os
 import argparse
 
 
-def nki_layernorm_kernel_v1(input_tensor, epsilon, gamma_vector, beta_vector, output_tensor):
+# NKI_EXAMPLE_45_BEGIN
+@nki.jit
+def nki_layernorm_kernel_v1(input_tensor, epsilon, gamma_vector, beta_vector):
   """Computes LayerNorm.
     Used nki.language APIs only.
   """
+  output_tensor = nl.ndarray(input_tensor.shape, dtype=input_tensor.dtype,
+                             buffer=nl.shared_hbm)
+
   # Ensure that the shapes of tensors match
-  assert input_tensor.shape == output_tensor.shape
   assert input_tensor.shape[1] == gamma_vector.shape[0] == beta_vector.shape[0]
 
   # Generate tile indices for loading/storing data
@@ -58,12 +64,20 @@ def nki_layernorm_kernel_v1(input_tensor, epsilon, gamma_vector, beta_vector, ou
     nl.store(output_tensor[i * nl.tile_size.pmax + i_p_io, i_f_io], value=output_sb,
              mask=(i * nl.tile_size.pmax + i_p_io < num_rows))
 
-def nki_layernorm_kernel_v2(input_tensor, epsilon, gamma_vector, beta_vector, output_tensor):
+  return output_tensor
+  # NKI_EXAMPLE_45_END
+
+
+# NKI_EXAMPLE_46_BEGIN
+@nki.jit
+def nki_layernorm_kernel_v2(input_tensor, epsilon, gamma_vector, beta_vector):
   """Computes LayerNorm.
     Used nki.isa APIs to calculate mean/variance and perform shift/scale.
   """
+  output_tensor = nl.ndarray(input_tensor.shape, dtype=input_tensor.dtype,
+                             buffer=nl.shared_hbm)
+
   # Ensure that the shapes of tensors match
-  assert input_tensor.shape == output_tensor.shape
   assert input_tensor.shape[1] == gamma_vector.shape[0] == beta_vector.shape[0]
 
   # Generate tile indices for loading/storing data
@@ -122,69 +136,66 @@ def nki_layernorm_kernel_v2(input_tensor, epsilon, gamma_vector, beta_vector, ou
     nl.store(output_tensor[i * nl.tile_size.pmax + i_p_io, i_f_io], value=output_sb,
              mask=(i * nl.tile_size.pmax + i_p_io < num_rows))
 
+  return output_tensor
+  # NKI_EXAMPLE_46_END
+
 
 def parse_args():
-    parser = argparse.ArgumentParser(
-    """Run LayerNorm pytorch implementation.
-    """)
-    parser.add_argument("--nrows",
-                        default=4*1024,
-                        type=int,
-                        help="""The number of input rows""")
-    parser.add_argument("--ncols",
-                        default=8*1024,
-                        type=int,
-                        help="""The number of input columns""")
-    parser.add_argument("--mode",
-                        choices=["accuracy", "perf"],
-                        default="accuracy",
-                        help="""Do accuracy test or perf test.
-                                Accuracy test compares LayerNorm kernel against PyTorch implementation.
-                                Perf test will generate a NEFF for the PyTorch implementation in local directory
-                                for a manual run of neuron-profile.
-                             """)
-    args = parser.parse_args()
-    return args
+  parser = argparse.ArgumentParser(
+  """Run LayerNorm pytorch implementation.
+  """)
+  parser.add_argument("--nrows",
+                      default=4*1024,
+                      type=int,
+                      help="""The number of input rows""")
+  parser.add_argument("--ncols",
+                      default=8*1024,
+                      type=int,
+                      help="""The number of input columns""")
+  parser.add_argument("--mode",
+                      choices=["accuracy", "perf"],
+                      default="accuracy",
+                      help="""Do accuracy test or perf test.
+                              Accuracy test compares LayerNorm kernel against PyTorch implementation.
+                              Perf test will generate a NEFF for the PyTorch implementation in local directory
+                              for a manual run of neuron-profile.
+                           """)
+  args = parser.parse_args()
+  return args
 
 if __name__ == "__main__":
-    args = parse_args()
-    func_dict = {"v1": nki_layernorm_kernel_v1,
-                 "v2": nki_layernorm_kernel_v2,
-                 }
-
-    # Generate toy example
-    num_rows = args.nrows
-    num_cols = args.ncols
-    input_tensor = np.random.rand(num_rows, num_cols).astype(np.float32)
-    gamma_vector = np.random.rand(num_cols).astype(np.float32)
-    beta_vector = np.random.rand(num_cols).astype(np.float32)
-    epsilon = 1e-5
-            
-    if args.mode == "accuracy":
-        # version 1
-        print(f">>>> Running version 1")
-        nki_out_v1 = np.empty((num_rows, num_cols), dtype=np.float32)
-        nki.baremetal(nki_layernorm_kernel_v1)\
-                    (input_tensor, epsilon, gamma_vector, beta_vector, nki_out_v1)
-        # version 2
-        print(f">>>> Running version 2")
-        nki_out_v2 = np.empty((num_rows, num_cols), dtype=np.float32)
-        nki.baremetal(nki_layernorm_kernel_v2)\
-                    (input_tensor, epsilon, gamma_vector, beta_vector, nki_out_v2)
-        # compare
-        np_all = np.all(nki_out_v1 == nki_out_v1)
-        print(f">>>> LayerNorm V1 and V2 matches?", np_all)
-        assert np_all
-                
-    else:
-      # perf mode
-      for version in ["v1", "v2"]:
-          print(f">>>> Running version {version}.")
-          func = func_dict[version]
-          nki_out_test = np.empty((num_rows, num_cols), dtype=np.float32)
-          nki.benchmark(func,
-                        save_neff_name='file.neff',
-                        save_trace_name='profile.ntff')\
-                        (input_tensor, epsilon, gamma_vector, beta_vector, nki_out_test)
-          os.rename("file.neff", f"{version}_{num_rows}_{num_cols}.neff")
-          os.rename("profile.ntff", f"{version}_{num_rows}_{num_cols}.ntff")
+  args = parse_args()
+  func_dict = {"v1": nki_layernorm_kernel_v1,
+               "v2": nki_layernorm_kernel_v2,
+               }
+
+  # Generate toy example
+  num_rows = args.nrows
+  num_cols = args.ncols
+  input_tensor = np.random.rand(num_rows, num_cols).astype(np.float32)
+  gamma_vector = np.random.rand(num_cols).astype(np.float32)
+  beta_vector = np.random.rand(num_cols).astype(np.float32)
+  epsilon = 1e-5
+
+  if args.mode == "accuracy":
+    # version 1
+    print(f">>>> Running version 1")
+    nki_out_v1 = nki_layernorm_kernel_v1(input_tensor, epsilon, gamma_vector, beta_vector)
+    # version 2
+    print(f">>>> Running version 2")
+    nki_out_v2 = nki_layernorm_kernel_v2(input_tensor, epsilon, gamma_vector, beta_vector)
+    # compare
+    np_all = np.all(nki_out_v1 == nki_out_v1)
+    print(f">>>> LayerNorm V1 and V2 matches?", np_all)
+    assert np_all
+
+  else:
+    # perf mode
+    for version in ["v1", "v2"]:
+      print(f">>>> Running version {version}.")
+      func = func_dict[version]
+      benchmark_kernel = nki.benchmark(func, save_neff_name='file.neff',
+                                       save_trace_name='profile.ntff')
+      nki_out_test = benchmark_kernel(input_tensor, epsilon, gamma_vector, beta_vector)
+      os.rename("file.neff", f"{version}_{num_rows}_{num_cols}.neff")
+      os.rename("profile.ntff", f"{version}_{num_rows}_{num_cols}.ntff")
diff --git a/src/tutorials/layernorm/layernorm_torch.py b/src/nki_samples/tutorials/layernorm/layernorm_torch.py
similarity index 87%
rename from src/tutorials/layernorm/layernorm_torch.py
rename to src/nki_samples/tutorials/layernorm/layernorm_torch.py
index 59853fd..c2be186 100644
--- a/src/tutorials/layernorm/layernorm_torch.py
+++ b/src/nki_samples/tutorials/layernorm/layernorm_torch.py
@@ -4,9 +4,9 @@
 LayerNorm NKI kernel implementation.
 
 """
+# NKI_EXAMPLE_47_BEGIN
 import torch
 from torch_xla.core import xla_model as xm
-from torch_neuronx import nki_jit
 import argparse
 import os
 
@@ -42,13 +42,16 @@ def parse_args():
     args = parser.parse_args()
     return args
 
+
+from neuronxcc.nki.docs.examples.layernorm.layernorm_nki_kernel import nki_layernorm_kernel_v1, \
+  nki_layernorm_kernel_v2
+
 if __name__ == "__main__":
     args = parse_args()
-    from neuronxcc.nki.docs.examples.layernorm.layernorm_nki_kernel import nki_layernorm_kernel_v1, nki_layernorm_kernel_v2
     func_dict = {"v1": nki_layernorm_kernel_v1,
                  "v2": nki_layernorm_kernel_v2,
                  }
-    
+
     device = xm.xla_device()
     num_rows = args.nrows
     num_cols = args.ncols
@@ -58,7 +61,7 @@ def parse_args():
     gamma_vector = torch.rand((num_cols), dtype=torch.float32)
     beta_vector = torch.rand((num_cols), dtype=torch.float32)
     epsilon = 1e-5
-    
+
     # Compute torch layernorm layer in cpu
     output_torch = layernorm_layer(input_tensor, epsilon, gamma_vector, beta_vector)
 
@@ -66,17 +69,15 @@ def parse_args():
     input_tensor = input_tensor.to(device=device)
     gamma_vector = gamma_vector.to(device=device)
     beta_vector = beta_vector.to(device=device)
-    output_nki = torch.zeros((num_rows, num_cols), dtype=torch.float32).to(device=device)
 
     print(f">>>> Running version {args.version}.")
     func = func_dict[args.version]
 
     # add nki_jit decorator
-    nki_layernorm_kernel = nki_jit(func)
 
     # Compute NKI layernorm kernel in NeuronDevice
     xm.mark_step()
-    nki_layernorm_kernel(input_tensor, epsilon, gamma_vector, beta_vector, output_nki)
+    output_nki = func(input_tensor, epsilon, gamma_vector, beta_vector)
     xm.mark_step()
     output_nki = output_nki.to(device='cpu')
 
@@ -86,5 +87,6 @@ def parse_args():
         print("NKI and Torch match")
     else:
         print("NKI and Torch differ")
-    
+        # NKI_EXAMPLE_47_END
+
     assert allclose
\ No newline at end of file
diff --git a/src/tutorials/matrix_multiplication/matrix_multiplication_nki_kernels.py b/src/nki_samples/tutorials/matrix_multiplication/matrix_multiplication_nki_kernels.py
similarity index 94%
rename from src/tutorials/matrix_multiplication/matrix_multiplication_nki_kernels.py
rename to src/nki_samples/tutorials/matrix_multiplication/matrix_multiplication_nki_kernels.py
index 7aeb5d6..8f913f2 100644
--- a/src/tutorials/matrix_multiplication/matrix_multiplication_nki_kernels.py
+++ b/src/nki_samples/tutorials/matrix_multiplication/matrix_multiplication_nki_kernels.py
@@ -12,7 +12,9 @@
 import numpy as np
 
 
-def nki_matmul_basic_(lhsT, rhs, result):
+# NKI_EXAMPLE_16_BEGIN
+@nki.jit
+def nki_matmul_basic_(lhsT, rhs):
   """NKI kernel to compute a 64x128x512 matrix multiplication operation
 
   Args:
@@ -20,8 +22,11 @@ def nki_matmul_basic_(lhsT, rhs, result):
         matrix multiplication, delivered transposed for optimal performance
       rhs: an input tensor of shape [128,512], a right hand side argument of the
         matrix multiplication
+  Returns:
       result: the resulting output tensor of shape [64,512]
   """
+  result = nl.ndarray((64, 512), dtype=lhsT.dtype, buffer=nl.shared_hbm)
+
   # Defining indexes for input LHS.T
   # - Note: here we take LayoutConstraint #1 into account:
   # "For MatMult, contraction axis must be mapped to P-dim"
@@ -53,8 +58,13 @@ def nki_matmul_basic_(lhsT, rhs, result):
   # This dictates which indices to use to address the result tile.
   nl.store(result[i_out_p, i_out_f], value=result_sbuf)
 
+  return result
+  # NKI_EXAMPLE_16_END
+
 
-def nki_matmul_tiled_(lhsT, rhs, result):
+# NKI_EXAMPLE_18_BEGIN
+@nki.jit
+def nki_matmul_tiled_(lhsT, rhs):
   """NKI kernel to compute a matrix multiplication operation in a tiled manner
 
   Args:
@@ -64,12 +74,14 @@ def nki_matmul_tiled_(lhsT, rhs, result):
       rhs: an input tensor of shape [K,N], where K is a multiple of 128, and N
         is a multiple of 512.  It is the right-hand-side argument of the matrix
         multiplication.
+  Returns:
       result: the resulting output tensor of shape [M,N]
   """
 
   K, M = lhsT.shape
   K_, N = rhs.shape
   assert K == K_, "lhsT and rhs must have the same contraction dimension"
+  result = nl.ndarray((M, N), dtype=lhsT.dtype, buffer=nl.shared_hbm)
 
   TILE_M = nl.tile_size.gemm_stationary_fmax  # 128
   TILE_K = nl.tile_size.pmax  # 128
@@ -100,8 +112,13 @@ def nki_matmul_tiled_(lhsT, rhs, result):
       nl.store(result[m * TILE_M:(m + 1) * TILE_M, n * TILE_N:(n + 1) * TILE_N],
                value=res_sb)
 
+  return result
+  # NKI_EXAMPLE_18_END
 
-def nki_matmul_hoist_load_(lhsT, rhs, result):
+
+# NKI_EXAMPLE_19_BEGIN
+@nki.jit
+def nki_matmul_hoist_load_(lhsT, rhs):
   """NKI kernel to compute a matrix multiplication operation in a tiled manner
      while hoisting the load of the lhsT and rhs to outer loops.
 
@@ -112,12 +129,14 @@ def nki_matmul_hoist_load_(lhsT, rhs, result):
       rhs: an input tensor of shape [K,N], where K is a multiple of 128, and N
         is a multiple of 512.  It is the right-hand-side argument of the matrix
         multiplication.
+  Returns:
       result: the resulting output tensor of shape [M,N]
   """
 
   K, M = lhsT.shape
   K_, N = rhs.shape
   assert K == K_, "lhsT and rhs must have the same contraction dimension"
+  result = nl.ndarray((M, N), dtype=lhsT.dtype, buffer=nl.shared_hbm)
 
   TILE_M = nl.tile_size.gemm_stationary_fmax  # 128
   TILE_K = nl.tile_size.pmax  # 128
@@ -163,8 +182,13 @@ def nki_matmul_hoist_load_(lhsT, rhs, result):
       res_sb = nl.copy(res_psum, dtype=result.dtype)
       nl.store(result[m * TILE_M + i_res.p, n * TILE_N + i_res.x], value=res_sb)
 
+  return result
+  # NKI_EXAMPLE_19_END
+
 
-def nki_matmul_block_free_dimension_(lhsT, rhs, result):
+# NKI_EXAMPLE_20_BEGIN
+@nki.jit
+def nki_matmul_block_free_dimension_(lhsT, rhs):
   """NKI kernel to compute a matrix multiplication operation while blocking the
      free dimensions of the LHS and RHS to improve memory access pattern.
 
@@ -175,12 +199,14 @@ def nki_matmul_block_free_dimension_(lhsT, rhs, result):
       rhs: an input tensor of shape [K,N], where K is a multiple of 128, and N
         is a multiple of 512.  It is the right-hand-side argument of the matrix
         multiplication.
+  Returns:
       result: the resulting output tensor of shape [M,N]
   """
 
   K, M = lhsT.shape
   K_, N = rhs.shape
   assert K == K_, "lhsT and rhs must have the same contraction dimension"
+  result = nl.ndarray((M, N), dtype=lhsT.dtype, buffer=nl.shared_hbm)
 
   TILE_M = nl.tile_size.gemm_stationary_fmax  # 128
   TILE_K = nl.tile_size.pmax  # 128
@@ -243,11 +269,15 @@ def nki_matmul_block_free_dimension_(lhsT, rhs, result):
                           (n * TILES_IN_BLOCK_N + bn) * TILE_N + i_res.x],
                    value=res_sb)
 
+  return result
+  # NKI_EXAMPLE_20_END
 
+
+# NKI_EXAMPLE_21_BEGIN
+@nki.jit
 def nki_matmul_fully_optimized_(
     lhsT,
     rhs,
-    result,
     # Meta-parameters
     TILES_IN_BLOCK_M=16,
     TILES_IN_BLOCK_N=2,
@@ -264,13 +294,15 @@ def nki_matmul_fully_optimized_(
       rhs: an input tensor of shape [K,N],  where K is a multiple of 128 *
         TILES_IN_BLOCK_K and N is a multiple of 512 * TILES_IN_BLOCK_N.  It is
         the right-hand-side argument of the matrix multiplication.
-      result: the resulting output tensor of shape [M,N]
       TILES_IN_BLOCK_*: meta parameters to control blocking dimensions
+  Returns:
+      result: the resulting output tensor of shape [M,N]
   """
 
   K, M = lhsT.shape
   K_, N = rhs.shape
   assert K == K_, "lhsT and rhs must have the same contraction dimension"
+  result = nl.ndarray((M, N), dtype=lhsT.dtype, buffer=nl.shared_hbm)
 
   TILE_M = nl.tile_size.gemm_stationary_fmax  # 128
   TILE_K = nl.tile_size.pmax  # 128
@@ -360,16 +392,19 @@ def nki_matmul_fully_optimized_(
                         BLOCK_N * n + i_res_packed.x],
                  value=result_packed[i_res_packed.p, i_res_packed.x])
 
+  return result
+# NKI_EXAMPLE_21_END
+
 
+# NKI_EXAMPLE_23_BEGIN
 if __name__ == "__main__":
   # Benchmarking with large matrices to show the differences more clearly
   lhsT = nt.tensor[[8192, 4096], nl.bfloat16]
   rhs = nt.tensor[[8192, 8192], nl.bfloat16]
-  output = nt.tensor[[4096, 8192], nl.bfloat16]
 
   def benchmark_nki(nki_func):
     bench_func = nki.benchmark(warmup=5, iters=10)(nki_func)
-    bench_func(lhsT, rhs, output)
+    bench_func(lhsT, rhs)
     latency_res = bench_func.benchmark_result.nc_latency
     p99 = latency_res.get_latency_percentile(99)
     print("Latency: {:.2f} ms (P99)".format(p99 / 1000.0))
@@ -385,3 +420,4 @@ def benchmark_nki(nki_func):
 
   print("Benchmarking nki_matmul_fully_optimized")
   benchmark_nki(nki_matmul_fully_optimized_)
+  # NKI_EXAMPLE_23_END
diff --git a/src/tutorials/matrix_multiplication/matrix_multiplication_torch.py b/src/nki_samples/tutorials/matrix_multiplication/matrix_multiplication_torch.py
similarity index 83%
rename from src/tutorials/matrix_multiplication/matrix_multiplication_torch.py
rename to src/nki_samples/tutorials/matrix_multiplication/matrix_multiplication_torch.py
index ec0084c..de39ce8 100644
--- a/src/tutorials/matrix_multiplication/matrix_multiplication_torch.py
+++ b/src/nki_samples/tutorials/matrix_multiplication/matrix_multiplication_torch.py
@@ -7,23 +7,21 @@
 
 import torch
 from torch_xla.core import xla_model as xm
-from torch_neuronx import nki_jit
 
 from matrix_multiplication_nki_kernels import nki_matmul_basic_, nki_matmul_tiled_, nki_matmul_hoist_load_, nki_matmul_block_free_dimension_, nki_matmul_fully_optimized_
 
 if __name__ == "__main__":
 
+  # NKI_EXAMPLE_17_BEGIN
   device = xm.xla_device()
   cpu = torch.device('cpu')
 
   # Test the small workload with basic kernel
   lhs_small = torch.rand((64, 128), dtype=torch.bfloat16, device=device)
   rhs_small = torch.rand((128, 512), dtype=torch.bfloat16, device=device)
-  output_small = torch.zeros((64, 512), dtype=torch.bfloat16, device=device)
 
   # Run NKI kernel
-  nki_matmul_basic_jit = nki_jit(nki_matmul_basic_)
-  nki_matmul_basic_jit(lhs_small.T, rhs_small, output_small)
+  output_small = nki_matmul_basic_(lhs_small.T, rhs_small)
 
   # Run torch reference
   output_small_torch = torch.matmul(lhs_small, rhs_small)
@@ -34,18 +32,18 @@
     print("NKI and Torch match")
   else:
     print("NKI and Torch differ")
+    # NKI_EXAMPLE_17_END
 
+  # NKI_EXAMPLE_22_BEGIN
   # Test the large workload with tiled kernels
   lhs = torch.rand((4096, 1024), dtype=torch.bfloat16, device=device)
   rhs = torch.rand((1024, 2048), dtype=torch.bfloat16, device=device)
-  output = torch.zeros((4096, 2048), dtype=torch.bfloat16, device=device)
 
   # Run torch reference
   output_torch = torch.matmul(lhs, rhs).to(device=cpu)
 
   def check_match(nki_func):
-    jit_func = nki_jit(nki_func)
-    jit_func(lhs.T, rhs, output)
+    output = nki_func(lhs.T, rhs)
     output_nki = output.to(device=cpu)
     if torch.allclose(output_torch, output_nki, atol=1e-4, rtol=1e-2):
       print("NKI and Torch match")
@@ -63,3 +61,4 @@ def check_match(nki_func):
 
   print("Checking correctness of nki_matmul_fully_optimized")
   check_match(nki_matmul_fully_optimized_)
+  # NKI_EXAMPLE_22_END
diff --git a/src/tutorials/rmsnorm/rmsnorm_jax.py b/src/nki_samples/tutorials/rmsnorm/rmsnorm_jax.py
similarity index 84%
rename from src/tutorials/rmsnorm/rmsnorm_jax.py
rename to src/nki_samples/tutorials/rmsnorm/rmsnorm_jax.py
index 5b412d8..f0efc20 100644
--- a/src/tutorials/rmsnorm/rmsnorm_jax.py
+++ b/src/nki_samples/tutorials/rmsnorm/rmsnorm_jax.py
@@ -7,9 +7,9 @@
 
 import jax
 import jax.numpy as jnp
-from jax_neuronx import nki_call
 from rmsnorm_nki_kernels import nki_rmsnorm_kernel
 
+# NKI_EXAMPLE_44_BEGIN
 # Reference JAX implementation
 def jax_rms_norm(a_tensor, g_tensor):
   # Square the tensor (element-wise)
@@ -26,11 +26,7 @@ def jax_rms_norm(a_tensor, g_tensor):
 a_tensor = jax.random.uniform(a_key, (250, 512))
 g_tensor = jax.random.uniform(g_key, (512,))
 
-output_nki = nki_call(
-  nki_rmsnorm_kernel,
-  a_tensor, g_tensor,
-  out_shape=jax.ShapeDtypeStruct(a_tensor.shape, dtype=a_tensor.dtype),
-)
+output_nki = nki_rmsnorm_kernel(a_tensor, g_tensor)
 
 print(a_tensor)
 
@@ -43,3 +39,4 @@ def jax_rms_norm(a_tensor, g_tensor):
   print("NKI and JAX match")
 else:
   print("NKI and JAX differ")
+  # NKI_EXAMPLE_44_END
diff --git a/src/tutorials/rmsnorm/rmsnorm_nki_kernels.py b/src/nki_samples/tutorials/rmsnorm/rmsnorm_nki_kernels.py
similarity index 90%
rename from src/tutorials/rmsnorm/rmsnorm_nki_kernels.py
rename to src/nki_samples/tutorials/rmsnorm/rmsnorm_nki_kernels.py
index 140b682..402eecd 100644
--- a/src/tutorials/rmsnorm/rmsnorm_nki_kernels.py
+++ b/src/nki_samples/tutorials/rmsnorm/rmsnorm_nki_kernels.py
@@ -6,20 +6,23 @@
 """
 
 import numpy as np
+# NKI_EXAMPLE_42_BEGIN
 import math
 import neuronxcc.nki as nki
 import neuronxcc.nki.language as nl
 
 
-def nki_rmsnorm_kernel(a_tensor, g_tensor, out_tensor):
+@nki.jit
+def nki_rmsnorm_kernel(a_tensor, g_tensor):
   # Calculate out_tensor = a_tensor/RMS(a_tensor) * g_tensor
   # Where RMS(a_tensor) = sqrt((1/N) * sum(a_tensor * a_tensor))
   # and N = a_tensor.shape[1]
   # Reduction (mean) is performed in the free (2nd) dimension
+  out_tensor = nl.ndarray(a_tensor.shape, dtype=a_tensor.dtype,
+                          buffer=nl.shared_hbm)
 
   # Make sure shapes match
   assert a_tensor.shape[1] == g_tensor.shape[0]
-  assert a_tensor.shape == out_tensor.shape
 
   # Generate tensor indices to index input tensor
   ix = nl.arange(128)[:, None]
@@ -68,14 +71,15 @@ def nki_rmsnorm_kernel(a_tensor, g_tensor, out_tensor):
     nl.store(out_tensor[i * 128 + ix, iy], value=out_tile,
             mask=(i * 128 + ix < num_rows))
 
+  return out_tensor
+  # NKI_EXAMPLE_42_END
+
 
 if __name__ == "__main__":
   a = np.random.rand(128, 512).astype(np.float32)
   g = np.random.rand(512).astype(np.float32)
-  output_nki  = np.zeros(a.shape, dtype=a.dtype)
 
-  nki_rmsnorm_kernel_baremetal = nki.baremetal(nki_rmsnorm_kernel)
-  nki_rmsnorm_kernel_baremetal(a, g, output_nki)
+  output_nki = nki_rmsnorm_kernel(a, g)
   print(f"output_nki={output_nki}")
 
   # One-line numpy RMSNorm
diff --git a/src/tutorials/rmsnorm/rmsnorm_torch.py b/src/nki_samples/tutorials/rmsnorm/rmsnorm_torch.py
similarity index 82%
rename from src/tutorials/rmsnorm/rmsnorm_torch.py
rename to src/nki_samples/tutorials/rmsnorm/rmsnorm_torch.py
index 71ced3e..c9bfc69 100644
--- a/src/tutorials/rmsnorm/rmsnorm_torch.py
+++ b/src/nki_samples/tutorials/rmsnorm/rmsnorm_torch.py
@@ -5,11 +5,11 @@
 
 """
 
-from torch_neuronx.xla_impl.ops import nki_jit
 import torch
 import os
 from rmsnorm_nki_kernels import nki_rmsnorm_kernel
 
+# NKI_EXAMPLE_43_BEGIN
 # Reference torch implementation
 def torch_rmsnorm_kernel(a_tensor, g_tensor):
   # Square the tensor (element-wise)
@@ -25,13 +25,10 @@ def torch_rmsnorm_kernel(a_tensor, g_tensor):
 from torch_xla.core import xla_model as xm
 device = xm.xla_device()
 
-nki_rmsnorm_kernel = nki_jit(nki_rmsnorm_kernel)
-
 a_tensor = torch.rand((250, 512), dtype=torch.float32).to(device=device)
 g_tensor = torch.rand((512), dtype=torch.float32).to(device=device)
-output_nki = torch.zeros((250, 512), dtype=torch.float32).to(device=device)
 
-nki_rmsnorm_kernel(a_tensor, g_tensor, output_nki)
+output_nki = nki_rmsnorm_kernel(a_tensor, g_tensor)
 print(f"output_nki={output_nki}")
 
 output_torch = torch_rmsnorm_kernel(a_tensor, g_tensor)
@@ -41,3 +38,4 @@ def torch_rmsnorm_kernel(a_tensor, g_tensor):
   print("NKI and Torch match")
 else:
   print("NKI and Torch differ")
+# NKI_EXAMPLE_43_END
diff --git a/src/tutorials/sd_attention/sd_attention_nki_kernels.py b/src/nki_samples/tutorials/sd_attention/sd_attention_nki_kernels.py
similarity index 81%
rename from src/tutorials/sd_attention/sd_attention_nki_kernels.py
rename to src/nki_samples/tutorials/sd_attention/sd_attention_nki_kernels.py
index e5eec25..6d1f781 100644
--- a/src/tutorials/sd_attention/sd_attention_nki_kernels.py
+++ b/src/nki_samples/tutorials/sd_attention/sd_attention_nki_kernels.py
@@ -12,7 +12,9 @@
 import argparse
 from scipy.special import softmax
 
-def fused_self_attn_for_SD_small_head_size(q_ref, k_ref, v_ref, out_ref, use_causal_mask=False,
+# NKI_EXAMPLE_31_BEGIN
+@nki.jit
+def fused_self_attn_for_SD_small_head_size(q_ref, k_ref, v_ref, use_causal_mask=False,
                                            mixed_percision=True):
   """
   Fused self attention kernel for small head dimension Stable Diffusion workload, 
@@ -44,7 +46,7 @@ def fused_self_attn_for_SD_small_head_size(q_ref, k_ref, v_ref, out_ref, use_cau
   # Assume all IO tensors have the same dtype
   kernel_dtype = q_ref.dtype
   pe_in_dt = nl.bfloat16 if mixed_percision else np.float32
-  assert q_ref.dtype == k_ref.dtype == v_ref.dtype == out_ref.dtype
+  assert q_ref.dtype == k_ref.dtype == v_ref.dtype
 
   # Shape checking
   seqlen, d_head = q_ref.shape
@@ -53,7 +55,7 @@ def fused_self_attn_for_SD_small_head_size(q_ref, k_ref, v_ref, out_ref, use_cau
   assert tuple(k_ref.shape) == (seqlen, d_head), 'Input shape mismatch!'
   assert tuple(v_ref.shape) == (seqlen,d_head), \
   f'Input shape mismatch! Expected: {(seqlen, d_head)} Actual: {tuple(v_ref.shape)}'
-  assert tuple(out_ref.shape) == (seqlen, d_head), 'Output shape mismatch!'
+  out_ref = nl.ndarray((seqlen, d_head), dtype=q_ref.dtype, buffer=nl.shared_hbm)
 
   # Softmax scaling factor, multiplied onto Q
   softmax_scale = 0.125
@@ -210,58 +212,61 @@ def fused_self_attn_for_SD_small_head_size(q_ref, k_ref, v_ref, out_ref, use_cau
       out_ref[i_q_seq_tile * q_seq_tile_size + if_out, ip_out],
       value=attn_res_div)
 
+  return out_ref
+# NKI_EXAMPLE_31_END
+
 
 def parse_args():
-    parser = argparse.ArgumentParser("Run Stable Diffusion Attention NKI kernel.")
-    parser.add_argument("--mode",
-                        choices=["accuracy", "perf"],
-                        default="accuracy",
-                        help="""Do accuracy test or perf test.
-                                Accuracy test uses cpu golden output as golden reference.
-                             """)
+  parser = argparse.ArgumentParser("Run Stable Diffusion Attention NKI kernel.")
+  parser.add_argument("--mode",
+                      choices=["accuracy", "perf"],
+                      default="accuracy",
+                      help="""Do accuracy test or perf test.
+                              Accuracy test uses cpu golden output as golden reference.
+                           """)
+
+  args = parser.parse_args()
+  return args
 
-    args = parser.parse_args()
-    return args
 
 def cpu_golden_attn(q, k, v):
-    softmax_scale = 0.125
+  softmax_scale = 0.125
 
-    q_scaled = q * softmax_scale
-    raw_score = np.matmul(q_scaled, k.transpose())
-    norm_score = softmax(raw_score, axis=-1)
+  q_scaled = q * softmax_scale
+  raw_score = np.matmul(q_scaled, k.transpose())
+  norm_score = softmax(raw_score, axis=-1)
 
-    return np.matmul(norm_score, v)
+  return np.matmul(norm_score, v)
 
 
 if __name__ == "__main__":
-    args = parse_args()
-
-    print(f"Running {args.mode} mode.")
-
-    seqlen, d_head = 4096, 64
-    
-    # Set up input tensors
-    dtype = np.float32
-    q_tensor = np.random.rand(seqlen, d_head).astype(dtype)
-    k_tensor = np.random.rand(seqlen, d_head).astype(dtype)
-    v_tensor = np.random.rand(seqlen, d_head).astype(dtype)
-    output_nki = np.empty((seqlen, d_head), dtype=dtype)
-    output_golden = cpu_golden_attn(q_tensor, k_tensor, v_tensor)
-    
-    if args.mode == "accuracy":
-        nki.baremetal(fused_self_attn_for_SD_small_head_size)\
-                        (q_tensor, k_tensor, v_tensor, output_nki)
-        allclose = np.allclose(output_nki, output_golden, atol=1e-5, rtol=1e-3)
-        print(f">>>> SD attention matches CPU reference? {allclose}")
-        assert allclose, "Accuracy check fails!"
-
-    else:
-        benchmark_func = nki.benchmark(fused_self_attn_for_SD_small_head_size,
-                        save_neff_name='file.neff',
-                        save_trace_name='profile.ntff')
-        benchmark_func(q_tensor, k_tensor, v_tensor, output_nki)
-        
-        metrics = benchmark_func.benchmark_result.nc_latency
-        print(">>>> SD attention benchmark results")
-        print("latency.p50 = " + str(metrics.get_latency_percentile(50)))
-        print("latency.p99 = " + str(metrics.get_latency_percentile(99)))
\ No newline at end of file
+  args = parse_args()
+
+  print(f"Running {args.mode} mode.")
+
+  seqlen, d_head = 4096, 64
+
+  # Set up input tensors
+  dtype = np.float32
+  q_tensor = np.random.rand(seqlen, d_head).astype(dtype)
+  k_tensor = np.random.rand(seqlen, d_head).astype(dtype)
+  v_tensor = np.random.rand(seqlen, d_head).astype(dtype)
+  output_nki = np.empty((seqlen, d_head), dtype=dtype)
+  output_golden = cpu_golden_attn(q_tensor, k_tensor, v_tensor)
+
+  if args.mode == "accuracy":
+    output_nki = fused_self_attn_for_SD_small_head_size(q_tensor, k_tensor, v_tensor)
+    allclose = np.allclose(output_nki, output_golden, atol=1e-5, rtol=1e-3)
+    print(f">>>> SD attention matches CPU reference? {allclose}")
+    assert allclose, "Accuracy check fails!"
+
+  else:
+    benchmark_func = nki.benchmark(fused_self_attn_for_SD_small_head_size,
+                                   save_neff_name='file.neff',
+                                   save_trace_name='profile.ntff')
+    benchmark_func(q_tensor, k_tensor, v_tensor)
+
+    metrics = benchmark_func.benchmark_result.nc_latency
+    print(">>>> SD attention benchmark results")
+    print("latency.p50 = " + str(metrics.get_latency_percentile(50)))
+    print("latency.p99 = " + str(metrics.get_latency_percentile(99)))
\ No newline at end of file
diff --git a/src/tutorials/sd_attention/sd_attention_torch.py b/src/nki_samples/tutorials/sd_attention/sd_attention_torch.py
similarity index 79%
rename from src/tutorials/sd_attention/sd_attention_torch.py
rename to src/nki_samples/tutorials/sd_attention/sd_attention_torch.py
index f124607..639e5cf 100644
--- a/src/tutorials/sd_attention/sd_attention_torch.py
+++ b/src/nki_samples/tutorials/sd_attention/sd_attention_torch.py
@@ -5,8 +5,8 @@
 
 """
 
+# NKI_EXAMPLE_32_BEGIN
 import torch
-from torch_neuronx.xla_impl.ops import nki_jit
 from torch_xla.core import xla_model as xm
 
 from sd_attention_nki_kernels import fused_self_attn_for_SD_small_head_size
@@ -28,10 +28,8 @@ def cpu_golden_attn(q, k, v):
   q_tensor = torch.rand((4096, 64), dtype=torch.float32).to(device=device)
   k_tensor = torch.rand((4096, 64), dtype=torch.float32).to(device=device)
   v_tensor = torch.rand((4096, 64), dtype=torch.float32).to(device=device)
-  output_nki = torch.zeros((4096, 64), dtype=torch.float32).to(device=device)
 
-  nki_func = nki_jit(func=fused_self_attn_for_SD_small_head_size)
-  nki_func(q_tensor, k_tensor, v_tensor, output_nki)
+  output_nki = fused_self_attn_for_SD_small_head_size(q_tensor, k_tensor, v_tensor)
 
   output_torch = cpu_golden_attn(q_tensor, k_tensor, v_tensor)
 
@@ -42,4 +40,5 @@ def cpu_golden_attn(q, k, v):
   else:
     print("NKI and Torch differ")
 
-  assert allclose
\ No newline at end of file
+  assert allclose
+  # NKI_EXAMPLE_32_END
diff --git a/src/nki_samples/tutorials/tensor_addition/tensor_addition_jax.py b/src/nki_samples/tutorials/tensor_addition/tensor_addition_jax.py
new file mode 100644
index 0000000..e40f962
--- /dev/null
+++ b/src/nki_samples/tutorials/tensor_addition/tensor_addition_jax.py
@@ -0,0 +1,35 @@
+"""
+Copyright (C) 2024, Amazon.com. All Rights Reserved
+
+JAX implementation for tensor addition NKI tutorial.
+
+"""
+# NKI_EXAMPLE_30_BEGIN
+import jax
+import jax.numpy as jnp
+# NKI_EXAMPLE_30_END
+
+from tensor_addition_nki_kernels import nki_tensor_add
+
+
+# NKI_EXAMPLE_30_BEGIN
+if __name__ == "__main__":
+
+  seed_a, seed_b = jax.random.split(jax.random.PRNGKey(42))
+  a = jax.random.uniform(seed_a, (256, 1024), dtype=jnp.bfloat16)
+  b = jax.random.uniform(seed_b, (256, 1024), dtype=jnp.bfloat16)
+
+  output_nki = nki_tensor_add(a, b)
+  print(f"output_nki={output_nki}")
+
+  output_jax = a + b
+  print(f"output_jax={output_jax}")
+
+  allclose = jnp.allclose(output_jax, output_nki, atol=1e-4, rtol=1e-2)
+  if allclose:
+    print("NKI and JAX match")
+  else:
+    print("NKI and JAX differ")
+
+  assert allclose
+  # NKI_EXAMPLE_30_END
diff --git a/src/tutorials/tensor_addition/tensor_addition_nki_kernels.py b/src/nki_samples/tutorials/tensor_addition/tensor_addition_nki_kernels.py
similarity index 76%
rename from src/tutorials/tensor_addition/tensor_addition_nki_kernels.py
rename to src/nki_samples/tutorials/tensor_addition/tensor_addition_nki_kernels.py
index 2b49237..ea72488 100644
--- a/src/tutorials/tensor_addition/tensor_addition_nki_kernels.py
+++ b/src/nki_samples/tutorials/tensor_addition/tensor_addition_nki_kernels.py
@@ -5,20 +5,26 @@
 
 """
 import numpy as np
+# NKI_EXAMPLE_27_BEGIN
 import neuronxcc.nki as nki
 import neuronxcc.nki.language as nl
 
 
-def nki_tensor_add_kernel_(a_input, b_input, c_output):
+@nki.jit
+def nki_tensor_add_kernel_(a_input, b_input):
   """NKI kernel to compute element-wise addition of two input tensors
 
-  This kernel assumes strict input/output tile-sizes, of up-to [128,512]
+  This kernel assumes strict input/output sizes can be uniformly tiled to [128,512]
 
   Args:
-      a_input: a first input tensor, of shape [128,512]
-      b_input: a second input tensor, of shape [128,512]
-      c_output: an output tensor, of shape [128,512]
+      a_input: a first input tensor
+      b_input: a second input tensor
+
+  Returns:
+      c_output: an output tensor
   """
+  # Create output tensor shared between all SPMD instances as result tensor
+  c_output = nl.ndarray(a_input.shape, dtype=a_input.dtype, buffer=nl.shared_hbm)
 
   # Calculate tile offsets based on current 'program'
   offset_i_x = nl.program_id(0) * 128
@@ -39,7 +45,12 @@ def nki_tensor_add_kernel_(a_input, b_input, c_output):
   # store the addition results back to device memory (c_output)
   nl.store(c_output[ix, iy], value=c_tile)
 
+  # Transfer the ownership of `c_output` to the caller
+  return c_output
+  # NKI_EXAMPLE_27_END
+
 
+# NKI_EXAMPLE_28_BEGIN
 def nki_tensor_add(a_input, b_input):
   """NKI kernel caller to compute element-wise addition of two input tensors
 
@@ -57,12 +68,9 @@ def nki_tensor_add(a_input, b_input):
   # In this case, we use a 2D grid where the size of each invocation is 128x512
   grid_x = a_input.shape[0] // 128
   grid_y = a_input.shape[1] // 512
-  c_output = np.zeros(a_input.shape, dtype=a_input.dtype)
-
-  nki_tensor_add_kernel_baremetal = nki.baremetal(nki_tensor_add_kernel_)
-  nki_tensor_add_kernel_baremetal[grid_x, grid_y](a_input, b_input, c_output)
 
-  return c_output
+  return nki_tensor_add_kernel_[grid_x, grid_y](a_input, b_input)
+  # NKI_EXAMPLE_28_END
 
 
 if __name__ == "__main__":
diff --git a/src/nki_samples/tutorials/tensor_addition/tensor_addition_torch.py b/src/nki_samples/tutorials/tensor_addition/tensor_addition_torch.py
new file mode 100644
index 0000000..83673e5
--- /dev/null
+++ b/src/nki_samples/tutorials/tensor_addition/tensor_addition_torch.py
@@ -0,0 +1,35 @@
+"""
+Copyright (C) 2024, Amazon.com. All Rights Reserved
+
+PyTorch implementation for tensor addition NKI tutorial.
+
+"""
+# NKI_EXAMPLE_29_BEGIN
+import torch
+from torch_xla.core import xla_model as xm
+# NKI_EXAMPLE_29_END
+
+from tensor_addition_nki_kernels import nki_tensor_add
+
+
+# NKI_EXAMPLE_29_BEGIN
+if __name__ == "__main__":
+  device = xm.xla_device()
+
+  a = torch.rand((256, 1024), dtype=torch.bfloat16).to(device=device)
+  b = torch.rand((256, 1024), dtype=torch.bfloat16).to(device=device)
+
+  output_nki = nki_tensor_add(a, b)
+  print(f"output_nki={output_nki}")
+
+  output_torch = a + b
+  print(f"output_torch={output_torch}")
+
+  allclose = torch.allclose(output_torch, output_nki, atol=1e-4, rtol=1e-2)
+  if allclose:
+    print("NKI and Torch match")
+  else:
+    print("NKI and Torch differ")
+
+  assert allclose
+  # NKI_EXAMPLE_29_END
diff --git a/src/tutorials/transpose2d/transpose2d_jax.py b/src/nki_samples/tutorials/transpose2d/transpose2d_jax.py
similarity index 65%
rename from src/tutorials/transpose2d/transpose2d_jax.py
rename to src/nki_samples/tutorials/transpose2d/transpose2d_jax.py
index 024782c..f23ceef 100644
--- a/src/tutorials/transpose2d/transpose2d_jax.py
+++ b/src/nki_samples/tutorials/transpose2d/transpose2d_jax.py
@@ -5,25 +5,18 @@
 
 """
 
+# NKI_EXAMPLE_36_BEGIN
 import jax
 import jax.numpy as jnp
-from functools import partial
-from jax_neuronx import nki_call
+# NKI_EXAMPLE_36_END
 
 from transpose2d_nki_kernels import tensor_transpose2D_kernel_
 
-
-def transpose2D(in_tensor, shape2D):
-  return nki_call(
-    partial(tensor_transpose2D_kernel_, shape2D=shape2D),
-    in_tensor,
-    out_shape=jax.ShapeDtypeStruct(in_tensor.shape, dtype=in_tensor.dtype)
-  )
-
+# NKI_EXAMPLE_36_BEGIN
 if __name__ == "__main__":
   P, X, Y = 5, 37, 44
   a = jax.random.uniform(jax.random.PRNGKey(42), (P, X * Y))
-  a_t_nki = transpose2D(a, (X, Y))
+  a_t_nki = tensor_transpose2D_kernel_(a, shape2D=(X, Y))
 
   a_t_jax = jnp.transpose(a.reshape(P, X, Y), axes=(0, 2, 1)).reshape(P, X * Y)
   print(a, a_t_nki, a_t_jax)
@@ -35,3 +28,4 @@ def transpose2D(in_tensor, shape2D):
     print("NKI and JAX differ")
 
   assert allclose
+# NKI_EXAMPLE_36_END
diff --git a/src/tutorials/transpose2d/transpose2d_nki_kernels.py b/src/nki_samples/tutorials/transpose2d/transpose2d_nki_kernels.py
similarity index 90%
rename from src/tutorials/transpose2d/transpose2d_nki_kernels.py
rename to src/nki_samples/tutorials/transpose2d/transpose2d_nki_kernels.py
index d993c7e..171e6ed 100644
--- a/src/tutorials/transpose2d/transpose2d_nki_kernels.py
+++ b/src/nki_samples/tutorials/transpose2d/transpose2d_nki_kernels.py
@@ -5,11 +5,13 @@
 """
 
 import numpy as np
+# NKI_EXAMPLE_33_BEGIN
 import neuronxcc.nki as nki
 import neuronxcc.nki.language as nl
 
 
-def tensor_transpose2D_kernel_(in_tensor, out_tensor, shape2D):
+@nki.jit
+def tensor_transpose2D_kernel_(in_tensor, shape2D):
   """
   NKI kernel to reorder the elements on axis[1] of the input tensor.
 
@@ -36,6 +38,8 @@ def tensor_transpose2D_kernel_(in_tensor, out_tensor, shape2D):
     shape2D: tuple representing the dimensions to be transposed: (#rows, #cols)
     out_tensor: an output (transposed) tensor
   """
+  out_tensor = nl.ndarray(in_tensor.shape, dtype=in_tensor.dtype,
+                          buffer=nl.shared_hbm)
   # Gather input shapes
   sz_p, _ = in_tensor.shape
 
@@ -64,14 +68,15 @@ def tensor_transpose2D_kernel_(in_tensor, out_tensor, shape2D):
   # Finally, we store out_tile to external memory
   nl.store(out_tensor, value=out_tile)
 
+  return out_tensor
+  # NKI_EXAMPLE_33_END
+
 
 if __name__ == "__main__":
   P, X, Y = 5, 3, 4
   a = np.arange(P*X*Y, dtype=np.int8).reshape((P, X*Y))
-  a_t_nki = np.zeros((P, Y*X), dtype=np.int8)
 
-  tensor_transpose2D_kernel_torch = nki.baremetal(tensor_transpose2D_kernel_)
-  tensor_transpose2D_kernel_torch(a, a_t_nki, (X, Y))
+  a_t_nki = tensor_transpose2D_kernel_(a, (X, Y))
 
   a_t_np = np.transpose(a.reshape(P, X, Y), (0, 2, 1)).reshape(P, X * Y)
 
diff --git a/src/tutorials/transpose2d/transpose2d_torch.py b/src/nki_samples/tutorials/transpose2d/transpose2d_torch.py
similarity index 82%
rename from src/tutorials/transpose2d/transpose2d_torch.py
rename to src/nki_samples/tutorials/transpose2d/transpose2d_torch.py
index 71083d7..61fe367 100644
--- a/src/tutorials/transpose2d/transpose2d_torch.py
+++ b/src/nki_samples/tutorials/transpose2d/transpose2d_torch.py
@@ -4,13 +4,15 @@
 PyTorch implementation for transpose2d NKI tutorial.
 """
 
+# NKI_EXAMPLE_34_BEGIN
 import torch
 from torch_xla.core import xla_model as xm
-from torch_neuronx import nki_jit
+# NKI_EXAMPLE_34_END
 
 from transpose2d_nki_kernels import tensor_transpose2D_kernel_
 
 
+# NKI_EXAMPLE_34_BEGIN
 if __name__ == "__main__":
   device = xm.xla_device()
 
@@ -18,8 +20,7 @@
   a = torch.arange(P*X*Y, dtype=torch.int8).reshape((P, X*Y)).to(device=device)
   a_t_nki = torch.zeros((P, Y*X), dtype=torch.int8).to(device=device)
 
-  tensor_transpose2D_kernel_torch = nki_jit(tensor_transpose2D_kernel_)
-  tensor_transpose2D_kernel_torch(a, a_t_nki, (X, Y))
+  a_t_nki = tensor_transpose2D_kernel_(a, (X, Y))
 
   a_t_torch = torch.transpose(a.reshape(P, X, Y), 1, 2).reshape(P, X * Y)
 
@@ -32,3 +33,4 @@
     print("NKI and PyTorch differ")
 
   assert allclose
+  # NKI_EXAMPLE_34_END
diff --git a/src/reference/__init__.py b/src/reference/__init__.py
deleted file mode 100644
index ad4a18a..0000000
--- a/src/reference/__init__.py
+++ /dev/null
@@ -1,12 +0,0 @@
-# Copyright (c) 2023, Amazon.com. All Rights Reserved
-
-"""
-Package containing public kernels for Neuron Kernel Interface (NKI).
-
-Kernels here are the same to the ones available in the 
-NKI Github Sample Repo.
-
-TODO: Insert link to Github Repo when available
-"""
-from neuronxcc.nki.kernels.attention import fused_self_attn_for_SD_small_head_size, flash_attn_bwd, flash_fwd
-from neuronxcc.nki.kernels.vision import resize_nearest_fixed_dma_kernel, select_and_scatter_kernel
diff --git a/src/reference/attention.py b/src/reference/attention.py
deleted file mode 100644
index 81704b5..0000000
--- a/src/reference/attention.py
+++ /dev/null
@@ -1,1031 +0,0 @@
-"""
-Copyright (c) 2023, Amazon.com. All Rights Reserved
-
-kernels - Builtin high performance attention kernels
-
-"""
-import numpy as np
-
-from neuronxcc.nki import trace
-import neuronxcc.nki.isa as nisa
-import neuronxcc.nki.language as nl
-
-from neuronxcc.nki.language import par_dim
-from dataclasses import dataclass
-
-def div_ceil(n, d):
-  return (n + d - 1) // d
-
-@dataclass(frozen=True)
-class FlashConfig:
-  """
-    Config class for flash attention with default values
-  """
-  seq_tile_size:int = 2048
-  training:bool = True
-  should_transpose_v:bool = False
-
-  __annotations__ = {
-    'seq_tile_size': int,
-    'training': bool,
-    'should_transpose_v': bool
-  }
-
-@trace
-def _flash_attention_core(q_local_tile, k, v,
-                          q_h_per_k_h,
-                          o_buffer, l_buffer, m_buffer,
-                          batch_id, head_id, gqa_head_idx, q_tile_idx,
-                          local_k_large_tile_idx,
-                          kernel_dtype, acc_type,
-                          flash_config: FlashConfig,
-                          olm_buffer_idx=None,
-                          global_k_large_tile_idx=None,
-                          use_causal_mask=False, initialize=False,
-                          B_P_SIZE=128, B_F_SIZE=512, B_D_SIZE=128,
-                          dropout_p=0.0, dropout_p_tensor=None, seed_tensor=None
-                          ):
-  """
-  The flash attention core function to calcualte self attention between a tile of q and a block of K and V.
-  The q_local_tile has (B_P_SIZE, B_F_SIZE), which is loaded into the SBUF already. The block size of K and V
-  is defined in the seq_tile_size of the flash_config. The results are stored in the following there buffers
-  o_buffer: (num_large_k_tile, B_P_SIZE, d)
-  l_buffer: (num_large_k_tile, B_P_SIZE, 1)
-  m_buffer: (num_large_k_tile, B_P_SIZE, 1)
-  """
-  LARGE_TILE_SZ = flash_config.seq_tile_size
-  REDUCTION_TILE = min(2048, LARGE_TILE_SZ // 2)
-  num_k_tile_per_large_tile = LARGE_TILE_SZ // B_F_SIZE
-  seq_len = k.shape[-1]
-  seq_q_num_tiles = seq_len // B_P_SIZE
-
-  # Indices used by the distributed attention
-  if global_k_large_tile_idx is None:
-    global_k_large_tile_idx = local_k_large_tile_idx
-  if olm_buffer_idx is None:
-    olm_buffer_idx = local_k_large_tile_idx
-
-  i_q_p = nl.arange(B_P_SIZE)[:, None]
-  i_q_f = nl.arange(B_F_SIZE)[None, :]
-  i_d_p = nl.arange(B_D_SIZE)[:, None]
-  i_d_f = nl.arange(B_D_SIZE)[None, :]
-  i_f_128 = nl.arange(B_P_SIZE)[None, :]
-  i_f_k_tiles = nl.arange(num_k_tile_per_large_tile)[None, :]
-
-  # mask are used to only apply computation to the lower half of the matrix,
-  # which reduce the arthimetic intensity by half
-  forward_mask = q_tile_idx * B_P_SIZE >= global_k_large_tile_idx * LARGE_TILE_SZ if use_causal_mask else None
-  # Negation mask is the negation of `forward_mask`, which is used for the
-  # instructions executed on the blocks in the upper triangular section
-  # of the matrix.
-  # These instructions should not be executed when causual mask is disabled.
-  #
-  # For example, the o_buffer still needs to be propagated from o[j-1] to o[j] in
-  # the upper triangular of the matrix.
-  negation_mask = q_tile_idx * B_P_SIZE < global_k_large_tile_idx * LARGE_TILE_SZ if use_causal_mask else None
-
-  qk_res_buf = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ), buffer=nl.sbuf, dtype=acc_type)
-  max_local = nl.ndarray((par_dim(B_P_SIZE), num_k_tile_per_large_tile), dtype=acc_type)
-  for k_i in nl.affine_range(num_k_tile_per_large_tile):
-    qk_psum = nl.zeros((par_dim(B_P_SIZE), B_F_SIZE),
-                        dtype=np.float32, buffer=nl.psum)  # (128, 512)
-    multiplication_required_selection = global_k_large_tile_idx * LARGE_TILE_SZ + k_i * B_F_SIZE <= q_tile_idx * B_P_SIZE if use_causal_mask else None
-    qk_psum[i_q_p, i_q_f] += nl.matmul(q_local_tile, k[i_d_p, k_i * B_F_SIZE + i_q_f], transpose_x=True,
-                                       mask=multiplication_required_selection) # (p(128), 512)
-
-    if use_causal_mask:
-      left_diagonal_selection = q_tile_idx * B_P_SIZE >= global_k_large_tile_idx * LARGE_TILE_SZ + (k_i + 1) * B_F_SIZE
-      diagonal_and_right_selection = (q_tile_idx * B_P_SIZE < global_k_large_tile_idx * LARGE_TILE_SZ + (k_i + 1) * B_F_SIZE) & forward_mask
-
-      q_pos = q_tile_idx * B_P_SIZE + i_q_p
-      k_pos = global_k_large_tile_idx * LARGE_TILE_SZ + k_i * B_F_SIZE + i_q_f
-      pred = q_pos >= k_pos
-      # For tiles on and on the right of the diagonal, need to do affine_select.
-      # Magic number -9984.0 to replace -inf similar to what Tensorizer uses
-      qk_res_buf[i_q_p, k_i * B_F_SIZE + i_q_f] = nisa.affine_select(
-        pred=pred,
-        on_true_tile=qk_psum[i_q_p, i_q_f], on_false_value=-9984.0, dtype=kernel_dtype,
-        mask=diagonal_and_right_selection)
-
-      # For tiles on the left of the diagonal, direct copy, no select required.
-      qk_res_buf[i_q_p, k_i * B_F_SIZE + i_q_f] = \
-        nl.copy(qk_psum[i_q_p, i_q_f], dtype=kernel_dtype, mask=left_diagonal_selection)
-    else:
-      # Simply send psum result back to sbuf
-      qk_res_buf[i_q_p, k_i * B_F_SIZE + i_q_f] = \
-        nl.copy(qk_psum[i_q_p, i_q_f], dtype=kernel_dtype)
-
-    # Calculate max of the current tile
-    max_local[i_q_p, k_i] = nisa.tensor_reduce(np.max, qk_res_buf[i_q_p, k_i * B_F_SIZE + i_q_f], axis=(1,),
-                                        dtype=acc_type, negate=False, mask=forward_mask)
-
-  max_ = nisa.tensor_reduce(np.max, max_local[i_q_p, i_f_k_tiles], axis=(1, ),
-                    dtype=acc_type, negate=False, mask=forward_mask)
-  if not initialize:
-    m_previous = nl.copy(m_buffer[olm_buffer_idx - 1, i_q_p, 0])
-    m_buffer[olm_buffer_idx, i_q_p, 0] = nl.maximum(m_previous, max_, mask=forward_mask) # (128,1)
-    if use_causal_mask:
-      m_buffer[olm_buffer_idx, i_q_p, 0] = nl.copy(m_previous, mask=negation_mask)
-
-    m_current = m_buffer[olm_buffer_idx, i_q_p, 0]
-    # Compute scaling factor
-    alpha = nisa.activation(np.exp, m_previous, bias=-1*m_current, scale=1.0, mask=forward_mask)
-    o_previous = nl.copy(o_buffer[olm_buffer_idx-1, i_q_p, i_d_f], mask=forward_mask)
-    o_previous_scaled = nl.multiply(o_previous, alpha, mask=forward_mask)
-  else:
-    m_buffer[0, i_q_p, 0] = nl.copy(max_)
-    m_current = max_
-
-  p_local = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ), dtype=kernel_dtype)
-  i_r_f = nl.arange(REDUCTION_TILE)[None,: ]
-  p_partial_sum = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ // REDUCTION_TILE), dtype=acc_type)
-  for k_r_i in nl.affine_range(LARGE_TILE_SZ // REDUCTION_TILE):
-    # compute exp(qk-max)
-    p_local[i_q_p, k_r_i * REDUCTION_TILE + i_r_f] = \
-      nisa.activation(np.exp,
-                      qk_res_buf[i_q_p, k_r_i * REDUCTION_TILE + i_r_f],
-                      bias=-1 * m_current,
-                      scale=1.0,
-                      dtype=kernel_dtype,
-                      mask=forward_mask)
-
-    # dropout
-    if dropout_p > 0.0:
-      for k_d_i in nl.sequential_range(REDUCTION_TILE // B_F_SIZE):
-        offset = k_d_i + k_r_i * (REDUCTION_TILE // B_F_SIZE) \
-                  + global_k_large_tile_idx * (LARGE_TILE_SZ // B_F_SIZE) \
-                  + q_tile_idx * (seq_len // B_F_SIZE) \
-                  + (head_id * q_h_per_k_h + gqa_head_idx) * (seq_len // B_F_SIZE) * seq_q_num_tiles \
-                  + batch_id * nl.num_programs(1) * (seq_len // B_F_SIZE) * seq_q_num_tiles
-        offset_seed = nl.add(seed_tensor[0, 0], offset, mask=forward_mask)
-        nl.random_seed(seed=offset_seed, mask=forward_mask)
-        softmax_dropout = nl.dropout(p_local[i_q_p, k_r_i * REDUCTION_TILE + k_d_i * B_F_SIZE + i_q_f],
-                                    rate=dropout_p_tensor[i_q_p, 0],
-                                    mask=forward_mask)
-        p_local[i_q_p, k_r_i * REDUCTION_TILE + k_d_i * B_F_SIZE + i_q_f] = \
-          nl.multiply(softmax_dropout, 1 / (1 - dropout_p), mask=forward_mask)
-
-    # Compute partial row-tile sum of exp(qk-max))
-    p_partial_sum[i_q_p, k_r_i] = nl.sum(p_local[i_q_p, k_r_i * REDUCTION_TILE + i_r_f], axis=1, dtype=acc_type, mask=forward_mask)
-
-  p_local_transposed = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ), dtype=kernel_dtype)
-  for i_p_t in nl.affine_range(LARGE_TILE_SZ // 512):
-    p_local_t_tmp = nl.ndarray((par_dim(B_P_SIZE), 512), buffer=nl.psum, dtype=np.float32)
-    for i_p_t_local in nl.affine_range(512//128):
-      p_local_t_tmp[i_q_p, i_p_t_local*128 + i_f_128] = nisa.nc_transpose(p_local[i_q_p, i_p_t*512+i_p_t_local * B_P_SIZE + i_f_128])
-    i_f_512 = nl.arange(512)[None, :]
-    p_local_transposed[i_q_p, i_p_t * 512 + i_f_512 ] = nl.copy(p_local_t_tmp[i_q_p, i_f_512], dtype=kernel_dtype)
-
-  ps = nl.sum(p_partial_sum, axis=1, dtype=acc_type, mask=forward_mask)
-  pv_psum = nl.zeros((par_dim(B_P_SIZE), B_D_SIZE), dtype=np.float32, buffer=nl.psum)
-  for k_i in nl.affine_range(LARGE_TILE_SZ // B_P_SIZE):
-    pv_psum[i_q_p, i_d_f] += nl.matmul(p_local_transposed[i_q_p, k_i * B_P_SIZE + i_f_128],
-                                       v[k_i, i_q_p, i_d_f],
-                                       transpose_x=True,
-                                       mask=forward_mask) # (128, 128) (p(Br), d)
-
-  if initialize:
-    o_buffer[olm_buffer_idx, i_q_p, i_d_f] = nl.copy(pv_psum[i_q_p, i_d_f])
-    l_buffer[olm_buffer_idx, i_q_p, 0] = nl.add(nl.log(ps), max_)
-  else:
-    if use_causal_mask:
-      o_buffer[olm_buffer_idx, i_q_p, i_d_f] = nl.copy(o_buffer[olm_buffer_idx-1, i_q_p, i_d_f], mask=negation_mask)
-    o_buffer[olm_buffer_idx, i_q_p, i_d_f] = nl.add(o_previous_scaled, pv_psum, mask=forward_mask)
-
-    l_prev = l_buffer[olm_buffer_idx-1, i_q_p, 0]
-    l_exp = nl.add(nl.exp(nl.subtract(l_prev, m_current, mask=forward_mask), mask=forward_mask), ps, mask=forward_mask)
-    l_buffer[olm_buffer_idx, i_q_p, 0] = nl.add(m_current, nl.log(l_exp, mask=forward_mask), mask=forward_mask)
-    if use_causal_mask:
-      l_buffer[olm_buffer_idx, i_q_p, 0] = nl.copy(l_buffer[olm_buffer_idx-1, i_q_p, 0], mask=negation_mask)
-
-
-def flash_fwd(q, k, v, seed, o, lse=None,
-              softmax_scale=None,
-              use_causal_mask=True,
-              mixed_precision=True,
-              dropout_p=0.0, config=None):
-  """
-  Flash Attention Forward kernel
-
-  IO tensor layouts:
-    - q: shape   (bs, n_heads, d, seq_q)
-    - k: shape   (bs, nk_heads, d, seq_k)
-    - v: shape   (bs, nv_heads, d, seq_v) if config.should_transpose_v  else (bs, nv_heads, seq_v, d)
-    - seed: shape (1,)
-    - o: shape (bs, n_heads, seq_q, d)
-    - lse: shape (bs, nheads, nl.tile_size.pmax, seq // nl.tile_size.pmax) if training else None
-    - We use seq_q and seq_k just for clarity, this kernel requires seq_q == seq_k
-
-  IO tensor dtypes:
-    - This kernel assumes all IO tensors have the same dtype
-    - If mixed_percision is True, then all Tensor Engine operation will be performed in
-      bfloat16 and accumulation will be performed in float32. Otherwise the intermediates
-      will be in the same type as the inputs.
-
-  Compile-time Constants:
-    - softmax_scale: scaling for softmax, is None, default is `1.0/(d**0.5)`
-    - mixed_precision: flag to set non-matmul ops in fp32 precision, defualt is set to `true`, if false, we use same precision as input types
-    - causal_mask: flag to set causal masking
-    - config: Instance of dataclass :class:`nki.kernels.attention.FlashConfig` with Performance config parameters for flash attention with default values
-        seq_tile_size: `default=2048`, size of the kv tile size for attention computation reduction
-        training: bool to indicate training vs inference `default=True`
-
-  Performance Notes:
-    For better performance, the kernel is tiled to be of size `LARGE_TILE_SZ`, and Flash attention math techniques are applied in unit
-    of `LARGE_TILE_SZ`. Seqlen that is not divisible by `LARGE_TILE_SZ` is not supported at the moment.
-
-  GQA support Notes:
-    the spmd kernel for launching kernel should be on kv_heads instead of nheads
-
-  Example usage:
-    MHA: q: [b, h, d, s], k: [b, h, d, s], v: [b, h, s, d]
-      usage: `flash_fwd[b, h](q, k, v, ...)`
-    GQA: q: [b, h, d, s], k: [b, kv_h, d, s], v: [b, kv_h, s, d]
-      usage: `flash_fwd[b, kv_h](q, k, v, ...)`
-  """
-  config = config or FlashConfig()
-  B_F_SIZE=512
-  B_P_SIZE=128
-  b , h, d, n  = q.shape
-  B_D_SIZE = d
-  k_h = k.shape[1]
-  v_shape = v.shape
-  if config.should_transpose_v:
-    assert tuple(v_shape) == (b, k_h, d, n), f"V shape does not match layout requirements, expect: {(b, k_h, d, n)} but got {v_shape}"
-    assert tuple(k.shape) == (b, k_h, d, n), f" k and v shape does not match the layout defined in the function, but got {k.shape}"
-  else:
-    assert tuple(v_shape) == (b, k_h, n, d), f"V shape does not match layout requirements, expect: {(b, k_h, n, d)} but got {v_shape}"
-    assert tuple(k.shape) == (b,k_h, d, n), f" k and v shape does not match the layout defined in the function, but got {k.shape}"
-  assert d <= 128, f" we do not support head_dim > 128, got head dim {d}"
-  kernel_dtype = nl.bfloat16 if mixed_precision else q.dtype
-  acc_type =  np.dtype(np.float32) if mixed_precision else kernel_dtype
-
-  i_q_p = nl.arange(B_P_SIZE)[:,None]
-  i_0_f = nl.arange(1)[None, :]
-  n_tile_q = n//B_P_SIZE # since q will be loaded on PE
-
-  batch_id = nl.program_id(axis=0)
-  head_id = nl.program_id(axis=1)
-  softmax_scale = softmax_scale or (1.0 / (d ** 0.5))
-
-  LARGE_TILE_SZ = config.seq_tile_size
-  # FIXME: Add masking for different seqlen values.
-  assert config.seq_tile_size >= 512, f" seq tile_size {config.seq_tile_size} cannot be less than 512"
-  assert n % LARGE_TILE_SZ == 0, f"seqlen is not divisible by {LARGE_TILE_SZ}"
-  num_large_k_tile = n // LARGE_TILE_SZ
-
-  # inference flag, check if lse is none
-  inference = not(config.training)
-  if inference:
-    assert lse is None, "lse should be none for inference"
-    assert seed is None, f"seed should be None for inference, but got {seed}"
-    assert dropout_p==0.0, f"dropout should be 0.0 for inference but got {dropout_p}"
-  else:
-    assert lse is not None, "lse should not be none for training"
-  q_h_per_k_h = h // k_h
-
-  if dropout_p > 0.0 and not inference:
-    seed_local = nl.load(seed[0])
-    # TODO: Remove this once the dropout supports scale prob
-    dropout_p_tensor = nl.full((B_P_SIZE, 1), fill_value=dropout_p, dtype=np.float32)
-  else:
-    dropout_p_tensor = None
-    seed_local = None
-
-  for i_q_h in nl.affine_range(q_h_per_k_h):
-
-    # =============== Global Flash Attention accumulators ====================== #
-    o_buffer = nl.full((n_tile_q, num_large_k_tile, par_dim(B_P_SIZE), d), 0.0, dtype=acc_type, buffer=nl.sbuf)
-    l_buffer = nl.full((n_tile_q, num_large_k_tile, par_dim(B_P_SIZE), 1), 0.0, dtype=acc_type, buffer=nl.sbuf)
-    m_buffer = nl.full((n_tile_q, num_large_k_tile, par_dim(B_P_SIZE), 1), 0.0, dtype=acc_type)
-    # =============== Global Flash Attention accumulators END ================== #
-
-    j = 0
-    cur_k_tile = nl.ndarray((par_dim(B_D_SIZE), LARGE_TILE_SZ), dtype=kernel_dtype)
-    cur_v_tile = nl.ndarray((LARGE_TILE_SZ//B_P_SIZE, par_dim(B_P_SIZE), B_D_SIZE), dtype=kernel_dtype)
-    load_tile_size = B_P_SIZE
-    for k_i in nl.affine_range(LARGE_TILE_SZ // load_tile_size):
-      load_p = nl.arange(B_D_SIZE)[:, None]
-      load_f = nl.arange(load_tile_size)[None, :]
-      cur_k_tile[load_p, load_tile_size*k_i+load_f] = nl.load(
-        k[batch_id, head_id, load_p, load_tile_size*k_i+load_f]
-      )
-    if config.should_transpose_v:
-      for v_i in nl.affine_range(LARGE_TILE_SZ // load_tile_size):
-        load_p = nl.arange(B_D_SIZE)[:, None]
-        load_f = nl.arange(B_P_SIZE)[None, :]
-
-        loaded = nl.load(v[batch_id, head_id, load_p, B_P_SIZE*v_i+load_f], dtype=kernel_dtype)
-        store_p = nl.arange(B_P_SIZE)[:, None]
-        store_f = nl.arange(B_D_SIZE)[None, :]
-        cur_v_tile[v_i, store_p, store_f] = nisa.nc_transpose(loaded)
-    else:
-      for v_i in nl.affine_range(LARGE_TILE_SZ // load_tile_size):
-        load_p = nl.arange(B_P_SIZE)[:, None]
-        load_f = nl.arange(B_D_SIZE)[None, :]
-
-        cur_v_tile[v_i, load_p, load_f] = nl.load(v[batch_id, head_id, B_P_SIZE*v_i+load_p, load_f], dtype=kernel_dtype)
-
-    for i in nl.affine_range(n_tile_q):
-      i_f_128 = nl.arange(B_P_SIZE)[None, :]
-      i_f_d = nl.arange(B_D_SIZE)[None, :]
-      i_p_d = nl.arange(B_D_SIZE)[:,None]
-      q_tile = nl.ndarray((B_D_SIZE, B_P_SIZE),dtype=kernel_dtype)
-      q_tile[i_p_d, i_f_128] = nl.load(q[batch_id, head_id * q_h_per_k_h + i_q_h, i_p_d, i*B_P_SIZE+i_f_128], dtype=kernel_dtype) \
-                                * softmax_scale # load (d, 128) tile in SBUF
-      # handle first tile and compute max and lse explicitly by passing initialize=True
-      _flash_attention_core(q_local_tile=q_tile, k=cur_k_tile, v=cur_v_tile,
-                            q_h_per_k_h=q_h_per_k_h,
-                            o_buffer=o_buffer[i], l_buffer=l_buffer[i], m_buffer=m_buffer[i],
-                            batch_id=batch_id, head_id=head_id,
-                            gqa_head_idx=i_q_h, q_tile_idx=i, local_k_large_tile_idx=0,
-                            kernel_dtype=kernel_dtype, acc_type=acc_type,
-                            flash_config=config, use_causal_mask=use_causal_mask,
-                            initialize=True,
-                            B_P_SIZE=B_P_SIZE, B_F_SIZE=B_F_SIZE, B_D_SIZE=B_D_SIZE,
-                            dropout_p=dropout_p, dropout_p_tensor=dropout_p_tensor, seed_tensor=seed_local)
-
-    for j in nl.sequential_range(1, num_large_k_tile):
-      cur_k_tile = nl.ndarray((par_dim(B_D_SIZE), LARGE_TILE_SZ), dtype=kernel_dtype)
-      cur_v_tile = nl.ndarray((LARGE_TILE_SZ//B_P_SIZE, par_dim(B_P_SIZE), B_D_SIZE), dtype=kernel_dtype)
-      load_tile_size = B_P_SIZE
-      for k_i in nl.affine_range(LARGE_TILE_SZ // load_tile_size):
-        load_p = nl.arange(B_D_SIZE)[:, None]
-        load_f = nl.arange(load_tile_size)[None, :]
-        cur_k_tile[load_p, load_tile_size*k_i+load_f] = nl.load(
-          k[batch_id, head_id, load_p, j*LARGE_TILE_SZ+load_tile_size*k_i+load_f]
-        )
-      if config.should_transpose_v:
-        for v_i in nl.affine_range(LARGE_TILE_SZ // load_tile_size):
-          load_p = nl.arange(B_D_SIZE)[:, None]
-          load_f = nl.arange(B_P_SIZE)[None, :]
-
-          loaded = nl.load(v[batch_id, head_id, load_p, j*LARGE_TILE_SZ+B_P_SIZE*v_i+load_f], dtype=kernel_dtype)
-          store_p = nl.arange(B_P_SIZE)[:, None]
-          store_f = nl.arange(B_D_SIZE)[None, :]
-          cur_v_tile[v_i, store_p, store_f] = nisa.nc_transpose(loaded)
-      else:
-        for v_i in nl.affine_range(LARGE_TILE_SZ // load_tile_size):
-          load_p = nl.arange(B_P_SIZE)[:, None]
-          load_f = nl.arange(B_D_SIZE)[None, :]
-
-          cur_v_tile[v_i, load_p, load_f] = nl.load(v[batch_id, head_id, j*LARGE_TILE_SZ+B_P_SIZE*v_i+load_p, load_f], dtype=kernel_dtype)
-
-      for i in nl.affine_range(n_tile_q):
-        i_f_128 = nl.arange(B_P_SIZE)[None, :]
-        i_f_d = nl.arange(B_D_SIZE)[None, :]
-        i_p_d = nl.arange(B_D_SIZE)[:,None]
-        q_tile = nl.ndarray((B_D_SIZE, B_P_SIZE),dtype=kernel_dtype)
-        q_tile[i_p_d, i_f_128] = nl.load(q[batch_id, head_id * q_h_per_k_h + i_q_h, i_p_d, i*B_P_SIZE+i_f_128], dtype=kernel_dtype) \
-                                  * softmax_scale # load (d, 128) tile in SBUF
-        _flash_attention_core(q_local_tile=q_tile, k=cur_k_tile, v=cur_v_tile,
-                              q_h_per_k_h=q_h_per_k_h,
-                              o_buffer=o_buffer[i], l_buffer=l_buffer[i], m_buffer=m_buffer[i],
-                              batch_id=batch_id, head_id=head_id,
-                              gqa_head_idx=i_q_h, q_tile_idx=i, local_k_large_tile_idx=j,
-                              kernel_dtype=kernel_dtype, acc_type=acc_type,
-                              flash_config=config, use_causal_mask=use_causal_mask,
-                              initialize=False,
-                              B_P_SIZE=B_P_SIZE, B_F_SIZE=B_F_SIZE, B_D_SIZE=B_D_SIZE,
-                              dropout_p=dropout_p, dropout_p_tensor=dropout_p_tensor, seed_tensor=seed_local)
-
-    # -------- write output to buffer on HBM ------------ #
-    for i in nl.affine_range(n_tile_q):
-      out = nl.ndarray((par_dim(B_P_SIZE), B_D_SIZE), dtype=kernel_dtype)
-      out[i_q_p, i_f_d] = nl.multiply(o_buffer[i, num_large_k_tile - 1, i_q_p, i_f_d],
-                                      nl.exp(m_buffer[i, num_large_k_tile - 1, i_q_p, i_0_f] - l_buffer[i, num_large_k_tile - 1, i_q_p, i_0_f]),
-                                      dtype=kernel_dtype)
-
-      nl.store(o[batch_id, head_id * q_h_per_k_h + i_q_h, i * B_P_SIZE + i_q_p, i_f_d], out[i_q_p, i_f_d])
-      if not inference:
-        lse_local = nl.zeros((par_dim(B_P_SIZE), 1), dtype=acc_type)
-        lse_local[i_q_p, i_0_f] = nl.copy(l_buffer[i, num_large_k_tile - 1, i_q_p, i_0_f], dtype=acc_type)
-        nl.store(lse[batch_id, head_id * q_h_per_k_h + i_q_h, i_q_p, i + i_0_f], lse_local[i_q_p, i_0_f])
-
-
-def flash_attn_bwd(
-  q_ref, k_ref, v_ref, o_ref,
-  dy_ref,
-  lse_ref,
-  seed_ref,
-  out_dq_ref, out_dk_ref, out_dv_ref,
-  use_causal_mask=False,
-  mixed_precision=False,
-  dropout_p=0.0,
-  softmax_scale=None,
-):
-  """
-  Flash attention backward kernel. Compute the backward gradients.
-
-  IO tensor layouts:
-   - q_ref: shape (bs, nheads, head_size, seq)
-   - k_ref: shape (bs, nheads, head_size, seq)
-   - v_ref: shape (bs, nheads, head_size, seq)
-   - o_ref: shape (bs, nheads, head_size, seq)
-   - dy_ref: shape (bs, nheads, head_size, seq)
-   - lse_ref: shape (bs, nheads, nl.tile_size.pmax, seq // nl.tile_size.pmax)
-   - seed_ref: shape (1,)
-   - out_dq_ref: shape (bs, nheads, head_size, seq)
-   - out_dk_ref: shape (bs, nheads, head_size, seq)
-   - out_dv_ref: shape (bs, nheads, head_size, seq)
-
-  Detailed steps:
-    1. D = rowsum(dO ◦ O) (pointwise multiply)
-
-    2. Recompute (softmax(Q^T@K))
-
-      2.1 Q^T@K
-      2.2 Scale the QK score
-      2.3 Apply causal mask
-      2.4 softmax
-
-    3. Compute the gradients of y = score @ V with respect to the loss
-
-    4. Compute the gradients of y = softmax(x)
-
-    5. Compute the gradients of Q^T@K
-
-      4.1 Compute dQ
-      4.2 Compute dK
-  """
-
-  # Use q_ref dtype as the intermediate tensor dtype
-  # Assume all IO tensors have the same dtype
-  kernel_dtype = q_ref.dtype
-  mixed_dtype = np.dtype(np.float32) if mixed_precision else kernel_dtype
-
-  assert q_ref.dtype == k_ref.dtype == v_ref.dtype == o_ref.dtype == dy_ref.dtype \
-         == out_dq_ref.dtype == out_dk_ref.dtype == out_dv_ref.dtype
-  assert lse_ref.dtype == mixed_dtype
-
-  # Shape checking
-  bs, nheads, d_head, seqlen = q_ref.shape
-  assert tuple(k_ref.shape) == (bs, nheads, d_head, seqlen), \
-    f"Input K shape mismatch, got {k_ref.shape}"
-  assert tuple(v_ref.shape) == (bs, nheads, d_head, seqlen), \
-    f"Input V shape mismatch, got {v_ref.shape}"
-  assert tuple(o_ref.shape) == (bs, nheads, d_head, seqlen), \
-    f"Input o shape mismatch, got {o_ref.shape}"
-  assert tuple(dy_ref.shape) == (bs, nheads, d_head, seqlen), \
-    f"Input dy shape mismatch, got {dy_ref.shape}"
-  assert tuple(lse_ref.shape) == (bs, nheads, nl.tile_size.pmax, seqlen // nl.tile_size.pmax), \
-    f"Input lse shape mismatch, got {lse_ref.shape}"
-  if seed_ref is not None:
-    assert tuple(seed_ref.shape) == (1,), \
-      f"Input seed shape mismatch, got {seed_ref.shape}"
-
-  assert tuple(out_dq_ref.shape) == (bs, nheads, d_head, seqlen), \
-    f"Output dQ shape mismatch, got {out_dq_ref.shape}"
-  assert tuple(out_dk_ref.shape) == (bs, nheads, d_head, seqlen), \
-    f"Output dK shape mismatch, got {out_dk_ref.shape}"
-  assert tuple(out_dv_ref.shape) == (bs, nheads, d_head, seqlen), \
-    f"Output dV shape mismatch, got {out_dv_ref.shape}"
-
-  # FIXME: Add masking for different seqlen values.
-  assert seqlen % 128 == 0, \
-    f"Input sequence length must be divisible by 128, got {seqlen}"
-
-  # Softmax scaling factor, multiplied onto Q
-  softmax_scale = softmax_scale or 1.0 / float(d_head ** 0.5)
-
-  # Different batch samples/attention heads have independent attention
-  batch_id = nl.program_id(axis=0)
-  head_id = nl.program_id(axis=1)
-
-  assert nl.num_programs(1) == nheads, \
-    f"The grid shape mismatch, got {nl.num_programs(1)} but should be {nheads}"
-
-  q_seq_n_tiles, q_seq_tile_size = div_ceil(seqlen, 128), 128
-  d_head_n_tiles, d_head_tile_size = div_ceil(d_head, 128), min(d_head, 128)
-
-  if seqlen >= 512:
-    k_seq_n_tiles, k_seq_tile_size = seqlen // 512, 512
-  else:
-    k_seq_n_tiles, k_seq_tile_size = seqlen // 128, 128
-
-  k_seq_n_tiles_backward, k_seq_tile_size_backward = seqlen // 128, 128
-  k_seq_fwd_bwd_tile_multipler = k_seq_tile_size // k_seq_tile_size_backward
-
-  ##############################################################
-  # Step 2.4 Prefetch exp bias for softmax
-  ##############################################################
-  softmax_exp_bias = nl.zeros((q_seq_n_tiles, par_dim(q_seq_tile_size), 1), dtype=mixed_dtype)
-  for i_q_seq_tile in nl.affine_range(q_seq_n_tiles):
-    ip_qk = nl.arange(q_seq_tile_size)[:, None]
-    lse_local = nl.load(
-      lse_ref[batch_id, head_id, ip_qk, i_q_seq_tile],
-      dtype=mixed_dtype)
-    softmax_exp_bias[i_q_seq_tile, ip_qk, 0] = lse_local * -1.0
-
-  ##############################################################
-  # Step 1 Compute rowsum(dO ◦ O)
-  ##############################################################
-  dy_o_sum = nl.ndarray((q_seq_n_tiles, par_dim(q_seq_tile_size), 1), dtype=mixed_dtype)
-  for i_q_seq_tile in nl.affine_range(q_seq_n_tiles):
-    ip_reduce = nl.arange(q_seq_tile_size)[:, None]
-    dy_o_partial = nl.zeros((par_dim(q_seq_tile_size), d_head_n_tiles), dtype=mixed_dtype)
-    for i_d_head_tile in nl.affine_range(d_head_n_tiles):
-      ip_load = nl.arange(d_head_tile_size)[:, None]
-      if_q = nl.arange(q_seq_tile_size)[None, :]
-      dy_local = nl.load_transpose2d(
-        dy_ref[batch_id, head_id, i_d_head_tile * d_head_tile_size + ip_load, i_q_seq_tile * q_seq_tile_size + if_q],
-        dtype=mixed_dtype)
-      o_local = nl.load_transpose2d(
-        o_ref[batch_id, head_id, i_d_head_tile * d_head_tile_size + ip_load, i_q_seq_tile * q_seq_tile_size + if_q],
-        dtype=mixed_dtype
-      )
-
-      dy_o_partial[ip_reduce, i_d_head_tile] = nisa.tensor_reduce(
-        np.add, data=dy_local*o_local, axis=(1,), dtype=mixed_dtype
-      )
-
-    dy_o_sum[i_q_seq_tile, ip_reduce, 0] = nisa.tensor_reduce(
-      np.add, data=dy_o_partial[ip_reduce, nl.arange(d_head_n_tiles)[None, :]],
-      axis=(1,), dtype=mixed_dtype
-    )
-
-  # Indices for prefetch
-  ip_qk = nl.arange(d_head_tile_size)[:, None]
-  if_q = nl.arange(q_seq_tile_size)[None, :]
-  if_k = nl.arange(k_seq_tile_size)[None, :]
-
-  if dropout_p > 0.0:
-    seed_local = nl.load(seed_ref[0])
-    # TODO: Remove this once the dropout supports scale prob
-    dropout_p_local = nl.full((q_seq_tile_size, 1), fill_value=dropout_p, dtype=np.float32)
-  else:
-    seed_local = None
-    dropout_p_local = None
-
-  dq_local_reduced = nl.zeros((q_seq_n_tiles, d_head_n_tiles, par_dim(d_head_tile_size), q_seq_tile_size),
-                              dtype=mixed_dtype)
-
-  # affine_range give the compiler permission to vectorize instructions
-  # inside the loop which improves the performance. However, when using the
-  # the dropout we should use sequential_range to avoid setting
-  # seed vectorization. TODO: the compiler should avoid vectorizing seed setting
-  _range = nl.sequential_range if dropout_p > 0.0 else nl.affine_range
-
-  for i_k_seq_tile in nl.affine_range(k_seq_n_tiles):
-    # Prefetch V, K
-    v_local = nl.zeros((d_head_n_tiles, par_dim(d_head_tile_size), k_seq_tile_size), dtype=kernel_dtype)
-    k_local = nl.zeros((d_head_n_tiles, par_dim(d_head_tile_size), k_seq_tile_size), dtype=kernel_dtype)
-    transposed_k_local = nl.zeros((k_seq_fwd_bwd_tile_multipler, d_head_n_tiles, par_dim(k_seq_tile_size_backward), d_head_tile_size), dtype=kernel_dtype)
-    for i_d_head_tile in nl.affine_range(d_head_n_tiles):
-      k_local[i_d_head_tile, ip_qk, if_k] = nl.load(
-        k_ref[batch_id, head_id, i_d_head_tile * d_head_tile_size + ip_qk, i_k_seq_tile * k_seq_tile_size + if_k],
-        dtype=kernel_dtype)
-      v_local[i_d_head_tile, ip_qk, if_k] = nl.load(
-        v_ref[batch_id, head_id, i_d_head_tile * d_head_tile_size + ip_qk, i_k_seq_tile * k_seq_tile_size + if_k],
-        dtype=kernel_dtype)
-      ##############################################################
-      # Prefetch k transpose for the backward too
-      ##############################################################
-      if_k_backward = nl.arange(k_seq_tile_size_backward)[None, :]
-      ip_k_backward = nl.arange(k_seq_tile_size_backward)[:, None]
-      if_d_head = nl.arange(d_head_tile_size)[None, :]
-      for i_k_seq_tile_backward in nl.affine_range(k_seq_fwd_bwd_tile_multipler):
-        transposed_k_local[i_k_seq_tile_backward, i_d_head_tile, ip_k_backward, if_d_head] = \
-          nisa.nc_transpose(k_local[i_d_head_tile, ip_qk,
-                                    i_k_seq_tile_backward * k_seq_tile_size_backward + if_k_backward])
-
-    dv_psum = nl.zeros((d_head_n_tiles, par_dim(d_head_tile_size), k_seq_tile_size),
-                        dtype=np.float32, buffer=nl.psum)
-    dk_psum = nl.zeros((d_head_n_tiles, par_dim(d_head_tile_size), k_seq_tile_size),
-                        dtype=np.float32, buffer=nl.psum)
-    for i_q_seq_tile in _range(q_seq_n_tiles):
-      # Prefetch dy, Q
-      dy_local = nl.zeros((d_head_n_tiles, par_dim(d_head_tile_size), q_seq_tile_size), dtype=kernel_dtype)
-      q_local = nl.zeros((d_head_n_tiles, par_dim(d_head_tile_size), q_seq_tile_size), dtype=kernel_dtype)
-      for i_d_head_tile in nl.affine_range(d_head_n_tiles):
-        ip_qk = nl.arange(d_head_tile_size)[:, None]
-        if_q = nl.arange(q_seq_tile_size)[None, :]
-
-        dy_local[i_d_head_tile, ip_qk, if_q] = nl.load(
-          dy_ref[batch_id, head_id, i_d_head_tile * d_head_tile_size + ip_qk, i_q_seq_tile * q_seq_tile_size + if_q],
-          dtype=kernel_dtype)
-
-        q_local[i_d_head_tile, ip_qk, if_q] = nl.load(
-          q_ref[batch_id, head_id, i_d_head_tile * d_head_tile_size + ip_qk, i_q_seq_tile * q_seq_tile_size + if_q],
-          dtype=kernel_dtype) * softmax_scale
-
-      _flash_attn_bwd_core(
-        q_local=q_local, k_local=k_local, transposed_k_local=transposed_k_local,
-        v_local=v_local, dy_local=dy_local,
-        dk_psum=dk_psum, dv_psum=dv_psum, dq_local_reduced=dq_local_reduced,
-        softmax_exp_bias=softmax_exp_bias, dy_o_sum=dy_o_sum,
-        local_i_q_seq_tile=i_q_seq_tile, local_i_k_seq_tile=i_k_seq_tile,
-        seqlen=seqlen, d_head=d_head,
-        use_causal_mask=use_causal_mask,
-        kernel_dtype=kernel_dtype, mixed_dtype=mixed_dtype,
-        softmax_scale=softmax_scale,
-        seed_local=seed_local, dropout_p=dropout_p, dropout_p_local=dropout_p_local,
-      )
-
-    # Write dK, dV
-    for i_d_head_tile in nl.affine_range(d_head_n_tiles):
-      ip_dkv = nl.arange(d_head_tile_size)[:, None]
-      if_dkv = nl.arange(k_seq_tile_size)[None, :]
-
-      nl.store(
-        out_dv_ref[batch_id, head_id,
-                   i_d_head_tile * d_head_tile_size + ip_dkv,
-                   i_k_seq_tile * k_seq_tile_size + if_dkv],
-        value=dv_psum[i_d_head_tile, ip_dkv, if_dkv],
-      )
-
-      nl.store(
-        out_dk_ref[batch_id, head_id,
-                    i_d_head_tile * d_head_tile_size + ip_dkv,
-                    i_k_seq_tile * k_seq_tile_size + if_dkv],
-        value=dk_psum[i_d_head_tile, ip_dkv, if_dkv],
-      )
-
-  # Write dQ
-  for i_q_seq_tile in nl.affine_range(q_seq_n_tiles):
-    for i_d_head_tile in nl.affine_range(d_head_n_tiles):
-      ip_dq = nl.arange(d_head_tile_size)[:, None]
-      if_dq = nl.arange(q_seq_tile_size)[None, :]
-
-      nl.store(
-        out_dq_ref[batch_id, head_id,
-                   i_d_head_tile * d_head_tile_size + ip_dq,
-                   i_q_seq_tile * q_seq_tile_size + if_dq],
-        value=dq_local_reduced[i_q_seq_tile, i_d_head_tile, ip_dq, if_dq],
-      )
-
-@trace
-def _flash_attn_bwd_core(
-  q_local, k_local, transposed_k_local, v_local, dy_local,
-  dk_psum, dv_psum, dq_local_reduced,
-  softmax_exp_bias, dy_o_sum,
-  local_i_q_seq_tile, local_i_k_seq_tile,
-  seqlen, d_head,
-  use_causal_mask,
-  kernel_dtype, mixed_dtype,
-  softmax_scale,
-  seed_local, dropout_p, dropout_p_local,
-  global_i_q_seq_tile = None,
-  global_i_k_seq_tile = None,
-):
-  """
-  The flash backward core funciton to calculate the gradients of Q, K and V
-  of the given tiles. The result will be accumulated into the dk, dv, dq psum
-  """
-  q_seq_n_tiles, q_seq_tile_size = div_ceil(seqlen, 128), 128
-  d_head_n_tiles, d_head_tile_size = div_ceil(d_head, 128), min(d_head, 128)
-  if seqlen >= 512:
-    k_seq_n_tiles, k_seq_tile_size = seqlen // 512, 512
-  else:
-    k_seq_n_tiles, k_seq_tile_size = seqlen // 128, 128
-  k_seq_n_tiles_backward, k_seq_tile_size_backward = seqlen // 128, 128
-  k_seq_fwd_bwd_tile_multipler = k_seq_tile_size // k_seq_tile_size_backward
-
-  if global_i_q_seq_tile is None:
-    global_i_q_seq_tile = local_i_q_seq_tile
-    global_i_k_seq_tile = local_i_k_seq_tile
-
-  mask = global_i_q_seq_tile * q_seq_tile_size >= global_i_k_seq_tile * k_seq_tile_size if use_causal_mask else None
-  # PSUM buffer shape: [q_seq_tile_size P, k_seq_tile_size F]
-  qk_psum = nl.zeros((par_dim(q_seq_tile_size), k_seq_tile_size),
-                      dtype=np.float32, buffer=nl.psum)
-  qk_res_buf = nl.ndarray((par_dim(q_seq_tile_size), k_seq_tile_size), buffer=nl.sbuf, dtype=kernel_dtype)
-
-  batch_id = nl.program_id(axis=0)
-  head_id = nl.program_id(axis=1)
-  # Tensor indices for accessing qk result in k_seq_tile_size
-  if_q = nl.arange(q_seq_tile_size)[None, :]
-  ip_qk = nl.arange(d_head_tile_size)[:, None]
-
-  ip_q = nl.arange(q_seq_tile_size)[:, None]
-  if_k = nl.arange(k_seq_tile_size)[None, :]
-
-  # Loop over contraction dim of QK matmul
-  for i_d_head_tile in nl.affine_range(d_head_n_tiles):
-    ##############################################################
-    # Step 2.1 Compute Q^T@K, with matmul(stationary=tensor_q, moving=tensor_k, contract=d_head)
-    ##############################################################
-    qk_psum[ip_q, if_k] += nisa.nc_matmul(q_local[i_d_head_tile, ip_qk, if_q],
-                                            k_local[i_d_head_tile, ip_qk, if_k],
-                                            mask=mask)
-
-  ######################################
-  # Step 2.2. Apply optional causal mask
-  ######################################
-  if use_causal_mask:
-    # Magic number -9984.0 to replace -inf similar to what Tensorizer uses
-    qk_res_buf[ip_q, if_k] = nisa.affine_select(
-      pred=(global_i_q_seq_tile * q_seq_tile_size + ip_q >= global_i_k_seq_tile * k_seq_tile_size + if_k),
-      on_true_tile=qk_psum[ip_q, if_k], on_false_value=-9984.0, dtype=mixed_dtype,
-      mask=mask)
-  else:
-    # Simply send psum result back to sbuf
-    qk_res_buf[ip_q, if_k] = \
-      nl.copy(qk_psum[ip_q, if_k], dtype=mixed_dtype)
-
-  softmax_y = nl.ndarray((par_dim(q_seq_tile_size), k_seq_tile_size), dtype=kernel_dtype, buffer=nl.sbuf)
-  softmax_y[ip_q, if_k] = nisa.activation(np.exp,
-                                            data=qk_res_buf[ip_q, if_k],
-                                            bias=softmax_exp_bias[local_i_q_seq_tile, ip_q, 0],
-                                            scale=1.0,
-                                            mask=mask)
-  #####################################################################
-  # Dropout
-  #####################################################################
-  if dropout_p > 0.0:
-    offset = global_i_k_seq_tile + global_i_q_seq_tile * k_seq_n_tiles \
-              + head_id * k_seq_n_tiles * q_seq_n_tiles \
-              + batch_id * nl.num_programs(1) * k_seq_n_tiles * q_seq_n_tiles
-    offset_seed = nl.add(seed_local[0, 0], offset, mask=mask)
-    nl.random_seed(seed=offset_seed, mask=mask)
-    softmax_y[ip_q, if_k] = nl.dropout(softmax_y[ip_q, if_k], rate=dropout_p_local[ip_q, 0], mask=mask)
-    softmax_y[ip_q, if_k] = nl.multiply(softmax_y[ip_q, if_k], 1 / (1 - dropout_p), mask=mask)
-
-  #####################################################################
-  # Step 3.1 Calculate the backward gradients dL/dV, where y=softmax@V
-  # in value projection with matmul(stationary=dy, moving=softmax)
-  #####################################################################
-  for i_d_head_tile in nl.affine_range(d_head_n_tiles):
-    ip_dv = nl.arange(d_head_tile_size)[:, None]
-    if_dv = nl.arange(k_seq_tile_size)[None, :]
-    if_trans_dy = nl.arange(q_seq_tile_size)[None, :]
-    trans_dy = nisa.nc_transpose(dy_local[i_d_head_tile, ip_dv, if_trans_dy],
-                                  mask=mask)
-    dv_psum[i_d_head_tile, ip_dv, if_dv] += \
-      nisa.nc_matmul(trans_dy, softmax_y[ip_q, if_k], mask=mask)
-
-  #####################################################################
-  # Step 3.2 Calculate the backward gradients dL/dsoftmax, where y=softmax@V
-  # in value projection with matmul(stationary=dy, moving=v)
-  #####################################################################
-  softmax_dy_psum = nl.zeros((par_dim(q_seq_tile_size), k_seq_tile_size),
-                              dtype=np.float32, buffer=nl.psum)
-  for i_d_head_tile in nl.affine_range(d_head_n_tiles):
-    ip_softmax_dy = nl.arange(d_head_tile_size)[:, None]
-    if_dy = nl.arange(q_seq_tile_size)[None, :]
-    softmax_dy_psum[ip_q, if_k] += \
-      nisa.nc_matmul(dy_local[i_d_head_tile, ip_softmax_dy, if_dy],
-                      v_local[i_d_head_tile, ip_softmax_dy, if_k],
-                      mask=mask)
-
-  softmax_dy = nl.ndarray((par_dim(q_seq_tile_size), k_seq_tile_size), dtype=kernel_dtype, buffer=nl.sbuf)
-  softmax_dy[ip_q, if_k] = nl.copy(softmax_dy_psum[ip_q, if_k], dtype=kernel_dtype,
-                                      mask=mask)
-
-  #####################################################################
-  # Step 4 Calculate the softmax backward gradients dL/dx, where y=softmax(x)
-  # dL/dx = y * (dL/dy - rowsum(dO_O)), where y = softmax(x)
-  #####################################################################
-  softmax_dx_local = nl.ndarray((par_dim(q_seq_tile_size), k_seq_tile_size), dtype=kernel_dtype, buffer=nl.sbuf)
-  softmax_dx_local[ip_q, if_k] = \
-    nisa.tensor_scalar(data=softmax_dy[ip_q, if_k],
-                        op0=np.subtract,
-                        operand0=dy_o_sum[local_i_q_seq_tile, ip_q, 0],
-                        op1=np.multiply,
-                        operand1=softmax_y[ip_q, if_k],
-                        mask=mask)
-
-  #####################################################################
-  # Step 5.1 Calculate dK, with matmul(stationary=Q, moving=softmax_dx)
-  #####################################################################
-  for i_d_head_tile in nl.affine_range(d_head_n_tiles):
-    ip_trans_q = nl.arange(d_head_tile_size)[:, None]
-    if_trans_q = nl.arange(q_seq_tile_size)[None, :]
-    ip_dk = nl.arange(d_head_tile_size)[:, None]
-    trans_q_local = nisa.nc_transpose(q_local[i_d_head_tile, ip_trans_q, if_trans_q],
-                                      mask=mask)
-    dk_psum[i_d_head_tile, ip_dk, if_k] += \
-      nisa.nc_matmul(trans_q_local,
-                      softmax_dx_local[ip_q, if_k],
-                      mask=mask)
-
-  #####################################################################
-  # Step 5.2 Calculate dQ
-  #####################################################################
-  if_k = nl.arange(k_seq_tile_size_backward)[None, :]
-  ip_dq = nl.arange(d_head_tile_size)[:, None]
-  if_dq = nl.arange(q_seq_tile_size)[None, :]
-  if_d = nl.arange(d_head_tile_size)[None, :]
-  ip_transposed_k = nl.arange(k_seq_tile_size_backward)[:, None]
-  for i_d_head_tile in nl.affine_range(d_head_n_tiles):
-    dq_psum = nl.zeros((par_dim(d_head_tile_size), q_seq_tile_size),
-                        dtype=np.float32, buffer=nl.psum)
-    for i_k_seq_tile_backward in nl.affine_range(k_seq_fwd_bwd_tile_multipler):
-      transposed_softmax_dx_local = \
-        nisa.nc_transpose(softmax_dx_local[ip_q, i_k_seq_tile_backward * k_seq_tile_size_backward + if_k],
-                          mask=mask)
-      dq_psum[ip_dq, if_dq] += nisa.nc_matmul(
-          transposed_k_local[i_k_seq_tile_backward, i_d_head_tile, ip_transposed_k, if_d],
-          transposed_softmax_dx_local,
-          mask=mask)
-    dq_local = nl.multiply(dq_psum[ip_dq, if_dq], softmax_scale, dtype=kernel_dtype, mask=mask)
-    dq_local_reduced[local_i_q_seq_tile, i_d_head_tile, ip_dq, if_dq] = nl.loop_reduce(
-      dq_local, op=np.add, loop_indices=(local_i_k_seq_tile,),
-      dtype=mixed_dtype, mask=mask)
-
-def fused_self_attn_for_SD_small_head_size(q_ref, k_ref, v_ref, out_ref, use_causal_mask=False,
-                                           mixed_percision=True):
-  """
-  Fused self attention kernel for small head size Stable Diffusion workload.
-
-  Computes softmax(QK^T)V. Decoder model can optionally include a causal mask
-  application. Does not include QKV rojection, output projection, dropout,
-  residual connection, etc.
-
-  This kernel is designed to be used for Stable Diffusion models where the
-  n_heads is smaller or equal to 128. Assertion is thrown if `n_heads` does
-  not satisfy the requirement.
-
-  IO tensor layouts:
-   - q_ptr: shape   (bs, n_heads, seq_q)
-   - k_ptr: shape   (bs, seq_k, n_heads)
-   - v_ptr: shape   (bs, seq_v, n_heads)
-   - out_ptr: shape (bs, seq_q, n_heads)
-   - We use seq_q and seq_k just for clarity, this kernel requires seq_q == seq_k
-
-  IO tensor dtypes:
-   - This kernel assumes all IO tensors have the same dtype
-   - If mixed_percision is True, then all Tensor Engine operation will be performed in
-     bfloat16 and accumulation will be performed in float32. Otherwise the intermediates
-     will be in the same type as the inputs.
-  """
-  # Use q_ref dtype as the intermediate tensor dtype
-  # Assume all IO tensors have the same dtype
-  kernel_dtype = q_ref.dtype
-  pe_in_dt = nl.bfloat16 if mixed_percision else np.float32
-  assert q_ref.dtype == k_ref.dtype == v_ref.dtype == out_ref.dtype
-
-  # Shape checking
-  bs, d_head, seqlen = q_ref.shape
-  assert d_head <= 128, "Cannot use this kernel for d_head > 128"
-  assert tuple(q_ref.shape) == (bs, d_head, seqlen), 'Input shape mismatch!'
-  assert tuple(k_ref.shape) == (bs, seqlen, d_head), 'Input shape mismatch!'
-  assert tuple(v_ref.shape) == (bs, seqlen,
-                                d_head), f'Input shape mismatch! Expected: {(bs, seqlen, d_head)} Actual: {tuple(v_ref.shape)}'
-  assert tuple(out_ref.shape) == (bs, seqlen, d_head), 'Output shape mismatch!'
-
-  # Softmax scaling factor, multiplied onto Q
-  softmax_scale = 0.125
-
-  # Different batch samples/attention heads have independent attention
-  batch_id = nl.program_id(axis=0)
-  # batch_id = 0
-
-  # TODO: make q_seq_tile_size user input
-  # The matmuls currently use a fixed tile size of (128, 128). This may not achieve the best
-  # performance for dense attention. However, since this kernel is in preparation
-  # for block-sparse attention, this tile size is acceptable because the block
-  # size of block-sparse attention cannot be too large.
-  q_seq_n_tiles, q_seq_tile_size = seqlen // 128, 128
-  k_seq_n_tiles, k_seq_tile_size = seqlen // 128, 128
-  # No tiling on d_head dimension since the number of d_head fits in SB
-  d_head_tile_size = d_head
-  v_seq_n_tiles, v_seq_tile_size = seqlen // 128, 128
-
-  ###################################
-  # Step 1. transpose(tensor_v)
-  ###################################
-  # Buffer for v matrix transposed
-  # Pre-fetch and keep it in SBUF throughout different softmax tiles
-  trans_v = nl.ndarray((par_dim(v_seq_tile_size), v_seq_n_tiles, d_head), dtype=pe_in_dt)
-
-  for i_k_seq_tile in nl.affine_range(k_seq_n_tiles):
-    ip_v = nl.arange(v_seq_tile_size)[:, None]
-    if_v = nl.arange(d_head_tile_size)[None, :]
-    trans_v[ip_v, i_k_seq_tile, if_v] = nl.load(
-      v_ref[batch_id, i_k_seq_tile * k_seq_tile_size + ip_v, if_v],
-      dtype=pe_in_dt)
-
-  q_local = nl.ndarray((q_seq_n_tiles, par_dim(d_head_tile_size), q_seq_tile_size), dtype=pe_in_dt)
-  ip_q = nl.arange(d_head_tile_size)[:, None]
-  if_q = nl.arange(q_seq_tile_size)[None, :]
-  for i_q_seq_tile in nl.affine_range(q_seq_n_tiles):
-    q_local[i_q_seq_tile, ip_q, if_q] = nl.load(
-      q_ref[batch_id, ip_q, i_q_seq_tile * q_seq_tile_size + if_q],
-      dtype=pe_in_dt) * softmax_scale
-
-  k_local = nl.ndarray((k_seq_n_tiles, par_dim(d_head_tile_size), k_seq_tile_size), dtype=pe_in_dt)
-  ip_k = nl.arange(d_head_tile_size)[:, None]
-  if_k = nl.arange(k_seq_tile_size)[None, :]
-  for i_k_seq_tile in nl.affine_range(k_seq_n_tiles):
-    k_local[i_k_seq_tile, ip_k, if_k] = nl.load_transpose2d(
-      k_ref[batch_id,
-            i_k_seq_tile * k_seq_tile_size + nl.arange(k_seq_tile_size)[:, None],
-            nl.arange(d_head_tile_size)[None, :]],
-      dtype=pe_in_dt)
-
-  for i_q_seq_tile in nl.affine_range(q_seq_n_tiles):  # indent = 2
-    # A SBUF buffer for an independent softmax tile
-    qk_res_buf = nl.ndarray((par_dim(q_seq_tile_size), seqlen), dtype=kernel_dtype)
-
-    neg_max_res = nl.ndarray((par_dim(q_seq_tile_size), k_seq_n_tiles), dtype=kernel_dtype)
-    ip_max = nl.arange(q_seq_tile_size)[:, None]
-    if_max = nl.arange(k_seq_n_tiles)[None, :]
-
-    # Loop over RHS free of matmul(stationary=tensor_q, moving=tensor_k, contract=d_head)
-    for i_k_seq_tile in nl.affine_range(k_seq_n_tiles):  # indent = 4
-
-      # Since the K^T tile is the RHS, the q_seq_len dimension will be P in the result
-      # PSUM buffer shape: [q_seq_tile_size P, k_seq_tile_size F]
-      qk_psum = nl.zeros((par_dim(q_seq_tile_size), k_seq_tile_size),
-                         dtype=np.float32, buffer=nl.psum)
-
-      # Tensor indices for accessing qk result in k_seq_tile_size
-      ip_qk = nl.arange(q_seq_tile_size)[:, None]
-      if_qk = nl.arange(k_seq_tile_size)[None, :]
-
-      ##############################################################
-      # Step 2. matmul(stationary=tensor_q, moving=tensor_k, contract=d_head)
-      ##############################################################
-      qk_psum[ip_qk, if_qk] += nisa.nc_matmul(moving=k_local[i_k_seq_tile, ip_k, if_k],
-                                              stationary=q_local[i_q_seq_tile, ip_q, if_q])
-
-      ###################################
-      # Step 3. Apply optional causal mask
-      ###################################
-      if use_causal_mask:
-        # Magic number -9984.0 to replace -inf similar to what Tensorizer uses
-        qk_res_buf[ip_qk, i_k_seq_tile * k_seq_tile_size + if_qk] = nisa.affine_select(
-          pred=(i_q_seq_tile * q_seq_tile_size + ip_qk >= i_k_seq_tile * k_seq_tile_size + if_qk),
-          on_true_tile=qk_psum[ip_qk, if_qk], on_false_value=-9984.0, dtype=kernel_dtype)
-      else:
-        # Simply send psum result back to sbuf
-        qk_res_buf[ip_qk, i_k_seq_tile * k_seq_tile_size + if_qk] = nl.copy(qk_psum[ip_qk, if_qk],
-                                                                              dtype=kernel_dtype)
-
-      ###################################
-      # Step 4. Softmax
-      ###################################
-      # TODO: use TensorScalarCacheReduce to avoid an extra copy
-      # We want to break this reduction in tiles because we want to overlap it with the previous matmul
-      neg_max_res[ip_max, i_k_seq_tile] = nisa.tensor_reduce(
-        np.max, data=qk_res_buf[ip_qk, i_k_seq_tile * k_seq_tile_size + if_qk],
-        axis=(1,), dtype=kernel_dtype, negate=True)
-
-    neg_max_res_final = nisa.tensor_reduce(
-      np.min, data=neg_max_res[ip_max, if_max],
-      axis=(1,), dtype=kernel_dtype, negate=False)
-
-    ip_softmax = nl.arange(q_seq_tile_size)[:, None]
-    if_softmax = nl.arange(seqlen)[None, :]
-    ip_sum_res = nl.arange(q_seq_tile_size)[:, None]
-    if_sum_res = nl.arange(d_head_tile_size)[None, :]
-
-    softmax_res = nl.ndarray((par_dim(q_seq_tile_size), seqlen), dtype=pe_in_dt)
-    sum_divisor = nl.ndarray((par_dim(q_seq_tile_size), d_head_tile_size), dtype=kernel_dtype)
-
-    # Simply use a large tile of seq_len in size since this is a "blocking" instruction
-    # Assuming the compiler will merge exp and reduce_add into a single instruction on ACT
-    exp_res = nisa.activation(np.exp,
-                              data=qk_res_buf[ip_softmax, if_softmax],
-                              bias=neg_max_res_final, scale=1.0)
-
-    sum_res = nisa.tensor_reduce(np.add, data=exp_res, axis=(1,),
-                          dtype=kernel_dtype)
-    softmax_res[ip_softmax, if_softmax] = nl.copy(exp_res, dtype=pe_in_dt)
-
-    sum_reciprocal_broadcast = (1.0 / sum_res).broadcast_to((q_seq_tile_size, d_head_tile_size))
-    sum_divisor[ip_sum_res, if_sum_res] = nl.copy(sum_reciprocal_broadcast, dtype=kernel_dtype)
-
-    # Buffer for transposed softmax results (FP32 in PSUM)
-    trans_softmax_res = nl.ndarray(
-      (par_dim(k_seq_tile_size), k_seq_n_tiles, q_seq_tile_size),
-      dtype=pe_in_dt)
-
-    # Result psum buffer has the hidden dim as P
-    attn_res_psum = nl.zeros((par_dim(d_head_tile_size), q_seq_tile_size),
-                             dtype=np.float32, buffer=nl.psum)
-
-    ip_scores_t = nl.arange(k_seq_tile_size)[:, None]
-    if_scores_t = nl.arange(q_seq_tile_size)[None, :]
-    # Loop over matmul_1 contraction
-    for i_k_seq_tile in nl.affine_range(k_seq_n_tiles):
-      ###################################
-      # Step 5. transpose(softmax_res)
-      ###################################
-      ip_scores = nl.arange(q_seq_tile_size)[:, None]
-      if_scores = nl.arange(k_seq_tile_size)[None, :]
-
-      trans_softmax_res[ip_scores_t, i_k_seq_tile, if_scores_t] = nisa.nc_transpose(
-        softmax_res[ip_scores, i_k_seq_tile * k_seq_tile_size + if_scores])
-
-    ip_out = nl.arange(d_head_tile_size)[:, None]
-    if_out = nl.arange(q_seq_tile_size)[None, :]
-    for i_k_seq_tile in nl.affine_range(k_seq_n_tiles):
-      ######################################################################
-      # Step 6. matmul_1(stationary=trans_v, moving=trans_softmax_res, contract=seqlen_v=seqlen_k)
-      ######################################################################
-      ip_v_t = nl.arange(k_seq_tile_size)[:, None]
-      if_v_t = nl.arange(d_head_tile_size)[None, :]
-      attn_res_psum[ip_out, if_out] += \
-        nisa.nc_matmul(moving=trans_softmax_res[ip_scores_t, i_k_seq_tile, if_scores_t],
-                       stationary=trans_v[ip_v_t, i_k_seq_tile, if_v_t])
-
-    attn_res_sbuf = nl.copy(attn_res_psum[ip_out, if_out], dtype=kernel_dtype)
-
-    attn_res_div = attn_res_sbuf * nisa.nc_transpose(sum_divisor[ip_sum_res, if_sum_res])
-
-    nl.store(
-      out_ref[batch_id, i_q_seq_tile * q_seq_tile_size + if_out, ip_out],
-      value=attn_res_div)
-    
\ No newline at end of file
diff --git a/src/reference/tutorial.py b/src/reference/tutorial.py
deleted file mode 100644
index 4f3ebef..0000000
--- a/src/reference/tutorial.py
+++ /dev/null
@@ -1,29 +0,0 @@
-"""
-Copyright (c) 2023, Amazon.com. All Rights Reserved
-
-kernels - Builtin high performance NKI kernels used in tutorial
-
-"""
-
-import neuronxcc.nki.language as nl
-
-def add_kernel_nx8x128x512(a_ptr, b_ptr, c_ptr, n_elements):
-  ix = nl.arange(128)[:, None]
-  iy = nl.arange(512)[None, :]
-
-  tile_size = 128 * 512
-  block_size = 8 * tile_size
-
-  j = nl.program_id(axis=0)
-
-  for i in nl.affine_range(8):
-    offset = j * block_size + i * tile_size + 512 * ix + iy
-    mask = offset < n_elements
-    a_ptr = a_ptr.ptr + offset
-    b_ptr = b_ptr.ptr + offset
-    c_ptr = c_ptr.ptr + offset
-
-    a = nl.load(a_ptr, mask=mask)
-    b = nl.load(b_ptr, mask=mask)
-    c = a + b
-    nl.store(c_ptr, value=c, mask=mask)
\ No newline at end of file
diff --git a/src/tutorials/tensor_addition/tensor_addition_jax.py b/src/tutorials/tensor_addition/tensor_addition_jax.py
deleted file mode 100644
index 9655b84..0000000
--- a/src/tutorials/tensor_addition/tensor_addition_jax.py
+++ /dev/null
@@ -1,61 +0,0 @@
-"""
-Copyright (C) 2024, Amazon.com. All Rights Reserved
-
-JAX implementation for tensor addition NKI tutorial.
-
-"""
-import jax
-import jax.numpy as jnp
-from jax_neuronx import nki_call
-
-from tensor_addition_nki_kernels import nki_tensor_add_kernel_
-
-
-def nki_tensor_add(a_input, b_input):
-  """NKI kernel caller to compute element-wise addition of two input tensors
-
-  This kernel caller lifts tile-size restriction, by applying the kernel on tiles of the inputs/outputs
-
-  Args:
-      a_input: a first input tensor, of shape [N*128, M*512]
-      b_input: a second input tensor, of shape [N*128, M*512]
-
-  Returns:
-      a tensor of shape [N*128, M*512], the result of a_input + b_input
-  """
-
-  # The SPMD launch grid denotes the number of kernel instances.
-  # In this case, we use a 2D grid where the size of each invocation is 128x512
-  grid_x = a_input.shape[0] // 128
-  grid_y = a_input.shape[1] // 512
-
-  out_shape = jax.ShapeDtypeStruct((a_input.shape[0], a_input.shape[1]), dtype=a_input.dtype)
-
-  return nki_call(
-      nki_tensor_add_kernel_,
-      a_input,
-      b_input,
-      grid=(grid_x, grid_y),
-      out_shape=out_shape,
-    )
-
-
-if __name__ == "__main__":
-
-  seed_a, seed_b = jax.random.split(jax.random.PRNGKey(42))
-  a = jax.random.uniform(seed_a, (256, 1024), dtype=jnp.bfloat16)
-  b = jax.random.uniform(seed_b, (256, 1024), dtype=jnp.bfloat16)
-
-  output_nki = nki_tensor_add(a, b)
-  print(f"output_nki={output_nki}")
-
-  output_jax = a + b
-  print(f"output_jax={output_jax}")
-
-  allclose = jnp.allclose(output_jax, output_nki, atol=1e-4, rtol=1e-2)
-  if allclose:
-    print("NKI and JAX match")
-  else:
-    print("NKI and JAX differ")
-
-  assert allclose
diff --git a/src/tutorials/tensor_addition/tensor_addition_torch.py b/src/tutorials/tensor_addition/tensor_addition_torch.py
deleted file mode 100644
index 942e728..0000000
--- a/src/tutorials/tensor_addition/tensor_addition_torch.py
+++ /dev/null
@@ -1,57 +0,0 @@
-"""
-Copyright (C) 2024, Amazon.com. All Rights Reserved
-
-PyTorch implementation for tensor addition NKI tutorial.
-
-"""
-import torch
-from torch_xla.core import xla_model as xm
-from torch_neuronx import nki_jit
-
-from tensor_addition_nki_kernels import nki_tensor_add_kernel_
-
-
-def nki_tensor_add(a_input, b_input):
-  """NKI kernel caller to compute element-wise addition of two input tensors
-
-  This kernel caller lifts tile-size restriction, by applying the kernel on tiles of the inputs/outputs
-
-  Args:
-      a_input: a first input tensor, of shape [N*128, M*512]
-      b_input: a second input tensor, of shape [N*128, M*512]
-
-  Returns:
-      a tensor of shape [N*128, M*512], the result of a_input + b_input
-  """
-
-  # The SPMD launch grid denotes the number of kernel instances.
-  # In this case, we use a 2D grid where the size of each invocation is 128x512
-  grid_x = a_input.shape[0] // 128
-  grid_y = a_input.shape[1] // 512
-  c_output = torch.zeros(a_input.shape, dtype=a_input.dtype).to(device=device)
-
-  # Decorate the NKI kernel for PyTorch tracing
-  nki_tensor_add_kernel_torch = nki_jit(nki_tensor_add_kernel_)
-  nki_tensor_add_kernel_torch[grid_x, grid_y](a_input, b_input, c_output)
-
-  return c_output
-
-if __name__ == "__main__":
-  device = xm.xla_device()
-
-  a = torch.rand((256, 1024), dtype=torch.bfloat16).to(device=device)
-  b = torch.rand((256, 1024), dtype=torch.bfloat16).to(device=device)
-
-  output_nki = nki_tensor_add(a, b)
-  print(f"output_nki={output_nki}")
-
-  output_torch = a + b
-  print(f"output_torch={output_torch}")
-
-  allclose = torch.allclose(output_torch, output_nki, atol=1e-4, rtol=1e-2)
-  if allclose:
-    print("NKI and Torch match")
-  else:
-    print("NKI and Torch differ")
-
-  assert allclose
diff --git a/test/integration/flash_attention/flash_attention_benchmark.py b/test/integration/flash_attention/flash_attention_benchmark.py
index 5aa2e40..918a14f 100644
--- a/test/integration/flash_attention/flash_attention_benchmark.py
+++ b/test/integration/flash_attention/flash_attention_benchmark.py
@@ -14,6 +14,8 @@
 
 from flash_attention import nki_flash_attn_func
 
+parent_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
+sys.path.append(parent_dir)
 from perf_utils.LatencyCollector import benchmark
 
 if len(sys.argv) != 2:
diff --git a/test/integration/fused_sd_attention_small_head/sd2_512_benchmark.py b/test/integration/fused_sd_attention_small_head/sd2_512_benchmark.py
index e7fd205..5d63424 100644
--- a/test/integration/fused_sd_attention_small_head/sd2_512_benchmark.py
+++ b/test/integration/fused_sd_attention_small_head/sd2_512_benchmark.py
@@ -8,6 +8,8 @@
 from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
 from diffusers.models.unet_2d_condition import UNet2DConditionOutput
 
+parent_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
+sys.path.append(parent_dir)
 from perf_utils.LatencyCollector import benchmark
 
 import sys
diff --git a/test/integration/resize_nearest_fixed_dma_kernel/sd2_inpainting_936_624_benchmark.py b/test/integration/resize_nearest_fixed_dma_kernel/sd2_inpainting_936_624_benchmark.py
index 3ba0eab..4970f72 100644
--- a/test/integration/resize_nearest_fixed_dma_kernel/sd2_inpainting_936_624_benchmark.py
+++ b/test/integration/resize_nearest_fixed_dma_kernel/sd2_inpainting_936_624_benchmark.py
@@ -23,6 +23,8 @@
 else:
     from diffusers.models.cross_attention import CrossAttention
 
+parent_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
+sys.path.append(parent_dir)
 from perf_utils.LatencyCollector import benchmark
 
 import sys
diff --git a/test/unit/README.md b/test/unit/README.md
index e0835de..55dc937 100644
--- a/test/unit/README.md
+++ b/test/unit/README.md
@@ -1 +1,7 @@
-Tests under this folder are unit tests for the kernels in `neuronxcc.nki.kernels`, and they are part of the nki-samples Github Repo. Only public APIs can be used for tests in this folder.
\ No newline at end of file
+Tests under this folder are unit tests for the kernels in `src/nki_samples`. 
+
+To execute the tests, we need to include `src/nki_samples` in the `PYTHONPATH`.
+
+For example, 
+
+PYTHONPATH=$PYTHONPATH:/home/ubuntu/nki-samples/src/ pytest test_flash_attn_fwd.py
\ No newline at end of file
diff --git a/test/unit/__main__.py b/test/unit/__main__.py
deleted file mode 100644
index 34fee3a..0000000
--- a/test/unit/__main__.py
+++ /dev/null
@@ -1,14 +0,0 @@
-import os
-import sys
-
-# This file is basically a hack around the fact that pytest has a bug where it does not discover conftest.py correctly if you launch the test using --pyargs.
-# https://github.com/pytest-dev/pytest/issues/1596
-
-
-# Todo: Using __file__ isn't the most robust. Figure out how to do this using importlib or similar.
-test_root = os.path.dirname(__file__)
-
-if __name__ == "__main__":
-  import pytest
-  errcode = pytest.main([test_root] + sys.argv[1:])
-  sys.exit(errcode)
\ No newline at end of file
diff --git a/test/unit/conftest.py b/test/unit/conftest.py
new file mode 100644
index 0000000..cd663ae
--- /dev/null
+++ b/test/unit/conftest.py
@@ -0,0 +1,28 @@
+import pytest
+
+def pytest_addoption(parser):
+    parser.addoption(
+        "--simulation-only", action="store_true", default=False, help="Run simulation only, it will run test with `simulation` marker in simulation mode"
+    )
+
+def pytest_configure(config):
+    config.addinivalue_line(
+        "markers", "simulation: mark simulation test that can be executed without a NeuronDevice"
+    )
+
+@pytest.fixture
+def simulation_only(request):
+    return request.config.getoption("--simulation-only")
+
+def pytest_collection_modifyitems(session, config, items):
+    if config.getoption("--simulation-only"):
+        # Only run cases with `simulation marker`
+        result = []
+        for item in items:
+            for marker in item.iter_markers():
+                if marker.name == 'simulation':
+                    result.append(item)
+                    break
+        items.clear()
+        items.extend(result)
+        
\ No newline at end of file
diff --git a/test/unit/test_SD_attention_small_head.py b/test/unit/test_SD_attention_small_head.py
index 5480fa4..1a54a4b 100644
--- a/test/unit/test_SD_attention_small_head.py
+++ b/test/unit/test_SD_attention_small_head.py
@@ -3,15 +3,14 @@
 """
 import os
 import pytest
-from neuronxcc.nki.kernels.attention import fused_self_attn_for_SD_small_head_size
-from neuronxcc.nki import benchmark, baremetal
+from nki_samples.reference.attention import fused_self_attn_for_SD_small_head_size
+from neuronxcc.nki import benchmark, baremetal, simulate_kernel
 import neuronxcc.nki.language as nl
 import numpy as np
 from scipy.special import softmax
 
 test_trace_file_path='local_trace.ntff'
-numeric_func = baremetal(fused_self_attn_for_SD_small_head_size)
-bench_func = benchmark(warmup=5, iters=10, save_trace_name=test_trace_file_path)(fused_self_attn_for_SD_small_head_size)
+bench_func = benchmark(warmup=5, iters=20, save_trace_name=test_trace_file_path)(fused_self_attn_for_SD_small_head_size)
 
 def cpu_golden_attn(q, k, v):
     softmax_scale = 0.125
@@ -34,33 +33,37 @@ def test_attention_for_SD_perf(self, bs, seqlen, d, dtype, latency):
         q = np.random.random_sample((bs, d, seqlen)).astype(np.float32)
         k = np.random.random_sample((bs, seqlen, d)).astype(np.float32)
         v = np.random.random_sample((bs, seqlen, d)).astype(np.float32)
-        out = nl.static_cast(np.ndarray(shape=(bs, seqlen, d)), dtype)
-        
+
         q_dev = nl.static_cast(q, dtype)
         k_dev = nl.static_cast(k, dtype)
         v_dev = nl.static_cast(v, dtype)
 
-        bench_func[bs](q_dev, k_dev, v_dev, out)
-        latency_res = bench_func.benchmark_result.nc_latency
-        p99 = latency_res.get_latency_percentile(99)
-        assert p99 <= latency
+        bench_func_ = bench_func[bs]
+        bench_func_(q_dev, k_dev, v_dev)
+        latency_res = bench_func_.benchmark_result.nc_latency
+        p50 = latency_res.get_latency_percentile(50)
+        assert p50 <= latency*1.05 # short running kernels are subjected to hardware fluctuation
         assert os.path.getsize(test_trace_file_path) > 0
 
+    @pytest.mark.simulation
     @pytest.mark.parametrize("bs,seqlen,d,dtype", [
         [1, 4096, 128, np.float32],
         [1, 4096, 128, nl.bfloat16]
     ])
-    def test_attention_for_SD_numberic(self, bs, seqlen, d, dtype):
+    def test_attention_for_SD_numberic(self, simulation_only, bs, seqlen, d, dtype):
         q = np.random.random_sample((bs, d, seqlen)).astype(np.float32)
         k = np.random.random_sample((bs, seqlen, d)).astype(np.float32)
         v = np.random.random_sample((bs, seqlen, d)).astype(np.float32)
-        out = nl.static_cast(np.ndarray(shape=(bs, seqlen, d)), dtype)
-        
+
         q_dev = nl.static_cast(q, dtype)
         k_dev = nl.static_cast(k, dtype)
         v_dev = nl.static_cast(v, dtype)
 
-        numeric_func[bs](q_dev, k_dev, v_dev, out)
+        numeric_func = baremetal(fused_self_attn_for_SD_small_head_size)
+        if simulation_only:
+            out = simulate_kernel(numeric_func[bs], q_dev, k_dev, v_dev)
+        else:
+            out = numeric_func[bs](q_dev, k_dev, v_dev)
         out = nl.static_cast(out, np.float32)
         golden_result = cpu_golden_attn(q, k, v)
         assert np.allclose(out, golden_result, atol=1e-2)
diff --git a/test/unit/test_allocated_SD_attention_small_head.py b/test/unit/test_allocated_SD_attention_small_head.py
new file mode 100644
index 0000000..712148f
--- /dev/null
+++ b/test/unit/test_allocated_SD_attention_small_head.py
@@ -0,0 +1,72 @@
+"""
+Copyright (c) 2023, Amazon.com. All Rights Reserved
+"""
+import os
+import pytest
+from nki_samples.reference.allocated_attention import allocated_fused_self_attn_for_SD_small_head_size
+from neuronxcc.nki import benchmark, baremetal, simulate_kernel
+import neuronxcc.nki as nki
+import neuronxcc.nki.language as nl
+import numpy as np
+from scipy.special import softmax
+
+test_trace_file_path='local_trace.ntff'
+
+bench_func = benchmark(warmup=5, iters=20, save_trace_name=test_trace_file_path)(allocated_fused_self_attn_for_SD_small_head_size)
+
+def cpu_golden_attn(q, k, v):
+    softmax_scale = 0.125
+    q_scaled = q * softmax_scale
+    raw_score = np.matmul(q_scaled.transpose(0, 2, 1), k)
+    
+    norm_score = softmax(raw_score, axis=-1)
+
+    # Transpose the result so it has the same layout as ours
+    return np.matmul(norm_score, v).transpose(0, 2, 1)
+
+class TestAttention:
+
+    @pytest.mark.parametrize("bs,seqlen,d,dtype,latency", [
+        [1, 4096, 128, np.float32, 410],
+        [1, 4096, 128, nl.bfloat16, 350],
+        [1, 5120, 128, nl.bfloat16, 586]
+    ])
+    def test_allocated_attention_for_SD_perf(self, bs, seqlen, d, dtype, latency):
+        q = np.random.random_sample((bs, d, seqlen)).astype(np.float32)
+        k = np.random.random_sample((bs, d, seqlen)).astype(np.float32)
+        v = np.random.random_sample((bs, seqlen, d)).astype(np.float32)
+
+        q_dev = nl.static_cast(q, dtype)
+        k_dev = nl.static_cast(k, dtype)
+        v_dev = nl.static_cast(v, dtype)
+
+        bench_func_ = bench_func[bs]
+        bench_func_(q_dev, k_dev, v_dev)
+        latency_res = bench_func_.benchmark_result.nc_latency
+        p50 = latency_res.get_latency_percentile(50)
+        assert p50 <= latency * 1.05 # short running kernels are subjected to hardware fluctuation
+        assert os.path.getsize(test_trace_file_path) > 0
+
+    @pytest.mark.simulation
+    @pytest.mark.parametrize("bs,seqlen,d,dtype", [
+        [1, 4096, 128, np.float32],
+        [1, 4096, 128, nl.bfloat16],
+        [1, 5120, 128, nl.bfloat16]
+    ])
+    def test_allocated_attention_for_SD_numberic(self, simulation_only, bs, seqlen, d, dtype):
+        q = np.random.random_sample((bs, d, seqlen)).astype(np.float32)
+        k = np.random.random_sample((bs, d, seqlen)).astype(np.float32)
+        v = np.random.random_sample((bs, seqlen, d)).astype(np.float32)
+
+        q_dev = nl.static_cast(q, dtype)
+        k_dev = nl.static_cast(k, dtype)
+        v_dev = nl.static_cast(v, dtype)
+
+        numeric_func = baremetal(allocated_fused_self_attn_for_SD_small_head_size)
+        if simulation_only:
+            out = simulate_kernel(numeric_func[bs], q_dev, k_dev, v_dev)
+        else:
+            out = numeric_func[bs](q_dev, k_dev, v_dev)
+        out = nl.static_cast(out, np.float32)
+        golden_result = cpu_golden_attn(q, k, v)
+        assert np.allclose(out, golden_result, atol=1e-2)
diff --git a/test/unit/test_flash_attn_bwd.py b/test/unit/test_flash_attn_bwd.py
index a55abbe..0f45f9f 100644
--- a/test/unit/test_flash_attn_bwd.py
+++ b/test/unit/test_flash_attn_bwd.py
@@ -2,12 +2,14 @@
 Copyright (c) 2023, Amazon.com. All Rights Reserved
 """
 import pytest
-from neuronxcc.nki.kernels.attention import flash_attn_bwd
-from neuronxcc.nki import benchmark, baremetal
+from nki_samples.reference.attention import flash_attn_bwd
+from neuronxcc.nki import benchmark, baremetal, simulate_kernel
 import neuronxcc.nki.language as nl
 import numpy as np
 
-numeric_func = baremetal(flash_attn_bwd)
+xfail = pytest.mark.arch_specific_xfail
+
+
 bench_func = benchmark(warmup=5, iters=10)(flash_attn_bwd)
 
 def softmax(x: np.ndarray, dim: int, zero_max_mode=False,
@@ -85,6 +87,7 @@ def mixed_precision_matmul(a, b):
 
 class TestAttention:
 
+    @xfail # P167481231
     @pytest.mark.parametrize("bs, nheads, seqlen, d, dtype, latency", [
         [1, 4, 32*1024, 128, nl.bfloat16, 117000],
     ])
@@ -97,30 +100,24 @@ def test_flash_attn_bwd_perf(self, bs, nheads, seqlen, d, dtype, latency):
         lse = np.random.random_sample([bs, nheads, nl.tile_size.pmax, seqlen // nl.tile_size.pmax]).astype(np.float32)
         seed = None
 
-        out_dq = np.zeros(shape=[bs, nheads, d, seqlen], dtype=dtype)
-        out_dk = np.zeros(shape=[bs, nheads, d, seqlen], dtype=dtype)
-        out_dv = np.zeros(shape=[bs, nheads, d, seqlen], dtype=dtype)
-        
         q = nl.static_cast(q, dtype)
         k = nl.static_cast(k, dtype)
         v = nl.static_cast(v, dtype)
         o_proj = nl.static_cast(o_proj, dtype)
         dy = nl.static_cast(dy, dtype)  
-        out_dq = nl.static_cast(out_dq, dtype)
-        out_dk = nl.static_cast(out_dk, dtype)
-        out_dv = nl.static_cast(out_dv, dtype)
-
-        bench_func[bs, nheads](q, k, v, o_proj, dy, lse, seed,
-                               out_dq, out_dk, out_dv,
-                               use_causal_mask=True, mixed_precision=True)
-        latency_res = bench_func.benchmark_result.nc_latency
-        p99 = latency_res.get_latency_percentile(99)
+
+        bench_func_ = bench_func[bs, nheads]
+        bench_func_(q, k, v, o_proj, dy, lse, seed,
+                    use_causal_mask=True, mixed_precision=True)
+        latency_res = bench_func_.benchmark_result.nc_latency
+        p99 = latency_res.get_latency_percentile(50)
         assert p99 <= latency
 
+    @pytest.mark.simulation
     @pytest.mark.parametrize("bs, nheads, seqlen, d, dtype", [
         [1, 4, 4096, 128, np.float32],
     ])
-    def test_flash_attn_bwd_numerical(self, bs, nheads, seqlen, d, dtype):
+    def test_flash_attn_bwd_numerical(self, simulation_only, bs, nheads, seqlen, d, dtype):
         q = (np.random.random_sample([bs, nheads, d, seqlen]) - 0.5) * 2
         k = (np.random.random_sample([bs, nheads, d, seqlen]) - 0.5) * 2
         v = (np.random.random_sample([bs, nheads, d, seqlen]) - 0.5) * 2
@@ -130,10 +127,7 @@ def test_flash_attn_bwd_numerical(self, bs, nheads, seqlen, d, dtype):
         v = nl.static_cast(v, dtype)
         dy = nl.static_cast(dy, dtype)
         seed = None
-        out_dq = np.zeros(shape=[bs, nheads, d, seqlen], dtype=dtype)
-        out_dk = np.zeros(shape=[bs, nheads, d, seqlen], dtype=dtype)
-        out_dv = np.zeros(shape=[bs, nheads, d, seqlen], dtype=dtype)
-  
+
         dq_golden, dk_golden, dv_golden, cached_negative_max, cached_sum_reciprocal, o_proj = \
           cpu_attention_backward(q, k, v, dy, use_causal_mask=True)
         cached_negative_max = cached_negative_max.reshape(bs, nheads, seqlen // nl.tile_size.pmax,
@@ -142,9 +136,15 @@ def test_flash_attn_bwd_numerical(self, bs, nheads, seqlen, d, dtype):
                                                               nl.tile_size.pmax).transpose(0, 1, 3, 2)
         lse = -1.0 * (cached_negative_max + np.log(cached_sum_reciprocal))
 
-        numeric_func[bs, nheads](q, k, v, o_proj, dy, lse, seed,
-                                 out_dq, out_dk, out_dv,
-                                 use_causal_mask=True, mixed_precision=True)
+        numeric_func = baremetal(flash_attn_bwd)
+        if simulation_only:
+           out_dq, out_dk, out_dv = simulate_kernel(numeric_func[bs, nheads], q, k, v, o_proj, dy, lse, seed,
+                                                          use_causal_mask=True,
+                                                          mixed_precision=True)
+        else:
+          out_dq, out_dk, out_dv = numeric_func[bs, nheads](q, k, v, o_proj, dy, lse, seed,
+                                                          use_causal_mask=True,
+                                                          mixed_precision=True)
 
         assert np.allclose(out_dq, dq_golden, atol=1e-2)
         assert np.allclose(out_dk, dk_golden, atol=1e-2)
diff --git a/test/unit/test_flash_attn_fwd.py b/test/unit/test_flash_attn_fwd.py
index 4d91164..e52354d 100644
--- a/test/unit/test_flash_attn_fwd.py
+++ b/test/unit/test_flash_attn_fwd.py
@@ -2,12 +2,11 @@
 Copyright (c) 2023, Amazon.com. All Rights Reserved
 """
 import pytest
-from neuronxcc.nki.kernels.attention import flash_fwd, FlashConfig
-from neuronxcc.nki import benchmark, baremetal
+from nki_samples.reference.attention import flash_fwd, FlashConfig
+from neuronxcc.nki import benchmark, baremetal, simulate_kernel
 import neuronxcc.nki.language as nl
 import numpy as np
- 
-numeric_func = baremetal(flash_fwd)
+
 bench_func = benchmark(warmup=5, iters=10)(flash_fwd)
  
 def softmax(x: np.ndarray, dim: int, zero_max_mode=False,
@@ -63,75 +62,93 @@ def mixed_precision_matmul(a, b):
  
 class TestAttention:
  
-    @pytest.mark.parametrize("bs, nheads, seqlen, d, dtype, use_causal_mask,\
+    @pytest.mark.parametrize("bs, nheads, seqlen_q, seqlen_k, d, dtype, use_causal_mask,\
                               mixed_precision, training, tile_size, kv_heads, should_transpose_v, latency", [
-    [1, 6, 32*1024, 96, nl.bfloat16, True, True, True, 2048, 3, False, 87000000000],
-    [1, 1, 32*1024, 96, nl.bfloat16, True, True, False, 2048, None, False, 15100000000],
+    [1, 6, 32*1024, 32*1024, 96, nl.bfloat16, True, True, True, 2048, 3, False, 87000000000],
+    [1, 1, 32*1024, 32*1024, 96, nl.bfloat16, True, True, False, 2048, None, False, 15100000000],
+    # Non-square
+    [1, 3, 32*1024, 16*1024, 96, nl.bfloat16, True, True, False, 2048, None, False, 7550000000],
+    [1, 3, 16*1024, 32*1024, 96, nl.bfloat16, True, True, False, 2048, None, False, 7550000000],
     ])
-    def test_flash_attn_fwd_perf(self, bs, nheads, seqlen, d, dtype, use_causal_mask, 
+    def test_flash_attn_fwd_perf(self, bs, nheads, seqlen_q, seqlen_k, d, dtype, use_causal_mask, 
                                  mixed_precision, training, tile_size, kv_heads, should_transpose_v,latency):
-        q = (np.random.random_sample([bs, nheads, d, seqlen]) - 0.5) * 2
-        k = (np.random.random_sample([bs, nheads, d, seqlen]) - 0.5) * 2
+        q = (np.random.random_sample([bs, nheads, d, seqlen_q]) - 0.5) * 2
+        k = (np.random.random_sample([bs, nheads, d, seqlen_k]) - 0.5) * 2
         if should_transpose_v:
-            v = (np.random.random_sample([bs, nheads, d, seqlen]) - 0.5) * 2
+            v = (np.random.random_sample([bs, nheads, d, seqlen_k]) - 0.5) * 2
         else:
-            v = (np.random.random_sample([bs, nheads, seqlen, d]) - 0.5) * 2
-        o_proj = np.zeros(shape=[bs, nheads, seqlen, d], dtype=dtype)
-        out_lse = np.zeros(shape=[bs, nheads, int(nl.tile_size.pmax), seqlen // nl.tile_size.pmax], 
+            v = (np.random.random_sample([bs, nheads, seqlen_k, d]) - 0.5) * 2
+        o_proj = np.zeros(shape=[bs, nheads, seqlen_q, d], dtype=dtype)
+        out_lse = np.zeros(shape=[bs, nheads, int(nl.tile_size.pmax), seqlen_q // nl.tile_size.pmax], 
                                   dtype=nl.float32 if mixed_precision else dtype) if training else None
         seed = None
         
         q = nl.static_cast(q, dtype)
         k = nl.static_cast(k, dtype)
         v = nl.static_cast(v, dtype)
-        o_proj = nl.static_cast(o_proj, dtype)
         config = FlashConfig(**{'seq_tile_size':tile_size, 'training':training, 'should_transpose_v':should_transpose_v})
 
         heads = nheads if kv_heads is None else kv_heads
 
-        bench_func[bs, heads](q, k, v, seed, o_proj, out_lse,
-                               use_causal_mask=use_causal_mask, mixed_precision=mixed_precision, config=config)
-        latency_res = bench_func.benchmark_result.nc_latency
-        p99 = latency_res.get_latency_percentile(99)
+        bench_func_ = bench_func[bs, heads]
+        bench_func_(q, k, v, seed, use_causal_mask=use_causal_mask,
+                    mixed_precision=mixed_precision, config=config)
+        latency_res = bench_func_.benchmark_result.nc_latency
+        p99 = latency_res.get_latency_percentile(50)
         assert p99 <= latency
- 
-    @pytest.mark.parametrize("bs, nheads, seqlen, d, dtype, use_causal_mask,\
+    
+    @pytest.mark.simulation
+    @pytest.mark.parametrize("bs, nheads, seqlen_q, seqlen_k, d, dtype, use_causal_mask,\
                               training, tile_size, kv_heads, should_transpose_v", [
-    [1, 6, 4096, 128, np.float32, True, True, 2048, 3, False],
-    [1, 1, 4096, 128, np.float32, True, False, 2048, None, False],
+    [1, 6, 4096, 4096, 128, np.float32, True, True, 2048, 3, False],
+    [1, 1, 4096, 4096, 128, np.float32, True, False, 2048, None, False],
+    [1, 1, 8192, 4096, 128, np.float32, True, False, 2048, None, False],
+    [1, 1, 4096, 8192, 128, np.float32, True, False, 2048, None, False],
     ])
-    def test_flash_attn_fwd_numerical(self, bs, nheads, seqlen, d, dtype, use_causal_mask, 
+    def test_flash_attn_fwd_numerical(self, simulation_only, bs, nheads, seqlen_q, seqlen_k, d, dtype, use_causal_mask, 
                                      training, tile_size, kv_heads, should_transpose_v):
-        q = (np.random.random_sample([bs, nheads, d, seqlen]) - 0.5) * 2
-        k = (np.random.random_sample([bs, kv_heads or nheads, d, seqlen]) - 0.5) * 2
+        q = (np.random.random_sample([bs, nheads, d, seqlen_q]) - 0.5) * 2
+        k = (np.random.random_sample([bs, kv_heads or nheads, d, seqlen_k]) - 0.5) * 2
         if should_transpose_v:
-            v = (np.random.random_sample([bs, nheads, d, seqlen]) - 0.5) * 2
+            v = (np.random.random_sample([bs, nheads, d, seqlen_k]) - 0.5) * 2
             cpu_permute = (0, 1, 2, 3)
         else:
-            v = (np.random.random_sample([bs, kv_heads or nheads, seqlen, d]) - 0.5) * 2
+            v = (np.random.random_sample([bs, kv_heads or nheads, seqlen_k, d]) - 0.5) * 2
             cpu_permute = (0, 1, 3, 2)
-        o_proj = np.zeros(shape=[bs, nheads, seqlen, d], dtype=dtype)
+
         q = nl.static_cast(q, dtype)
         k = nl.static_cast(k, dtype)
         v = nl.static_cast(v, dtype)
         seed = None
-        out_lse = np.zeros(shape=[bs, nheads, int(nl.tile_size.pmax), seqlen // nl.tile_size.pmax], 
-                                  dtype=np.float32) if training else None
 
         o_proj_golden, cached_negative_max, cached_sum_reciprocal  = \
           cpu_attention_forward(q, k, v.transpose(cpu_permute), use_causal_mask=use_causal_mask,mixed_precision=True)
         o_proj_golden = o_proj_golden.transpose(0,1,3,2) # (b,h, d, seq)
-        cached_negative_max = cached_negative_max.reshape(bs, nheads, seqlen // nl.tile_size.pmax,
+        cached_negative_max = cached_negative_max.reshape(bs, nheads, seqlen_q // nl.tile_size.pmax,
                                                           nl.tile_size.pmax).transpose(0, 1, 3, 2)
-        cached_sum_reciprocal = cached_sum_reciprocal.reshape(bs, nheads, seqlen // nl.tile_size.pmax,
+        cached_sum_reciprocal = cached_sum_reciprocal.reshape(bs, nheads, seqlen_q // nl.tile_size.pmax,
                                                               nl.tile_size.pmax).transpose(0, 1, 3, 2)
         lse_golden = -1.0 * (cached_negative_max + np.log(cached_sum_reciprocal)) if training else None
         config = FlashConfig(**{'seq_tile_size':tile_size, 'training':training, 'should_transpose_v':should_transpose_v})
 
         heads = nheads if kv_heads is None else kv_heads
-        numeric_func[bs, heads](q, k, v, seed, o_proj, out_lse, seed,
-                                 use_causal_mask=use_causal_mask, mixed_precision=True, config=config)
 
-        assert np.allclose(o_proj, o_proj_golden, atol=1e-2)
+        numeric_func = baremetal(flash_fwd)
+        if simulation_only:
+            results = simulate_kernel(numeric_func[bs, heads], q, k, v, seed,
+                                          use_causal_mask=use_causal_mask,
+                                          mixed_precision=True,
+                                          config=config)
+        else:
+            results = numeric_func[bs, heads](q, k, v, seed,
+                                          use_causal_mask=use_causal_mask,
+                                          mixed_precision=True,
+                                          config=config)
+
         if training:
+            o_proj, out_lse = results
+            assert np.allclose(o_proj, o_proj_golden, atol=1e-2)
             assert np.allclose(out_lse, lse_golden, atol=1e-2)
+        else:
+            o_proj = results
+            assert np.allclose(o_proj, o_proj_golden, atol=1e-2)
diff --git a/test/unit/test_neuron_profile.py b/test/unit/test_neuron_profile.py
new file mode 100644
index 0000000..e607705
--- /dev/null
+++ b/test/unit/test_neuron_profile.py
@@ -0,0 +1,86 @@
+from neuronxcc.nki import benchmark
+from neuronxcc.nki import profile
+import neuronxcc.nki.language as nl
+import numpy as np
+import pytest
+import os
+import shutil
+import tempfile
+
+
+WORKING_DIRECTORY = tempfile.mkdtemp()
+SAVE_NEFF_NAME = "cus_file123.neff"
+SAVE_TRACE_NAME = "profile-custom.ntff"
+NUM_EXECS = 20
+PROFILE_NTH = 10  
+JSON_REPORTS = "json_reports"
+
+@profile(working_directory=WORKING_DIRECTORY, save_neff_name=SAVE_NEFF_NAME, overwrite=False , save_trace_name=SAVE_TRACE_NAME, num_execs=NUM_EXECS, profile_nth=PROFILE_NTH)
+def nki_tensor_tensor_add(a_tensor, b_tensor):
+  c_output = nl.ndarray(a_tensor.shape, dtype=a_tensor.dtype, buffer=nl.shared_hbm)
+ 
+  a = nl.load(a_tensor)
+  b = nl.load(b_tensor)
+
+  c_tile = a + b
+
+  nl.store(c_output, value=c_tile)
+
+  return c_output
+
+class TestNeuronProfile:
+    def _get_ntff_path(self, trace_val):
+        """
+        Prepares ntff file name based on execution trace number
+        """
+        if trace_val == 1:
+            return os.path.join(WORKING_DIRECTORY, f"{os.path.splitext(os.path.basename(SAVE_TRACE_NAME))[0]}.ntff")
+        else:
+            return os.path.join(WORKING_DIRECTORY, f"{os.path.splitext(os.path.basename(SAVE_TRACE_NAME))[0]}_exec_{trace_val}.ntff")
+
+    @pytest.fixture
+    def traces(self):
+        ret = []
+        if NUM_EXECS < PROFILE_NTH:
+            ret.append(self._get_ntff_path(PROFILE_NTH))
+        else:
+            curr = PROFILE_NTH
+            while curr <= NUM_EXECS:
+                ret.append(self._get_ntff_path(curr))
+                curr += PROFILE_NTH
+        return ret
+    
+    @pytest.fixture
+    def num_reports(self):
+        if NUM_EXECS < PROFILE_NTH:
+            return 1
+        else:
+            return NUM_EXECS // PROFILE_NTH
+
+    def test_output_artifacts_created(self, traces, num_reports):
+        # delete artifact directory, only testing non-overwrite functionality
+        if os.path.exists(WORKING_DIRECTORY):
+            shutil.rmtree(WORKING_DIRECTORY)
+
+        # creates dummy input to invoke profile kernel
+        a = np.zeros([128, 1024]).astype(np.float16)
+        b = np.random.random_sample([128, 1024]).astype(np.float16)
+
+        output_nki = nki_tensor_tensor_add(a, b)
+
+        # now asserting artifacts are correctly created     
+        assert os.path.exists(os.path.join(WORKING_DIRECTORY, SAVE_NEFF_NAME)) # neff
+        
+        for trace in traces:
+            assert os.path.exists(trace) # trace
+        
+        # json reports
+        report_dir = os.path.join(WORKING_DIRECTORY, JSON_REPORTS)
+
+        assert os.path.exists(report_dir) # actually exists
+        assert len(os.listdir(report_dir)) == num_reports # report all iterations queried
+
+        # post condition cleanup
+        if os.path.exists(WORKING_DIRECTORY):
+            shutil.rmtree(WORKING_DIRECTORY)
+
diff --git a/test/unit/test_resize_nearest.py b/test/unit/test_resize_nearest.py
index a77968b..72e7aef 100644
--- a/test/unit/test_resize_nearest.py
+++ b/test/unit/test_resize_nearest.py
@@ -3,14 +3,14 @@
 """
 import pytest
 
-from neuronxcc.nki.kernels.vision import resize_nearest_fixed_dma_kernel
-from neuronxcc.nki import benchmark, baremetal
+from nki_samples.reference.vision import resize_nearest_fixed_dma_kernel
+from neuronxcc.nki import benchmark, baremetal, simulate_kernel
 import neuronxcc.nki.language as nl
 import numpy as np
 
-numeric_func = baremetal(resize_nearest_fixed_dma_kernel)
 bench_func = benchmark(warmup=5, iters=10)(resize_nearest_fixed_dma_kernel)
 
+
 def cpu_golden_result(data_tensor, output_shape):
     in_b, in_h, in_w, in_c = data_tensor.shape
     out_b, out_h, out_w, out_c = output_shape
@@ -36,33 +36,37 @@ def cpu_golden_result(data_tensor, output_shape):
 class TestResizeNearest:
 
     @pytest.mark.parametrize("in_b, in_h, in_w, in_c, out_b, out_h, out_w, out_c, dtype, latency", [
- 	    [10, 30, 20, 1280, 10, 59, 38, 1280, np.float32, 1722],
+ 	    [10, 30, 20, 1280, 10, 59, 38, 1280, np.float32, 1740],
         [1, 30, 20, 1280, 1, 59, 38, 1280, nl.float16, 659],
         [1, 30, 20, 1280, 1, 59, 38, 1280, nl.bfloat16, 659],
  	])
     def test_resize_nearest_for_perf(self, in_b, in_h, in_w, in_c, out_b, out_h, out_w, out_c, dtype, latency):
         input_tensor = np.random.random_sample((in_b, in_h, in_w, in_c)).astype(np.float32)
-        output_tensor = nl.static_cast(np.ndarray(shape=(out_b, out_h, out_w, out_c)), dtype)
-        
+
         input_dev = nl.static_cast(input_tensor, dtype)
 
-        bench_func[in_b](input_dev, output_tensor)
-        latency_res = bench_func.benchmark_result.nc_latency
-        p99 = latency_res.get_latency_percentile(99)
+        bench_func_ = bench_func[in_b]
+        bench_func_(input_dev, (out_b, out_h, out_w, out_c))
+        latency_res = bench_func_.benchmark_result.nc_latency
+        p99 = latency_res.get_latency_percentile(50)
         assert p99 <= latency
 
+    @pytest.mark.simulation
     @pytest.mark.parametrize("in_b, in_h, in_w, in_c, out_b, out_h, out_w, out_c, dtype", [
  	    [10, 30, 20, 1280, 10, 59, 38, 1280, np.float32],
         [1, 30, 20, 1280, 1, 59, 38, 1280, nl.float16],
         [1, 30, 20, 1280, 1, 59, 38, 1280, nl.bfloat16],
  	])
-    def test_resize_nearest_for_numberic(self, in_b, in_h, in_w, in_c, out_b, out_h, out_w, out_c, dtype):
+    def test_resize_nearest_for_numberic(self, simulation_only, in_b, in_h, in_w, in_c, out_b, out_h, out_w, out_c, dtype):
         input_tensor = np.random.random_sample((in_b, in_h, in_w, in_c)).astype(np.float32)
-        output_tensor = nl.static_cast(np.ndarray(shape=(out_b, out_h, out_w, out_c)), dtype)
-        
+
         input_dev = nl.static_cast(input_tensor, dtype)
 
-        numeric_func[in_b](input_dev, output_tensor)
+        numeric_func = baremetal(resize_nearest_fixed_dma_kernel)
+        if simulation_only:
+            output_tensor = simulate_kernel(numeric_func[in_b], input_dev, (out_b, out_h, out_w, out_c))
+        else:
+            output_tensor = numeric_func[in_b](input_dev, (out_b, out_h, out_w, out_c))
         output_tensor = nl.static_cast(output_tensor, np.float32)
         golden_result = cpu_golden_result(input_tensor, output_tensor.shape)
         assert np.allclose(output_tensor, golden_result, atol=1e-2)
diff --git a/test/unit/test_rmsnorm_qkv.py b/test/unit/test_rmsnorm_qkv.py
new file mode 100644
index 0000000..28838d1
--- /dev/null
+++ b/test/unit/test_rmsnorm_qkv.py
@@ -0,0 +1,69 @@
+"""
+Copyright (c) 2024, Amazon.com. All Rights Reserved
+"""
+import pytest
+from nki_samples.reference.allocated_fused_linear import allocated_fused_rms_norm_qkv
+from neuronxcc.nki import benchmark, baremetal, simulate_kernel
+import neuronxcc.nki.language as nl
+import numpy as np
+
+bench_func = benchmark(warmup=5, iters=10)(allocated_fused_rms_norm_qkv)
+
+np.random.seed(0)
+
+
+def rms_norm(hidden, gamma, eps=1e-6):
+  rms = np.sqrt(np.mean(np.square(hidden), axis=-1, keepdims=True))
+  norm = hidden * np.reciprocal(rms + eps)
+  if gamma is not None:
+    norm *= gamma
+  return norm
+
+def cpu_golden_result(hidden, gamma, qkv_weights, dtype, do_norm=True):
+  if do_norm:
+      hidden = rms_norm(hidden, gamma)
+  qkv_out = (hidden @ qkv_weights).astype(dtype)
+  return qkv_out
+
+class TestRMSNormQKV:
+  @pytest.mark.parametrize("batch, seqlen, dim, d_head, dtype, latency", [
+    [1, 128, 512, 512, np.float16, 25],
+    [1, 512, 1024, 384, nl.bfloat16, 40],
+    [1, 128, 1024, 512, nl.bfloat16, 28],
+    # [1, 1024, 8192, 512, nl.bfloat16, 301 * 1.02], # FIXME: performance is flaky
+  ])
+  def test_allocated_rmsnorm_qkv_perf(self, batch, seqlen, dim, d_head, dtype, latency):
+    hidden = np.random.random_sample((batch, seqlen, dim)).astype(np.float32)
+    weights = np.random.random_sample((dim, d_head)).astype(np.float32)
+
+    hidden = nl.static_cast(hidden, dtype)
+    weights = nl.static_cast(weights, dtype)
+
+    bench_func(hidden, weights)
+    latency_res = bench_func.benchmark_result.nc_latency
+    p99 = latency_res.get_latency_percentile(50)
+    assert p99 <= latency
+
+  @pytest.mark.simulation
+  @pytest.mark.parametrize("batch, seqlen, dim, d_head, dtype", [
+    [1, 128, 512, 512, np.float16],
+    [1, 512, 1024, 384, nl.bfloat16],
+    [1, 128, 1024, 512, nl.bfloat16],
+    [1, 1024, 8192, 512, nl.bfloat16]
+  ])
+  def test_allocated_rmsnorm_qkv_numeric(self, simulation_only, batch, seqlen, dim, d_head, dtype):
+    hidden = np.random.random_sample((batch, seqlen, dim))
+    weights = np.random.random_sample((dim, d_head))
+
+    hidden_dev = nl.static_cast(hidden, dtype)
+    weights_dev = nl.static_cast(weights, dtype)
+
+    numeric_func = baremetal(allocated_fused_rms_norm_qkv)
+    if simulation_only:
+      out = simulate_kernel(numeric_func, hidden_dev, weights_dev)
+    else:
+      out = numeric_func(hidden_dev, weights_dev)
+    out = nl.static_cast(out, np.float32)
+    golden_res = nl.static_cast(cpu_golden_result(hidden, None, weights, dtype, do_norm=True), np.float32)
+    assert np.allclose(out, golden_res, atol=1e-2, rtol=1e-2)
+
diff --git a/test/unit/test_select_and_scatter.py b/test/unit/test_select_and_scatter.py
index 70f7a7c..08e787f 100644
--- a/test/unit/test_select_and_scatter.py
+++ b/test/unit/test_select_and_scatter.py
@@ -1,11 +1,10 @@
 import pytest
 
-from neuronxcc.nki.kernels.vision import select_and_scatter_kernel
-from neuronxcc.nki import benchmark, baremetal
+from nki_samples.reference.vision import select_and_scatter_kernel
+from neuronxcc.nki import benchmark, baremetal, simulate_kernel
 import neuronxcc.nki.language as nl
 import numpy as np
 
-numeric_func = baremetal(select_and_scatter_kernel)
 bench_func = benchmark(warmup=5, iters=10)(select_and_scatter_kernel)
 
 np.random.seed(0)
@@ -39,7 +38,6 @@ def cpu_golden_result(operand_tensor, source_tensor, window_dimensions=(3, 3), w
                     out_h = h * stride_h + local_h - padding[0]
                     out_w = w * stride_w + local_w - padding[1]
                     output_tensor[n, c, out_h, out_w] += source_tensor[n, c, h, w]
-
     return output_tensor
 
 class TestSelectAndScatter:
@@ -47,31 +45,33 @@ class TestSelectAndScatter:
  	    [8, 64, 112, 112, 56, 56, np.float32, 4500],
  	])
     def test_select_and_scatter_for_perf(self, n, c, operand_h, operand_w, source_h, source_w, dtype, latency):
-        operand_tensor = np.random.random_sample((n, c, operand_h, operand_w)).astype(np.float32)
-        source_tensor = np.random.random_sample((n, c, source_h, source_w)).astype(np.float32)
-        output_tensor = nl.static_cast(np.ndarray(shape=(n, c, operand_h, operand_w)), dtype)
-        
-        operand_dev = nl.static_cast(operand_tensor, dtype)
-        source_dev = nl.static_cast(source_tensor, dtype)
+        operand_dev = nl.static_cast(np.random.random_sample((n, c, operand_h, operand_w)), dtype)
+        source_dev = nl.static_cast(np.random.random_sample((n, c, source_h, source_w)), dtype)
 
-        bench_func(operand_dev, source_dev, output_tensor)
+        bench_func(operand_dev, source_dev)
         latency_res = bench_func.benchmark_result.nc_latency
-        p99 = latency_res.get_latency_percentile(99)
+        p99 = latency_res.get_latency_percentile(50)
         assert p99 <= latency
 
+    @pytest.mark.simulation
     @pytest.mark.parametrize("n, c, operand_h, operand_w, source_h, source_w, dtype", [
  	    [8, 64, 112, 112, 56, 56, np.float32],
- 	    pytest.param(8, 64, 112, 112, 56, 56, nl.bfloat16, marks=pytest.mark.xfail),
+ 	    [8, 64, 112, 112, 56, 56, nl.bfloat16],
  	])
-    def test_select_and_scatter_for_numeric(self, n, c, operand_h, operand_w, source_h, source_w, dtype):
-        operand_tensor = np.random.random_sample((n, c, operand_h, operand_w)).astype(np.float32)
-        source_tensor = np.random.random_sample((n, c, source_h, source_w)).astype(np.float32)
-        output_tensor = nl.static_cast(np.ndarray(shape=(n, c, operand_h, operand_w)), dtype)
-        
-        operand_dev = nl.static_cast(operand_tensor, dtype)
-        source_dev = nl.static_cast(source_tensor, dtype)
+    def test_select_and_scatter_for_numeric(self,simulation_only, n, c, operand_h, operand_w, source_h, source_w, dtype):
+        operand_dev = nl.static_cast(np.random.random_sample((n, c, operand_h, operand_w)), dtype)
+        source_dev = nl.static_cast(np.random.random_sample((n, c, source_h, source_w)), dtype)
+
+        sw = nl.static_cast(np.ndarray(shape=(n, c, source_h, source_w, 3, 3)), dtype)
+        operand_tensor = nl.static_cast(operand_dev, np.float32)
+        source_tensor = nl.static_cast(source_dev, np.float32)
 
-        numeric_func(operand_dev, source_dev, output_tensor)
+        numeric_func = baremetal(select_and_scatter_kernel)
+        if simulation_only:
+            output_dev = simulate_kernel(numeric_func, operand_dev, source_dev)
+        else:
+            output_dev = numeric_func(operand_dev, source_dev)
         golden_result = cpu_golden_result(operand_tensor, source_tensor)
-        output_tensor = nl.static_cast(output_tensor, np.float32)
-        assert np.allclose(output_tensor, golden_result)
\ No newline at end of file
+        nki_result = nl.static_cast(output_dev, np.float32)
+
+        assert np.allclose(nki_result, golden_result, rtol=1e-2, atol=1e-2)