Skip to content

Commit c730f7a

Browse files
committed
[Performance] Make _to_consolidated compatible with compile
ghstack-source-id: c0f6116 Pull Request resolved: #1041
1 parent fe6db77 commit c730f7a

File tree

1 file changed

+141
-17
lines changed

1 file changed

+141
-17
lines changed

tensordict/base.py

+141-17
Original file line numberDiff line numberDiff line change
@@ -3521,9 +3521,10 @@ def _reduce_vals_and_metadata(self, *, dtype=NO_DEFAULT, requires_metadata):
35213521

35223522
flat_size = []
35233523
start = 0
3524+
sorting_index = 0
35243525

35253526
def add_single_value(value, key, metadata_dict, dtype, shape, flat_size):
3526-
nonlocal start
3527+
nonlocal start, sorting_index
35273528
n = value.element_size() * value.numel()
35283529
if need_padding:
35293530
pad = n % 8
@@ -3541,7 +3542,10 @@ def add_single_value(value, key, metadata_dict, dtype, shape, flat_size):
35413542
start,
35423543
stop,
35433544
pad,
3545+
flat_size[-1],
3546+
sorting_index,
35443547
)
3548+
sorting_index = sorting_index + 1
35453549
start = stop
35463550

35473551
def assign(
@@ -10395,6 +10399,7 @@ def to(self, *args, **kwargs) -> T:
1039510399
pin_memory=non_blocking_pin,
1039610400
num_threads=num_threads,
1039710401
non_blocking=non_blocking,
10402+
compilable=is_dynamo_compiling(),
1039810403
)
1039910404

1040010405
if non_blocking is None:
@@ -10452,14 +10457,42 @@ def to_pinmem(tensor, _to=to):
1045210457
self._sync_all()
1045310458
return result
1045410459

10455-
def _to_consolidated(self, *, device, pin_memory, num_threads, non_blocking):
10460+
def _to_consolidated(self, *, device, pin_memory, num_threads, non_blocking, compilable):
1045610461
if num_threads is None:
1045710462
# unspecified num_threads should mean 0
1045810463
num_threads = 0
10464+
1045910465
storage = self._consolidated["storage"]
10460-
if pin_memory:
10461-
storage = storage.pin_memory()
10462-
storage_cast = storage.to(device, non_blocking=True)
10466+
10467+
@torch.compiler.disable()
10468+
def to(storage):
10469+
if pin_memory:
10470+
storage = storage.pin_memory()
10471+
storage_cast = storage.to(device, non_blocking=True)
10472+
return storage_cast
10473+
storage_cast = to(storage)
10474+
10475+
if compilable:
10476+
result = self._to_consolidated_compile(device=device, num_threads=num_threads, storage_cast=storage_cast)
10477+
else:
10478+
result = self._to_consolidated_eager(device=device, num_threads=num_threads, storage_cast=storage_cast)
10479+
10480+
if non_blocking in (False, None):
10481+
if device.type == "cuda" and non_blocking is False:
10482+
# sending to CUDA force sync
10483+
cuda_device = device
10484+
elif storage.device.type == "cuda":
10485+
# sending from cuda: need sync unless intentionally not asked for
10486+
cuda_device = storage.device.type
10487+
else:
10488+
cuda_device = None
10489+
if cuda_device is not None:
10490+
torch.cuda.current_stream(cuda_device).synchronize()
10491+
10492+
return result
10493+
10494+
def _to_consolidated_eager(self, *, device, num_threads, storage_cast):
10495+
1046310496
untyped_storage = storage_cast.untyped_storage()
1046410497

1046510498
def set_(x):
@@ -10528,20 +10561,111 @@ def copy_dict(d):
1052810561
}
1052910562

1053010563
result._consolidated["metadata"] = copy_dict(self._consolidated["metadata"])
10531-
if non_blocking in (False, None):
10532-
if device.type == "cuda" and non_blocking is False:
10533-
# sending to CUDA force sync
10534-
cuda_device = device
10535-
elif storage.device.type == "cuda":
10536-
# sending from cuda: need sync unless intentionally not asked for
10537-
cuda_device = storage.device.type
10538-
else:
10539-
cuda_device = None
10540-
if cuda_device is not None:
10541-
torch.cuda.current_stream(cuda_device).synchronize()
10542-
1054310564
return result
1054410565

