Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch_xla scan forces inputs to have gradients #8783

Open
tengyifei opened this issue Mar 4, 2025 · 0 comments
Open

torch_xla scan forces inputs to have gradients #8783

tengyifei opened this issue Mar 4, 2025 · 0 comments
Assignees
Labels
bug Something isn't working

Comments

@tengyifei
Copy link
Collaborator

The snippet

# Make some fake tensors to trace the user function and obtain the
# forward and backward graphs. Note that the init/carry fake tensor
# always requires grad. That's because even if the user passed in some
# `init` that does not require grad, we still want gradients to flow
# through the `carry` from one iteration of the user function to the
# next. In summary, the `carry` argument used to trace a user function
# to get a correct backward pass always requires grad.
def make_fake_tensor(v: torch.Tensor, requires_grad=True) -> torch.Tensor:
return torch.empty_like(
v, dtype=v.dtype, device=v.device, requires_grad=requires_grad)
is probably wrong.

The most obvious example is that if one of the input is an integer, then it can't possibly have gradients.

@tengyifei tengyifei self-assigned this Mar 4, 2025
@ysiraichi ysiraichi added the bug Something isn't working label Mar 5, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants