Skip to content

Commit a02ce57

Browse files
committed
[Performance] Make _to_consolidated compatible with compile
ghstack-source-id: e6b2c50 Pull Request resolved: #1041
1 parent ee49fc7 commit a02ce57

File tree

3 files changed

+213
-21
lines changed

3 files changed

+213
-21
lines changed

benchmarks/common/h2d_test.py

+30-5
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,13 @@
1414
TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version)
1515

1616

17+
@pytest.fixture(autouse=True, scope="module")
18+
def empty_compiler_cache():
19+
torch._dynamo.reset_code_caches()
20+
print("Emptying cache")
21+
yield
22+
23+
1724
@pytest.fixture
1825
def td():
1926
return TensorDict(
@@ -52,20 +59,38 @@ def default_device():
5259
pytest.skip("CUDA/MPS is not available")
5360

5461

55-
@pytest.mark.parametrize("consolidated", [False, True])
62+
@pytest.mark.parametrize(
63+
"consolidated,compiled", [[False, False], [True, False], [True, True]]
64+
)
5665
@pytest.mark.skipif(
5766
TORCH_VERSION < version.parse("2.5.0"), reason="requires torch>=2.5"
5867
)
5968
class TestTo:
60-
def test_to(self, benchmark, consolidated, td, default_device):
69+
def test_to(self, benchmark, consolidated, td, default_device, compiled):
6170
if consolidated:
6271
td = td.consolidate()
63-
benchmark(lambda: td.to(default_device))
6472

65-
def test_to_njt(self, benchmark, consolidated, njt_td, default_device):
73+
def to(td):
74+
return td.to(default_device)
75+
76+
if compiled:
77+
to = torch.compile(to)
78+
for _ in range(3):
79+
to(td)
80+
benchmark(to, td)
81+
82+
def test_to_njt(self, benchmark, consolidated, njt_td, default_device, compiled):
6683
if consolidated:
6784
njt_td = njt_td.consolidate()
68-
benchmark(lambda: njt_td.to(default_device))
85+
86+
def to(td):
87+
return td.to(default_device)
88+
89+
if compiled:
90+
to = torch.compile(to)
91+
for _ in range(3):
92+
to(njt_td)
93+
benchmark(to, njt_td)
6994

7095

7196
if __name__ == "__main__":

benchmarks/compile/compile_td_test.py

+7
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,13 @@ class MyTensorClass:
2323
f: torch.Tensor
2424

2525

26+
@pytest.fixture(autouse=True, scope="module")
27+
def empty_compiler_cache():
28+
torch._dynamo.reset_code_caches()
29+
print("Emptying cache")
30+
yield
31+
32+
2633
# Functions
2734
def add_one(td):
2835
return td + 1

tensordict/base.py

+176-16
Original file line numberDiff line numberDiff line change
@@ -3521,9 +3521,10 @@ def _reduce_vals_and_metadata(self, *, dtype=NO_DEFAULT, requires_metadata):
35213521

35223522
flat_size = []
35233523
start = 0
3524+
sorting_index = 0
35243525

35253526
def add_single_value(value, key, metadata_dict, dtype, shape, flat_size):
3526-
nonlocal start
3527+
nonlocal start, sorting_index
35273528
n = value.element_size() * value.numel()
35283529
if need_padding:
35293530
pad = n % 8
@@ -3541,7 +3542,10 @@ def add_single_value(value, key, metadata_dict, dtype, shape, flat_size):
35413542
start,
35423543
stop,
35433544
pad,
3545+
flat_size[-1],
3546+
sorting_index,
35443547
)
3548+
sorting_index = sorting_index + 1
35453549
start = stop
35463550

35473551
def assign(
@@ -10441,6 +10445,7 @@ def to(self, *args, **kwargs) -> T:
1044110445
pin_memory=non_blocking_pin,
1044210446
num_threads=num_threads,
1044310447
non_blocking=non_blocking,
10448+
compilable=is_dynamo_compiling(),
1044410449
)
1044510450