10566+
def _to_consolidated_compile(self, *, device, num_threads, storage_cast):
10567+
10568+
def get_tensors_length(metadata, lengths=None, pos=None, keys=None, prefix=()):
10569+
root = False
10570+
if lengths is None:
10571+
lengths = []
10572+
pos = []
10573+
keys = []
10574+
root = True
10575+
for k, v in metadata["leaves"].items():
10576+
lengths.append(v[-2])
10577+
pos.append(v[-1])
10578+
keys.append(prefix + (k,))
10579+
for k, d in metadata.items():
10580+
if "leaves" in d:
10581+
get_tensors_length(d, lengths=lengths, pos=pos, keys=keys, prefix=prefix + (k,))
10582+
if root:
10583+
# l = torch.empty(len(lengths), dtype=torch.long)
10584+
# l[torch.as_tensor(pos)] = torch.as_tensor(lengths)
10585+
out0 = [None, ] * len(pos)
10586+
out1 = [None, ] * len(pos)
10587+
for p, l, k in zip(pos, lengths, keys):
10588+
out0[p] = k
10589+
out1[p] = l
10590+
return out0, out1
10591+
10592+
def split_storage(consolidated):
10593+
keys, splits = get_tensors_length(consolidated["metadata"])
10594+
return dict(zip(keys, consolidated["storage"].split(splits)))
10595+
10596+
if num_threads is None:
10597+
# unspecified num_threads should mean 0
10598+
num_threads = 0
10599+
10600+
_consolidated = {"storage": storage_cast}
10601+
if "metadata" in self._consolidated:
10602+
# faster than deepcopy
10603+
def copy_dict(d):
10604+
return {
10605+
k: v if not isinstance(v, dict) else copy_dict(v)
10606+
for k, v in d.items()
10607+
}
10608+
10609+
_consolidated["metadata"] = copy_dict(self._consolidated["metadata"])
10610+
10611+
slice_map = split_storage(_consolidated)
10612+
10613+
def set_(name, x):
10614+
if not isinstance(name, tuple):
10615+
name = (name,)
10616+
if x.is_nested:
10617+
from torch._subclasses.fake_tensor import FakeTensor
10618+
from torch._subclasses.functional_tensor import FunctionalTensor
10619+
from torch.nested._internal.nested_tensor import (
10620+
_tensor_symint_registry,
10621+
NestedTensor,
10622+
)
10623+
from torch.nested._internal.ops import extract_kwargs
10624+
10625+
if x.layout != torch.jagged:
10626+
raise RuntimeError(
10627+
"to(device) with nested tensors that do not have a jagged layout is not implemented yet. "
10628+
"Please raise an issue on GitHub."
10629+
)
10630+
kwargs = extract_kwargs(x)
10631+
values = x._values
10632+
lengths = x._lengths
10633+
offsets = x._offsets
10634+
storage_offsets = slice_map[(*name[:-1], "<NJT_OFFSETS>"+name[-1],)]
10635+
kwargs["offsets"] = storage_offsets.view(offsets.dtype).view(offsets.shape)
10636+
if lengths is not None:
10637+
storage_lengths = slice_map[(*name[:-1], "<NJT_LENGTHS>"+name[-1],)]
10638+
kwargs["lengths"] = storage_lengths.view(lengths.dtype).view(lengths.shape)
10639+
ragged_source = lengths
10640+
else:
10641+
ragged_source = offsets
10642+
new_thing = kwargs.get("lengths", kwargs.get("offsets"))
10643+
if isinstance(new_thing, (FakeTensor, FunctionalTensor)):
10644+
from torch._subclasses.functional_tensor import (
10645+
mb_unwrap_functional_tensor,
10646+
)
10647+
10648+
# Temporary hack until we have the union find
10649+
tgt = mb_unwrap_functional_tensor(new_thing)
10650+
src = mb_unwrap_functional_tensor(ragged_source)
10651+
tgt.nested_int_memo = src.nested_int_memo
10652+
else:
10653+
_tensor_symint_registry[new_thing] = _tensor_symint_registry[
10654+
ragged_source
10655+
]
10656+
10657+
storage_values = slice_map[(*name[:-1], "<NJT_VALUES>"+name[-1],)]
10658+
return NestedTensor(
10659+
storage_values.view(values.dtype).view(values.shape),
10660+
**kwargs,
10661+
)
10662+
return slice_map[name].view(x.dtype).view(x.shape)
10663+
10664+
result = self._fast_apply(
10665+
set_, device=torch.device(device), num_threads=num_threads, named=True, nested_keys=True,
10666+
)
10667+
result._consolidated = _consolidated
10668+
return result
1054510669
def _sync_all(self):
1054610670
if _has_cuda:
1054710671
# TODO: dynamo doesn't like torch.cuda.is_initialized

0 commit comments

Comments
 (0)