Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Feb 12, 2025
1 parent 89d7127 commit 468b4be
Show file tree
Hide file tree
Showing 6 changed files with 113 additions and 26 deletions.
70 changes: 53 additions & 17 deletions tensordict/_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
_get_shape_from_args,
_getitem_batch_size,
_is_number,
_maybe_correct_neg_dim,
_parse_to,
_renamed_inplace_method,
_shape,
Expand Down Expand Up @@ -292,6 +293,49 @@ def __init__(
if stack_dim_name is not None:
self._td_dim_name = stack_dim_name

@classmethod
def _new_lazy_unsafe(
cls,
*tensordicts: T,
stack_dim: int = 0,
hook_out: callable | None = None,
hook_in: callable | None = None,
batch_size: Sequence[int] | None = None,
device: torch.device | None = None,
names: Sequence[str] | None = None,
stack_dim_name: str | None = None,
strict_shape: bool = False,
) -> None:
self = cls.__new__(cls)
self._is_locked = None

# sanity check
num_tds = len(tensordicts)
batch_size = torch.Size(batch_size) if batch_size is not None else None
if not num_tds:
# create an empty tensor
td0 = TensorDict(batch_size=batch_size, device=device, names=names)
self._device = torch.device(device) if device is not None else None
else:
td0 = tensordicts[0]
# device = td0.device
_batch_size = td0.batch_size

for td in tensordicts[1:]:
_bs = td.batch_size
if _bs != _batch_size:
_batch_size = torch.Size(
[s if _bs[i] == s else -1 for i, s in enumerate(_batch_size)]
)
self.tensordicts: list[TensorDictBase] = list(tensordicts)
self.stack_dim = stack_dim
self._batch_size = self._compute_batch_size(_batch_size, stack_dim, num_tds)
self.hook_out = hook_out
self.hook_in = hook_in
if stack_dim_name is not None:
self._td_dim_name = stack_dim_name
return self

# These attributes should never be set
@property
def _is_shared(self):
Expand Down Expand Up @@ -633,7 +677,9 @@ def _split_index(self, index):
encountered_tensor = False
for i, idx in enumerate(index): # noqa: B007
cursor_incr = 1
if idx is None:
# if idx is None:
# idx = True
if idx is None or idx is True:
out.append(None)
num_none += cursor <= self.stack_dim
continue
Expand Down Expand Up @@ -1675,6 +1721,8 @@ def _iterate_over_keys(self) -> None:

@cache # noqa: B019
def _key_list(self):
if not self.tensordicts:
return []
keys = set(self.tensordicts[0].keys())
for td in self.tensordicts[1:]:
keys = keys.intersection(td.keys())
Expand Down Expand Up @@ -2099,15 +2147,6 @@ def assign(converted_idx, value=value):
value_unbind,
):
if mask.any():
assert (
self.tensordicts[i][_idx].shape
== torch.zeros(self.tensordicts[i].shape)[_idx].shape
), (
self.tensordicts[i].shape,
_idx,
self.tensordicts[i][_idx],
torch.zeros(self.tensordicts[i].shape)[_idx].shape,
)
self.tensordicts[i][_idx] = _value
else:
for (i, _idx), _value in _zip_strict(
Expand Down Expand Up @@ -2160,7 +2199,7 @@ def __getitem__(self, index: IndexType) -> Any:
batch_size = _getitem_batch_size(self.batch_size, index)
else:
batch_size = None
return LazyStackedTensorDict(
return self._new_lazy_unsafe(
*result,
stack_dim=cat_dim,
device=self.device,
Expand Down Expand Up @@ -3203,10 +3242,7 @@ def _unsqueeze(self, dim):
return result

def split(self, split_size: int | list[int], dim: int = 0) -> list[TensorDictBase]:
if dim < 0:
dim = self.ndim + dim
if dim < 0 or dim > self.ndim - 1:
raise ValueError(f"Out-of-bounds dim value: {dim}.")
dim = _maybe_correct_neg_dim(dim, shape=self.shape)
if dim == self.stack_dim:
if isinstance(split_size, int):
split_size = [split_size] * -(len(self.tensordicts) // -split_size)
Expand All @@ -3217,15 +3253,15 @@ def iter_across_tds():
for s in split_size:
if s == 0:
batch_size = list(self._batch_size)
batch_size.pop(self.stack_dim)
batch_size[self.stack_dim] = 0
yield LazyStackedTensorDict(
batch_size=batch_size,
device=self.device,
stack_dim=self.stack_dim,
)
continue
stop = start + s
yield LazyStackedTensorDict(
yield self._new_lazy_unsafe(
*self.tensordicts[slice(start, stop)], stack_dim=self.stack_dim
)
start = stop
Expand Down
16 changes: 16 additions & 0 deletions tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -2518,7 +2518,23 @@ def _set_at_str(self, key, value, idx, *, validated, non_blocking: bool):
tensor_in = self._get_str(key, NO_DEFAULT)

if is_non_tensor(value) and not (self._is_shared or self._is_memmap):
if isinstance(idx, tuple) and len(idx) == 1:
idx = idx[0]
dest = tensor_in
if (
isinstance(idx, torch.Tensor)
and idx.shape == ()
and self.shape == ()
and idx.dtype == torch.bool
and idx
):
self._set_str(
key,
dest.squeeze(0),
validated=True,
inplace=False,
ignore_lock=True,
)
is_diff = dest[idx].tolist() != value.tolist()
if is_diff:
dest_val = dest.maybe_to_stack()
Expand Down
2 changes: 1 addition & 1 deletion tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6593,7 +6593,7 @@ def update(
value = tree_map(torch.clone, value)
# the key must be a string by now. Let's check if it is present
if target is not None:
if not is_leaf(type(target)):
if not is_leaf(type(target)) and not is_leaf(type(value)):
if subkey:
sub_keys_to_update = _prune_selected_keys(
keys_to_update, firstkey
Expand Down
25 changes: 20 additions & 5 deletions tensordict/tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -1571,7 +1571,6 @@ def wrapped_func(self, *args, **kwargs):
td = super(type(self), self).__getattribute__("_tensordict")
else:
td = self._tensordict

result = getattr(td, funcname)(*args, **kwargs)
if no_wrap:
return result
Expand Down Expand Up @@ -1776,10 +1775,26 @@ def _setitem(self, item: NestedKey, value: Any) -> None: # noqa: D417
value (any): value to set for the item
"""
if isinstance(item, str) or (
isinstance(item, tuple) and all(isinstance(_item, str) for _item in item)
):
raise ValueError(f"Invalid indexing arguments: {item}.")
istuple = isinstance(item, tuple)
if istuple or isinstance(item, str):
# _unravel_key_to_tuple will return an empty tuple if the index isn't a NestedKey
idx_unravel = _unravel_key_to_tuple(item)
if idx_unravel:
raise ValueError(f"Invalid indexing arguments: {item}.")

if istuple and len(item) == 1:
return _setitem(self, item[0], value)
if (
(
isinstance(item, torch.Tensor)
and item.dtype == torch.bool
and not item.shape
and item
)
or (item is True)
or (item is None)
) and self.batch_size == ():
return self.update(value.squeeze(0))

if not is_tensorclass(value) and not isinstance(
value, (TensorDictBase, numbers.Number, Tensor)
Expand Down
14 changes: 11 additions & 3 deletions tensordict/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1243,7 +1243,11 @@ def func_as_decorator(_self, *args, **kwargs):
if out is not None:
if _attr_post is not _attr_pre:
ref = weakref.ref(_self)
out._last_op = (
if is_tensorclass(out):
out_lo = out._tensordict
else:
out_lo = out
out_lo._last_op = (
func.__name__,
(
args,
Expand All @@ -1262,7 +1266,11 @@ def func_as_decorator(_self, *args, **kwargs):
out = func(_self, *args, **kwargs)
if out is not None:
ref = weakref.ref(_self)
out._last_op = (func.__name__, (args, kwargs, ref))
if is_tensorclass(out):
out_lo = out._tensordict
else:
out_lo = out
out_lo._last_op = (func.__name__, (args, kwargs, ref))
return out

return func_as_decorator
Expand Down Expand Up @@ -2023,7 +2031,7 @@ def _getitem_batch_size(batch_size, index):
out = []
count = -1
for i, idx in enumerate(index):
if idx is None:
if idx is True or idx is None:
out.append(1)
continue
count += 1 if not bools[i] else idx.ndim
Expand Down
12 changes: 12 additions & 0 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -11299,6 +11299,18 @@ def test_set(self, non_tensor_data):
== "another string!"
)

def test_setitem_edge_case(self):
s = NonTensorStack("a string")
t = NonTensorStack("another string")
s[0][True] = t
assert s[0].data == "another string"
for i in (None, True):
s = NonTensorStack("0", "1")
t = NonTensorStack(NonTensorStack("2", "3"), stack_dim=1)
assert t.batch_size == (2, 1)
s[:, i] = t
assert s.tolist() == ["2", "3"]

def test_stack(self, non_tensor_data):
assert (
LazyStackedTensorDict.lazy_stack([non_tensor_data, non_tensor_data], 0).get(
Expand Down

0 comments on commit 468b4be

Please sign in to comment.