Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inconsistent output when mixing eager with torch.compile #8832

Open
liangfu opened this issue Mar 13, 2025 · 1 comment
Open

Inconsistent output when mixing eager with torch.compile #8832

liangfu opened this issue Mar 13, 2025 · 1 comment
Labels
bug Something isn't working dynamo

Comments

@liangfu
Copy link

liangfu commented Mar 13, 2025

🐛 Bug

When mixing eager tensors with torch.compile, the output tensor result is consistent.

To Reproduce

import torch
import os
import torch_xla.core.xla_model as xm

def write_to_kv_cache(
        key: torch.Tensor,
        value: torch.Tensor,
        key_cache: torch.Tensor,
        value_cache: torch.Tensor,
        slot_mapping: torch.Tensor,
) -> None:
    torch.ops.xla.dynamo_set_buffer_donor_(key_cache, True)
    torch.ops.xla.dynamo_set_buffer_donor_(value_cache, True)

    key = key.flatten(0, 2)
    value = value.flatten(0, 2)
    key_cache = key_cache.flatten(0, 2)
    value_cache = value_cache.flatten(0, 2)
    key_cache.index_copy_(0, slot_mapping, key)
    value_cache.index_copy_(0, slot_mapping, value)

if __name__ == '__main__':
    device = xm.xla_device()
    num_blocks = 128
    block_size = 128
    num_kv_heads = 4
    head_size = 64
    kv_cache_shape = (2, num_blocks, block_size, num_kv_heads, head_size)
    kv_cache = torch.zeros(kv_cache_shape,
                           dtype=torch.float,
                           device=device)
    key_cache, value_cache = kv_cache

    num_heads = 64
    kv = torch.empty(1, 3, 2, num_kv_heads, head_size, dtype=torch.float, device=device)
    kv.uniform_(-1,1)
    key, value = kv.unbind(dim=2)
    slot_mapping = torch.tensor([0,1,2,3,4,5,6,7,8,9,10,11], dtype=torch.int32,device=device).long()
    compiled_callable = torch.compile(write_to_kv_cache,
                                      backend="openxla",
                                      fullgraph=False,
                                      dynamic=False)
    compiled_callable(key, value, key_cache, value_cache, slot_mapping)
    print(f"k/v cache use torch compile {key_cache[0][:5]}")

    compiled_callable(key, value, key_cache, value_cache, slot_mapping)
    print(f"k/v cache use torch compile again func {key_cache[0][:5]}")

    write_to_kv_cache(key, value, key_cache, value_cache, slot_mapping)
    print(f"k/v cache use original func {key_cache[0][:5]}")

Expected behavior

Environment

  • Reproducible on XLA backend [CPU/TPU/CUDA]:
  • torch_xla version:

Additional context

@miladm
Copy link
Collaborator

miladm commented Mar 13, 2025

@liangfu this is a great project I'd like to see us scope and land.

AFAIU torch.compile()+eager integration facced+addressed some challenges (in the context of of vLLM V1). Please feel free to concretely list the technical issues you observe with the XLA backend. Then, I'd suggest we (a) consult our meta partners on the general torch.compile() + eager hybrid challenges they have/are fixing, (b) use your list of issues to scope + enable the compile+eager hybrid execution model for the xla backend.

cc @yaoshiang @bhavya01 @tengyifei @bdhirsh @zou3519 @WoosukKwon @ezyang for viz

@ysiraichi ysiraichi added bug Something isn't working dynamo labels Mar 14, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working dynamo
Projects
None yet
Development

No branches or pull requests

3 participants