Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fused Self Attention Produces Large Errors When Using torch.randn Instead of torch.rand #51

Open
1 task
Jerry2423 opened this issue Feb 18, 2025 · 0 comments
Labels
bug Something isn't working

Comments

@Jerry2423
Copy link

Describe the bug

In sd_attention_torch.py, when q_tensor (same for k and v) is generated using torch.rand, the results of fused_self_attn_for_SD_small_head_size match the expected output from cpu_golden_attn. However, when generated using torch.randn, the computed results show a significant discrepancy.

Expected Behavior

   q_tensor = torch.randn((4096, 64), dtype=torch.float32).to(device=device)
  k_tensor = torch.randn((4096, 64), dtype=torch.float32).to(device=device)
  v_tensor = torch.randn((4096, 64), dtype=torch.float32).to(device=device)

  output_nki = fused_self_attn_for_SD_small_head_size(q_tensor, k_tensor, v_tensor)

  output_torch = cpu_golden_attn(q_tensor, k_tensor, v_tensor)

  allclose = torch.allclose(output_torch, output_nki, atol=1e-5, rtol=1e-3)

  if allclose:
    print("NKI and Torch match")
  else:
    print("NKI and Torch differ")

Expected output - "NKI and Torch match"

Current Behavior

NKI and Torch differ

Reproduction Steps

In sd_attention_torch.py, replace q_tensor = torch.rand((4096, 64), dtype=torch.float32).to(device=device) with q_tensor = torch.randn((4096, 64), dtype=torch.float32).to(device=device). Do the same thing with k_tensor, v_tensor

Regression Issue

  • Select this option if this issue appears to be a regression.

Possible Solution

No response

Additional Information/Context

No response

neuronx-cc version used

aws_neuronx_venv_pytorch_2_5_nxd_inference

Framework(s) and their versions used (JAX, PyTorch, etc..)

No response

@Jerry2423 Jerry2423 added the bug Something isn't working label Feb 18, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant