Skip to content

Commit 7c5ba3d

Browse files
authored
[Refactor] Use _set_tuple for faster set (pytorch#1372)
1 parent 8d913b8 commit 7c5ba3d

File tree

3 files changed

+34
-22
lines changed

3 files changed

+34
-22
lines changed

torchrl/envs/utils.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -231,15 +231,15 @@ def _set_single_key(source, dest, key, clone=False):
231231
new_val = dest.get(k, None)
232232
if new_val is None:
233233
new_val = _clone_no_keys(val)
234-
dest.set(k, new_val)
235-
# dest._set_tuple(k, new_val, inplace=False, validated=True)
234+
# dest.set(k, new_val)
235+
dest._set_str(k, new_val, inplace=False, validated=True)
236236
source = val
237237
dest = new_val
238238
else:
239239
if clone:
240240
val = val.clone()
241-
dest.set(k, val)
242-
# dest._set_tuple(k, val, inplace=False, validated=True)
241+
# dest.set(k, val)
242+
dest._set_str(k, val, inplace=False, validated=True)
243243

244244

245245
def _set(source, dest, key, total_key, excluded):
@@ -257,13 +257,13 @@ def _set(source, dest, key, total_key, excluded):
257257
_set(val, new_val, subkey, total_key, excluded) or non_empty_local
258258
)
259259
if non_empty_local:
260-
dest.set(key, new_val)
261-
# dest._set_tuple(key, new_val, inplace=False, validated=True)
260+
# dest.set(key, new_val)
261+
dest._set_str(key, new_val, inplace=False, validated=True)
262262
non_empty = non_empty_local
263263
else:
264264
non_empty = True
265-
dest.set(key, val)
266-
# dest._set_tuple(key, val, inplace=False, validated=True)
265+
# dest.set(key, val)
266+
dest._set_str(key, val, inplace=False, validated=True)
267267
return non_empty
268268

269269

torchrl/envs/vec_env.py

+18-9
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import torch
2121

2222
from tensordict import TensorDict, unravel_key
23+
from tensordict._tensordict import _unravel_key_to_tuple
2324
from tensordict.tensordict import LazyStackedTensorDict, TensorDictBase
2425
from torch import multiprocessing as mp
2526
from torchrl._utils import _check_for_faulty_process, VERBOSE
@@ -761,16 +762,18 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
761762
if self._single_task:
762763
# this is faster than update_ but won't work for lazy stacks
763764
for key in self.env_input_keys:
764-
self.shared_tensordict_parent.set(
765-
key,
766-
tensordict.get(key),
767-
inplace=True,
768-
)
769-
# self.shared_tensordict_parent._set_tuple(
765+
# self.shared_tensordict_parent.set(
770766
# key,
771-
# tensordict._get_tuple(key, None),
767+
# tensordict.get(key),
772768
# inplace=True,
773769
# )
770+
key = _unravel_key_to_tuple(key)
771+
self.shared_tensordict_parent._set_tuple(
772+
key,
773+
tensordict._get_tuple(key, None),
774+
inplace=True,
775+
validated=True,
776+
)
774777
else:
775778
self.shared_tensordict_parent.update_(
776779
tensordict.select(*self.env_input_keys, strict=False)
@@ -1062,8 +1065,14 @@ def _run_worker_pipe_shared_mem(
10621065
i += 1
10631066
if local_tensordict is not None:
10641067
for key in env_input_keys:
1065-
local_tensordict.set(key, shared_tensordict.get(key))
1066-
# local_tensordict._set_tuple(key, shared_tensordict._get_tuple(key, None), inplace=False, validated=True)
1068+
# local_tensordict.set(key, shared_tensordict.get(key))
1069+
key = _unravel_key_to_tuple(key)
1070+
local_tensordict._set_tuple(
1071+
key,
1072+
shared_tensordict._get_tuple(key, None),
1073+
inplace=False,
1074+
validated=True,
1075+
)
10671076
else:
10681077
local_tensordict = shared_tensordict.clone(recurse=False)
10691078
local_tensordict = env._step(local_tensordict)

torchrl/objectives/common.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -412,8 +412,10 @@ def _param_getter(self, network_name):
412412
value_to_set = getattr(self, value_to_set).detach()
413413
else:
414414
value_to_set = getattr(self, value_to_set)
415-
params.set(key, value_to_set)
416-
# params._set_tuple(key, value_to_set, inplace=False, validated=True)
415+
# params.set(key, value_to_set)
416+
params._set_tuple(
417+
key, value_to_set, inplace=False, validated=True
418+
)
417419
return params
418420
else:
419421
params = getattr(self, param_name)
@@ -448,9 +450,10 @@ def _target_param_getter(self, network_name):
448450
value_to_set = getattr(
449451
self, self.SEP.join(["_target_" + network_name, *key])
450452
)
451-
# _set is faster bc is bypasses the checks
452-
target_params.set(key, value_to_set)
453-
# target_params._set_tuple(key, value_to_set, inplace=False, validated=True)
453+
# target_params.set(key, value_to_set)
454+
target_params._set_tuple(
455+
key, value_to_set, inplace=False, validated=True
456+
)
454457
return target_params
455458
else:
456459
params = getattr(self, param_name)

0 commit comments

Comments
 (0)