Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generating stablehlo.composite and running it through PJRT #8792

Open
sechkova opened this issue Mar 5, 2025 · 1 comment
Open

Generating stablehlo.composite and running it through PJRT #8792

sechkova opened this issue Mar 5, 2025 · 1 comment
Labels
bug Something isn't working stablehlo StableHLO related work

Comments

@sechkova
Copy link

sechkova commented Mar 5, 2025

❓ Questions and Help

Following the example from the docs, I tried to use StableHLOCompositeBuilder to generate a stablehlo.composite op with the difference that I want to actually run it through PJRT instead of exporting it.
Is there a way of doing this currently or are there any future plans regarding it?

This is my example code:

import os
os.environ['XLA_STABLEHLO_COMPILE'] = '1'

import torch
import torch.nn.functional as F
from torch_xla import stablehlo
from torch_xla.experimental.mark_pattern_utils import StableHLOCompositeBuilder

class M(torch.nn.Module):

    def __init__(self):
        super().__init__()
        self.q_proj = torch.nn.Linear(128, 128, bias=False)
        self.k_proj = torch.nn.Linear(128, 128, bias=False)
        self.v_proj = torch.nn.Linear(128, 128, bias=False)
        self.b = StableHLOCompositeBuilder("test.sdpa", {"scale": 0.25, "other_attr": "val"})

    def forward(self, x):
        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)
        q, k, v = self.b.mark_inputs(q, k, v)
        attn_out = F.scaled_dot_product_attention(q, k, v, scale=0.25)
        attn_out = self.b.mark_outputs(attn_out)
        attn_out = attn_out + x
        return attn_out

device = "xla"

input_args = torch.randn((10, 8, 128)).to(device)
model = M().to(device)
out = model(input_args)
print(out)
WARNING:root:Found CUDA without GPU_NUM_DEVICES. Defaulting to PJRT_DEVICE=CUDA with GPU_NUM_DEVICES=1

loc("select.69"): error: 'stablehlo.select' op using value defined outside the region

...

RuntimeError: torch_xla/csrc/runtime/stablehlo_helper.cc:109 : Check failed: status.ok()
*** Begin stack trace ***
	tsl::CurrentStackTrace()
	torch_xla::ConvertHloToStableHlo(xla::HloModuleProto const*, mlir::ModuleOp*)
	torch_xla::runtime::PjRtComputationClient::Compile(std::vector<torch_xla::runtime::ComputationClient::CompileInstance, std::allocator<torch_xla::runtime::ComputationClient::CompileInstance> >)
	torch_xla::XLAGraphExecutor::Compile(std::vector<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> >, std::allocator<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > > >&, absl::lts_20230802::Span<std::string const>, torch::lazy::LazyGraphExecutor::SyncTensorCollection const&, torch::lazy::LazyGraphExecutor::PostOrderData*, std::vector<torch::lazy::Value, std::allocator<torch::lazy::Value> > const&)
	torch_xla::XLAGraphExecutor::SyncTensorsGraphInternal(std::vector<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> >, std::allocator<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > > >*, absl::lts_20230802::Span<std::string const>, torch::lazy::LazyGraphExecutor::SyncTensorsConfig const&, bool)
	torch_xla::XLAGraphExecutor::SyncTensorsGraph(std::vector<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> >, std::allocator<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > > >*, absl::lts_20230802::Span<std::string const>, bool, bool, bool)

...

*** End stack trace ***
MHLO -> StableHLO conversion failed.
StableHLO Module from MHLO -> StableHLO conversion is not leagal.Please open a github issue to PyTorch/XLA.

I used torch-xla 2.5.1 for the example above but I get similar error with 2.6

torch                    2.5.1
torch-xla                2.5.1
@ysiraichi ysiraichi added bug Something isn't working stablehlo StableHLO related work labels Mar 6, 2025
@ysiraichi
Copy link
Collaborator

Thank you for filing this issue.
I was able to reproduce this on 2.7.0+git9b61c1a.

@tengyifei @ManfeiBai @zpcore Do you know what's happening, here?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working stablehlo StableHLO related work
Projects
None yet
Development

No branches or pull requests

2 participants