|
20 | 20 | import torch
|
21 | 21 |
|
22 | 22 | from tensordict import TensorDict, unravel_key
|
| 23 | +from tensordict._tensordict import _unravel_key_to_tuple |
23 | 24 | from tensordict.tensordict import LazyStackedTensorDict, TensorDictBase
|
24 | 25 | from torch import multiprocessing as mp
|
25 | 26 | from torchrl._utils import _check_for_faulty_process, VERBOSE
|
@@ -761,16 +762,18 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
|
761 | 762 | if self._single_task:
|
762 | 763 | # this is faster than update_ but won't work for lazy stacks
|
763 | 764 | for key in self.env_input_keys:
|
764 |
| - self.shared_tensordict_parent.set( |
765 |
| - key, |
766 |
| - tensordict.get(key), |
767 |
| - inplace=True, |
768 |
| - ) |
769 |
| - # self.shared_tensordict_parent._set_tuple( |
| 765 | + # self.shared_tensordict_parent.set( |
770 | 766 | # key,
|
771 |
| - # tensordict._get_tuple(key, None), |
| 767 | + # tensordict.get(key), |
772 | 768 | # inplace=True,
|
773 | 769 | # )
|
| 770 | + key = _unravel_key_to_tuple(key) |
| 771 | + self.shared_tensordict_parent._set_tuple( |
| 772 | + key, |
| 773 | + tensordict._get_tuple(key, None), |
| 774 | + inplace=True, |
| 775 | + validated=True, |
| 776 | + ) |
774 | 777 | else:
|
775 | 778 | self.shared_tensordict_parent.update_(
|
776 | 779 | tensordict.select(*self.env_input_keys, strict=False)
|
@@ -1062,8 +1065,14 @@ def _run_worker_pipe_shared_mem(
|
1062 | 1065 | i += 1
|
1063 | 1066 | if local_tensordict is not None:
|
1064 | 1067 | for key in env_input_keys:
|
1065 |
| - local_tensordict.set(key, shared_tensordict.get(key)) |
1066 |
| - # local_tensordict._set_tuple(key, shared_tensordict._get_tuple(key, None), inplace=False, validated=True) |
| 1068 | + # local_tensordict.set(key, shared_tensordict.get(key)) |
| 1069 | + key = _unravel_key_to_tuple(key) |
| 1070 | + local_tensordict._set_tuple( |
| 1071 | + key, |
| 1072 | + shared_tensordict._get_tuple(key, None), |
| 1073 | + inplace=False, |
| 1074 | + validated=True, |
| 1075 | + ) |
1067 | 1076 | else:
|
1068 | 1077 | local_tensordict = shared_tensordict.clone(recurse=False)
|
1069 | 1078 | local_tensordict = env._step(local_tensordict)
|
|
0 commit comments