1044610451
if non_blocking is None:
@@ -10498,14 +10503,49 @@ def to_pinmem(tensor, _to=to):
1049810503
self._sync_all()
1049910504
return result
1050010505

10501-
def _to_consolidated(self, *, device, pin_memory, num_threads, non_blocking):
10506+
def _to_consolidated(
10507+
self, *, device, pin_memory, num_threads, non_blocking, compilable
10508+
):
1050210509
if num_threads is None:
1050310510
# unspecified num_threads should mean 0
1050410511
num_threads = 0
10512+
1050510513
storage = self._consolidated["storage"]
10506-
if pin_memory:
10507-
storage = storage.pin_memory()
10508-
storage_cast = storage.to(device, non_blocking=True)
10514+
10515+
@torch.compiler.disable()
10516+
def to(storage):
10517+
if pin_memory:
10518+
storage = storage.pin_memory()
10519+
storage_cast = storage.to(device, non_blocking=True)
10520+
return storage_cast
10521+
10522+
storage_cast = to(storage)
10523+
10524+
if compilable:
10525+
result = self._to_consolidated_compile(
10526+
device=device, num_threads=num_threads, storage_cast=storage_cast
10527+
)
10528+
else:
10529+
result = self._to_consolidated_eager(
10530+
device=device, num_threads=num_threads, storage_cast=storage_cast
10531+
)
10532+
10533+
if non_blocking in (False, None):
10534+
if device.type == "cuda" and non_blocking is False:
10535+
# sending to CUDA force sync
10536+
cuda_device = device
10537+
elif storage.device.type == "cuda":
10538+
# sending from cuda: need sync unless intentionally not asked for
10539+
cuda_device = storage.device.type
10540+
else:
10541+
cuda_device = None
10542+
if cuda_device is not None:
10543+
torch.cuda.current_stream(cuda_device).synchronize()
10544+
10545+
return result
10546+
10547+
def _to_consolidated_eager(self, *, device, num_threads, storage_cast):
10548+
1050910549
untyped_storage = storage_cast.untyped_storage()
1051010550

1051110551
def set_(x):
@@ -10574,18 +10614,138 @@ def copy_dict(d):
1057410614
}
1057510615

1057610616
result._consolidated["metadata"] = copy_dict(self._consolidated["metadata"])
10577-
if non_blocking in (False, None):
10578-
if device.type == "cuda" and non_blocking is False:
10579-
# sending to CUDA force sync
10580-
cuda_device = device
10581-
elif storage.device.type == "cuda":
10582-
# sending from cuda: need sync unless intentionally not asked for
10583-
cuda_device = storage.device.type
10584-
else:
10585-
cuda_device = None
10586-
if cuda_device is not None:
10587-
torch.cuda.current_stream(cuda_device).synchronize()
10617+
return result
10618+
10619+
def _to_consolidated_compile(self, *, device, num_threads, storage_cast):
10620+
10621+
def get_tensors_length(metadata, lengths=None, pos=None, keys=None, prefix=()):
10622+
root = False
10623+
if lengths is None:
10624+
lengths = []
10625+
pos = []
10626+
keys = []
10627+
root = True
10628+
for k, v in metadata["leaves"].items():
10629+
lengths.append(v[-2])
10630+
pos.append(v[-1])
10631+
keys.append(prefix + (k,))
10632+
for k, d in metadata.items():
10633+
if "leaves" in d:
10634+
get_tensors_length(
10635+
d, lengths=lengths, pos=pos, keys=keys, prefix=prefix + (k,)
10636+
)
10637+
if root:
10638+
# l = torch.empty(len(lengths), dtype=torch.long)
10639+
# l[torch.as_tensor(pos)] = torch.as_tensor(lengths)
10640+
out0 = [
10641+
None,
10642+
] * len(pos)
10643+
out1 = [
10644+
None,
10645+
] * len(pos)
10646+
for p, l, k in zip(pos, lengths, keys):
10647+
out0[p] = k
10648+
out1[p] = l
10649+
return out0, out1
10650+
10651+
def split_storage(consolidated):
10652+
keys, splits = get_tensors_length(consolidated["metadata"])
10653+
return dict(zip(keys, consolidated["storage"].split(splits)))
10654+
10655+
if num_threads is None:
10656+
# unspecified num_threads should mean 0
10657+
num_threads = 0
10658+
10659+
_consolidated = {"storage": storage_cast}
10660+
if "metadata" in self._consolidated:
10661+
# faster than deepcopy
10662+
def copy_dict(d):
10663+
return {
10664+
k: v if not isinstance(v, dict) else copy_dict(v)
10665+
for k, v in d.items()
10666+
}
10667+
10668+
_consolidated["metadata"] = copy_dict(self._consolidated["metadata"])
10669+
10670+
slice_map = split_storage(_consolidated)
10671+
10672+
def view_as(src, dest):
10673+
return src.view(dest.dtype)[: dest.numel()].view(dest.shape)
1058810674

