Skip to content

Commit

Permalink
[INTERPRETER] Improve support for aliased tensors in interpreter (#5890)
Browse files Browse the repository at this point in the history
If two tensors alias, the interpreter will clobber writes when it makes
a CPU copy of its backing storage. Identifying this in general is quite
difficult, but in most cases aliasing tensors share the same backing
storage. If we set up our CPU-side storage to mirror the device, then we
can avoid this particular bug in that case. This partially fixes #5791.

<!---
The core Triton is a small number of people, and we receive many PRs
(thank
you!).  To help us review your code more quickly, **if you are a new
contributor (less than 3 PRs merged) we ask that you complete the
following
tasks and include the filled-out checklist in your PR description.**

Complete the following tasks before sending your PR, and replace `[ ]`
with
`[x]` to indicate you have done them.
-->

# New contributor declaration
- [X] I am not making a trivial change, such as fixing a typo in a
comment.

- [ ] I have written a PR description following these
  [rules](https://cbea.ms/git-commit/#why-not-how).

- [X] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`.

- Select one of the following.
  - [X] I have added tests.
    - `/test` for `lit` tests
    - `/unittest` for C++ tests
    - `/python/test` for end-to-end tests
  - [ ] This PR does not need a test because `FILL THIS IN`.

- Select one of the following.
  - [X] I have not added any `lit` tests.
- [ ] The `lit` tests I have added follow these [best
practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices),
including the "tests should be minimal" section. (Usually running Python
code
    and using the instructions it generates is not minimal.)
  • Loading branch information
saagarjha authored Feb 12, 2025
1 parent b2a86b1 commit 08d7f64
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 9 deletions.
12 changes: 12 additions & 0 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -7025,3 +7025,15 @@ def _simple_add(
_simple_add[grid](x, x.stride(0), x.stride(1))

assert torch.allclose(x, torch.ones_like(x) * c_dim)


@pytest.mark.interpreter
def test_aliasing(device):

@triton.jit
def aliasing_kernel(buffer, buffer2):
triton.language.store(buffer, 1)

buffer = torch.zeros(1, device=device)
aliasing_kernel[(1, )](buffer, buffer)
assert buffer[0] == 1
23 changes: 14 additions & 9 deletions python/triton/runtime/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1157,16 +1157,22 @@ def __init__(self, fn, arg_names, grid):
self.constexprs = [name for name in arg_names if __annotations__.get(name) == "constexpr"]

def _init_args_hst(self, args_dev, kwargs):
storages = {}

def _to_cpu(arg):
if isinstance(arg, tuple):
return _tuple_create(arg, map(_to_cpu, arg))
elif not hasattr(arg, "data_ptr"):
return arg

unwrapped_arg = _unwrap_tensor(arg)
if unwrapped_arg.untyped_storage().data_ptr() not in storages:
storage = unwrapped_arg.untyped_storage()
storages[storage.data_ptr()] = storage.cpu()

storage = storages[unwrapped_arg.untyped_storage().data_ptr()]
cpu_arg = unwrapped_arg.new_empty(0, device='cpu')
cpu_arg.set_(unwrapped_arg.untyped_storage().cpu(), unwrapped_arg.storage_offset(), unwrapped_arg.size(),
unwrapped_arg.stride())
cpu_arg.set_(storage, unwrapped_arg.storage_offset(), unwrapped_arg.size(), unwrapped_arg.stride())
cpu_arg = _rewrap_tensor(cpu_arg, original_tensor=arg)
return cpu_arg

Expand All @@ -1175,21 +1181,17 @@ def _to_cpu(arg):
# Process keyword arguments
kwargs_hst = {}
for key, value in kwargs.items():
if hasattr(value, "data_ptr"):
kwargs_hst[key] = value.cpu()
elif isinstance(value, tuple):
return _tuple_create(value, map(_to_cpu, value))
else:
kwargs_hst[key] = value
kwargs_hst[key] = _to_cpu(value)
return args_hst, kwargs_hst

def _restore_args_dev(self, args_dev, args_hst, kwargs, kwargs_hst):
storages = {}

def _from_cpu(arg_dev, arg_hst):
if hasattr(arg_dev, "data_ptr"):
# No need to rewrap because this just modifies internal
arg_dev, arg_hst = _unwrap_tensor(arg_dev), _unwrap_tensor(arg_hst)
arg_dev.untyped_storage().copy_(arg_hst.untyped_storage())
storages[arg_dev.untyped_storage().data_ptr()] = (arg_dev.untyped_storage(), arg_hst.untyped_storage())
elif isinstance(arg_dev, tuple):
for (arg_dev, arg_hst) in zip(arg_dev, arg_hst):
_from_cpu(arg_dev, arg_hst)
Expand All @@ -1202,6 +1204,9 @@ def _from_cpu(arg_dev, arg_hst):
kwarg_hst = kwargs_hst[key]
_from_cpu(kwarg_dev, kwarg_hst)

for (arg_dev, arg_hst) in storages.values():
arg_dev.copy_(arg_hst)

def __call__(self, *args_dev, **kwargs):
if kwargs.pop("warmup", False):
return
Expand Down

0 comments on commit 08d7f64

Please sign in to comment.