10675+
def set_(name, x):
10676+
if not isinstance(name, tuple):
10677+
name = (name,)
10678+
if x.is_nested:
10679+
from torch._subclasses.fake_tensor import FakeTensor
10680+
from torch._subclasses.functional_tensor import FunctionalTensor
10681+
from torch.nested._internal.nested_tensor import (
10682+
_tensor_symint_registry,
10683+
NestedTensor,
10684+
)
10685+
from torch.nested._internal.ops import extract_kwargs
10686+
10687+
if x.layout != torch.jagged:
10688+
raise RuntimeError(
10689+
"to(device) with nested tensors that do not have a jagged layout is not implemented yet. "
10690+
"Please raise an issue on GitHub."
10691+
)
10692+
kwargs = extract_kwargs(x)
10693+
values = x._values
10694+
lengths = x._lengths
10695+
offsets = x._offsets
10696+
storage_offsets = slice_map[
10697+
(
10698+
*name[:-1],
10699+
"<NJT_OFFSETS>" + name[-1],
10700+
)
10701+
]
10702+
kwargs["offsets"] = view_as(storage_offsets, offsets)
10703+
if lengths is not None:
10704+
storage_lengths = slice_map[
10705+
(
10706+
*name[:-1],
10707+
"<NJT_LENGTHS>" + name[-1],
10708+
)
10709+
]
10710+
kwargs["lengths"] = view_as(storage_lengths, lengths)
10711+
ragged_source = lengths
10712+
else:
10713+
ragged_source = offsets
10714+
new_thing = kwargs.get("lengths", kwargs.get("offsets"))
10715+
if isinstance(new_thing, (FakeTensor, FunctionalTensor)):
10716+
from torch._subclasses.functional_tensor import (
10717+
mb_unwrap_functional_tensor,
10718+
)
10719+
10720+
# Temporary hack until we have the union find
10721+
tgt = mb_unwrap_functional_tensor(new_thing)
10722+
src = mb_unwrap_functional_tensor(ragged_source)
10723+
tgt.nested_int_memo = src.nested_int_memo
10724+
else:
10725+
_tensor_symint_registry[new_thing] = _tensor_symint_registry[
10726+
ragged_source
10727+
]
10728+
10729+
storage_values = slice_map[
10730+
(
10731+
*name[:-1],
10732+
"<NJT_VALUES>" + name[-1],
10733+
)
10734+
]
10735+
return NestedTensor(
10736+
view_as(storage_values, values),
10737+
**kwargs,
10738+
)
10739+
return view_as(slice_map[name], x)
10740+
10741+
result = self._fast_apply(
10742+
set_,
10743+
device=torch.device(device),
10744+
num_threads=num_threads,
10745+
named=True,
10746+
nested_keys=True,
10747+
)
10748+
result._consolidated = _consolidated
1058910749
return result
1059010750

1059110751
def _sync_all(self):

0 commit comments

Comments
 (0)