Skip to content

Commit d02f5a1

Browse files
authoredMar 8, 2023
[Refactor] Refactor the step to include reward and done in the 'next' tensordict (pytorch#941)
1 parent cdc6798 commit d02f5a1

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

56 files changed

+1756
-1444
lines changed
 

‎README.md

+183-171
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,13 @@ It contains tutorials and the API reference.
3737
TorchRL relies on [`TensorDict`](https://github.com/pytorch-labs/tensordict/),
3838
a convenient data structure<sup>(1)</sup> to pass data from
3939
one object to another without friction.
40+
41+
42+
Here is an example of how the [environment API](https://pytorch.org/rl/reference/envs.html)
43+
relies on tensordict to carry data from one function to another during a rollout
44+
execution:
45+
![Alt Text](docs/source/_static/img/rollout.gif)
46+
4047
`TensorDict` makes it easy to re-use pieces of code across environments, models and
4148
algorithms. For instance, here's how to code a rollout in TorchRL:
4249
<details>
@@ -156,202 +163,207 @@ The associated [`SafeModule` class](torchrl/modules/tensordict_module/common.py)
156163

157164
## Features
158165

159-
- a generic [trainer class](torchrl/trainers/trainers.py)<sup>(1)</sup> that
160-
executes the aforementioned training loop. Through a hooking mechanism,
161-
it also supports any logging or data transformation operation at any given
162-
time.
163-
164166
- A common [interface for environments](torchrl/envs)
165-
which supports common libraries (OpenAI gym, deepmind control lab, etc.)<sup>(1)</sup> and state-less execution (e.g. Model-based environments).
166-
The [batched environments](torchrl/envs/vec_env.py) containers allow parallel execution<sup>(2)</sup>.
167-
A common pytorch-first class of [tensor-specification class](torchrl/data/tensor_specs.py) is also provided.
168-
<details>
169-
<summary>Code</summary>
170-
171-
```python
172-
env_make = lambda: GymEnv("Pendulum-v1", from_pixels=True)
173-
env_parallel = ParallelEnv(4, env_make) # creates 4 envs in parallel
174-
tensordict = env_parallel.rollout(max_steps=20, policy=None) # random rollout (no policy given)
175-
assert tensordict.shape == [4, 20] # 4 envs, 20 steps rollout
176-
env_parallel.action_spec.is_in(tensordict["action"]) # spec check returns True
177-
```
178-
</details>
167+
which supports common libraries (OpenAI gym, deepmind control lab, etc.)<sup>(1)</sup> and state-less execution
168+
(e.g. Model-based environments).
169+
The [batched environments](torchrl/envs/vec_env.py) containers allow parallel execution<sup>(2)</sup>.
170+
A common pytorch-first class of [tensor-specification class](torchrl/data/tensor_specs.py) is also provided.
171+
TorchRL's environments API is simple but stringent and specific. Check the
172+
[documentation](https://pytorch.org/rl/reference/envs.html)
173+
and [tutorial](https://pytorch.org/rl/tutorials/pendulum.html) to learn more!
174+
<details>
175+
<summary>Code</summary>
176+
177+
```python
178+
env_make = lambda: GymEnv("Pendulum-v1", from_pixels=True)
179+
env_parallel = ParallelEnv(4, env_make) # creates 4 envs in parallel
180+
tensordict = env_parallel.rollout(max_steps=20, policy=None) # random rollout (no policy given)
181+
assert tensordict.shape == [4, 20] # 4 envs, 20 steps rollout
182+
env_parallel.action_spec.is_in(tensordict["action"]) # spec check returns True
183+
```
184+
</details>
179185

180186
- multiprocess [data collectors](torchrl/collectors/collectors.py)<sup>(2)</sup> that work synchronously or asynchronously.
181-
Through the use of TensorDict, TorchRL's training loops are made very similar to regular training loops in supervised
182-
learning (although the "dataloader" -- read data collector -- is modified on-the-fly):
183-
<details>
184-
<summary>Code</summary>
185-
186-
```python
187-
env_make = lambda: GymEnv("Pendulum-v1", from_pixels=True)
188-
collector = MultiaSyncDataCollector(
189-
[env_make, env_make],
190-
policy=policy,
191-
devices=["cuda:0", "cuda:0"],
192-
total_frames=10000,
193-
frames_per_batch=50,
194-
...
195-
)
196-
for i, tensordict_data in enumerate(collector):
197-
loss = loss_module(tensordict_data)
198-
loss.backward()
199-
optim.step()
200-
optim.zero_grad()
201-
collector.update_policy_weights_()
202-
```
203-
</details>
187+
Through the use of TensorDict, TorchRL's training loops are made very similar to regular training loops in supervised
188+
learning (although the "dataloader" -- read data collector -- is modified on-the-fly):
189+
<details>
190+
<summary>Code</summary>
191+
192+
```python
193+
env_make = lambda: GymEnv("Pendulum-v1", from_pixels=True)
194+
collector = MultiaSyncDataCollector(
195+
[env_make, env_make],
196+
policy=policy,
197+
devices=["cuda:0", "cuda:0"],
198+
total_frames=10000,
199+
frames_per_batch=50,
200+
...
201+
)
202+
for i, tensordict_data in enumerate(collector):
203+
loss = loss_module(tensordict_data)
204+
loss.backward()
205+
optim.step()
206+
optim.zero_grad()
207+
collector.update_policy_weights_()
208+
```
209+
</details>
204210

205211
- efficient<sup>(2)</sup> and generic<sup>(1)</sup> [replay buffers](torchrl/data/replay_buffers/replay_buffers.py) with modularized storage:
206-
<details>
207-
<summary>Code</summary>
208-
209-
```python
210-
storage = LazyMemmapStorage( # memory-mapped (physical) storage
211-
cfg.buffer_size,
212-
scratch_dir="/tmp/"
213-
)
214-
buffer = TensorDictPrioritizedReplayBuffer(
215-
alpha=0.7,
216-
beta=0.5,
217-
collate_fn=lambda x: x,
218-
pin_memory=device != torch.device("cpu"),
219-
prefetch=10, # multi-threaded sampling
220-
storage=storage
221-
)
222-
```
223-
</details>
212+
<details>
213+
<summary>Code</summary>
214+
215+
```python
216+
storage = LazyMemmapStorage( # memory-mapped (physical) storage
217+
cfg.buffer_size,
218+
scratch_dir="/tmp/"
219+
)
220+
buffer = TensorDictPrioritizedReplayBuffer(
221+
alpha=0.7,
222+
beta=0.5,
223+
collate_fn=lambda x: x,
224+
pin_memory=device != torch.device("cpu"),
225+
prefetch=10, # multi-threaded sampling
226+
storage=storage
227+
)
228+
```
229+
</details>
224230

225231
- cross-library [environment transforms](torchrl/envs/transforms/transforms.py)<sup>(1)</sup>,
226-
executed on device and in a vectorized fashion<sup>(2)</sup>,
227-
which process and prepare the data coming out of the environments to be used by the agent:
228-
<details>
229-
<summary>Code</summary>
230-
231-
```python
232-
env_make = lambda: GymEnv("Pendulum-v1", from_pixels=True)
233-
env_base = ParallelEnv(4, env_make, device="cuda:0") # creates 4 envs in parallel
234-
env = TransformedEnv(
235-
env_base,
236-
Compose(
237-
ToTensorImage(),
238-
ObservationNorm(loc=0.5, scale=1.0)), # executes the transforms once and on device
239-
)
240-
tensordict = env.reset()
241-
assert tensordict.device == torch.device("cuda:0")
242-
```
243-
Other transforms include: reward scaling (`RewardScaling`), shape operations (concatenation of tensors, unsqueezing etc.), contatenation of
244-
successive operations (`CatFrames`), resizing (`Resize`) and many more.
245-
246-
Unlike other libraries, the transforms are stacked as a list (and not wrapped in each other), which makes it
247-
easy to add and remove them at will:
248-
```python
249-
env.insert_transform(0, NoopResetEnv()) # inserts the NoopResetEnv transform at the index 0
250-
```
251-
Nevertheless, transforms can access and execute operations on the parent environment:
252-
```python
253-
transform = env.transform[1] # gathers the second transform of the list
254-
parent_env = transform.parent # returns the base environment of the second transform, i.e. the base env + the first transform
255-
```
256-
</details>
232+
executed on device and in a vectorized fashion<sup>(2)</sup>,
233+
which process and prepare the data coming out of the environments to be used by the agent:
234+
<details>
235+
<summary>Code</summary>
236+
237+
```python
238+
env_make = lambda: GymEnv("Pendulum-v1", from_pixels=True)
239+
env_base = ParallelEnv(4, env_make, device="cuda:0") # creates 4 envs in parallel
240+
env = TransformedEnv(
241+
env_base,
242+
Compose(
243+
ToTensorImage(),
244+
ObservationNorm(loc=0.5, scale=1.0)), # executes the transforms once and on device
245+
)
246+
tensordict = env.reset()
247+
assert tensordict.device == torch.device("cuda:0")
248+
```
249+
Other transforms include: reward scaling (`RewardScaling`), shape operations (concatenation of tensors, unsqueezing etc.), contatenation of
250+
successive operations (`CatFrames`), resizing (`Resize`) and many more.
251+
252+
Unlike other libraries, the transforms are stacked as a list (and not wrapped in each other), which makes it
253+
easy to add and remove them at will:
254+
```python
255+
env.insert_transform(0, NoopResetEnv()) # inserts the NoopResetEnv transform at the index 0
256+
```
257+
Nevertheless, transforms can access and execute operations on the parent environment:
258+
```python
259+
transform = env.transform[1] # gathers the second transform of the list
260+
parent_env = transform.parent # returns the base environment of the second transform, i.e. the base env + the first transform
261+
```
262+
</details>
257263

258264
- various tools for distributed learning (e.g. [memory mapped tensors](https://github.com/pytorch-labs/tensordict/blob/main/tensordict/memmap.py))<sup>(2)</sup>;
259265
- various [architectures](torchrl/modules/models/) and models (e.g. [actor-critic](torchrl/modules/tensordict_module/actors.py))<sup>(1)</sup>:
260-
<details>
261-
<summary>Code</summary>
262-
263-
```python
264-
# create an nn.Module
265-
common_module = ConvNet(
266-
bias_last_layer=True,
267-
depth=None,
268-
num_cells=[32, 64, 64],
269-
kernel_sizes=[8, 4, 3],
270-
strides=[4, 2, 1],
271-
)
272-
# Wrap it in a SafeModule, indicating what key to read in and where to
273-
# write out the output
274-
common_module = SafeModule(
275-
common_module,
276-
in_keys=["pixels"],
277-
out_keys=["hidden"],
278-
)
279-
# Wrap the policy module in NormalParamsWrapper, such that the output
280-
# tensor is split in loc and scale, and scale is mapped onto a positive space
281-
policy_module = SafeModule(
282-
NormalParamsWrapper(
283-
MLP(num_cells=[64, 64], out_features=32, activation=nn.ELU)
284-
),
285-
in_keys=["hidden"],
286-
out_keys=["loc", "scale"],
287-
)
288-
# Use a SafeProbabilisticSequential to combine the SafeModule with a
289-
# SafeProbabilisticModule, indicating how to build the
290-
# torch.distribution.Distribution object and what to do with it
291-
policy_module = SafeProbabilisticSequential( # stochastic policy
292-
policy_module,
293-
SafeProbabilisticModule(
294-
in_keys=["loc", "scale"],
295-
out_keys="action",
296-
distribution_class=TanhNormal,
297-
),
298-
)
299-
value_module = MLP(
300-
num_cells=[64, 64],
301-
out_features=1,
302-
activation=nn.ELU,
303-
)
304-
# Wrap the policy and value funciton in a common module
305-
actor_value = ActorValueOperator(common_module, policy_module, value_module)
306-
# standalone policy from this
307-
standalone_policy = actor_value.get_policy_operator()
308-
```
309-
</details>
266+
<details>
267+
<summary>Code</summary>
268+
269+
```python
270+
# create an nn.Module
271+
common_module = ConvNet(
272+
bias_last_layer=True,
273+
depth=None,
274+
num_cells=[32, 64, 64],
275+
kernel_sizes=[8, 4, 3],
276+
strides=[4, 2, 1],
277+
)
278+
# Wrap it in a SafeModule, indicating what key to read in and where to
279+
# write out the output
280+
common_module = SafeModule(
281+
common_module,
282+
in_keys=["pixels"],
283+
out_keys=["hidden"],
284+
)
285+
# Wrap the policy module in NormalParamsWrapper, such that the output
286+
# tensor is split in loc and scale, and scale is mapped onto a positive space
287+
policy_module = SafeModule(
288+
NormalParamsWrapper(
289+
MLP(num_cells=[64, 64], out_features=32, activation=nn.ELU)
290+
),
291+
in_keys=["hidden"],
292+
out_keys=["loc", "scale"],
293+
)
294+
# Use a SafeProbabilisticSequential to combine the SafeModule with a
295+
# SafeProbabilisticModule, indicating how to build the
296+
# torch.distribution.Distribution object and what to do with it
297+
policy_module = SafeProbabilisticSequential( # stochastic policy
298+
policy_module,
299+
SafeProbabilisticModule(
300+
in_keys=["loc", "scale"],
301+
out_keys="action",
302+
distribution_class=TanhNormal,
303+
),
304+
)
305+
value_module = MLP(
306+
num_cells=[64, 64],
307+
out_features=1,
308+
activation=nn.ELU,
309+
)
310+
# Wrap the policy and value funciton in a common module
311+
actor_value = ActorValueOperator(common_module, policy_module, value_module)
312+
# standalone policy from this
313+
standalone_policy = actor_value.get_policy_operator()
314+
```
315+
</details>
310316

311317
- exploration [wrappers](torchrl/modules/tensordict_module/exploration.py) and
312-
[modules](torchrl/modules/models/exploration.py) to easily swap between exploration and exploitation<sup>(1)</sup>:
313-
<details>
314-
<summary>Code</summary>
315-
316-
```python
317-
policy_explore = EGreedyWrapper(policy)
318-
with set_exploration_mode("random"):
319-
tensordict = policy_explore(tensordict) # will use eps-greedy
320-
with set_exploration_mode("mode"):
321-
tensordict = policy_explore(tensordict) # will not use eps-greedy
322-
```
323-
</details>
318+
[modules](torchrl/modules/models/exploration.py) to easily swap between exploration and exploitation<sup>(1)</sup>:
319+
<details>
320+
<summary>Code</summary>
321+
322+
```python
323+
policy_explore = EGreedyWrapper(policy)
324+
with set_exploration_mode("random"):
325+
tensordict = policy_explore(tensordict) # will use eps-greedy
326+
with set_exploration_mode("mode"):
327+
tensordict = policy_explore(tensordict) # will not use eps-greedy
328+
```
329+
</details>
324330

325331
- A series of efficient [loss modules](https://github.com/pytorch/rl/blob/main/torchrl/objectives/costs)
326-
and highly vectorized
327-
[functional return and advantage](https://github.com/pytorch/rl/blob/main/torchrl/objectives/returns/functional.py)
328-
computation.
332+
and highly vectorized
333+
[functional return and advantage](https://github.com/pytorch/rl/blob/main/torchrl/objectives/returns/functional.py)
334+
computation.
335+
336+
<details>
337+
<summary>Code</summary>
329338

330-
<details>
331-
<summary>Code</summary>
339+
### Loss modules
340+
```python
341+
from torchrl.objectives import DQNLoss
342+
loss_module = DQNLoss(value_network=value_network, gamma=0.99)
343+
tensordict = replay_buffer.sample(batch_size)
344+
loss = loss_module(tensordict)
345+
```
332346

333-
### Loss modules
334-
```python
335-
from torchrl.objectives import DQNLoss
336-
loss_module = DQNLoss(value_network=value_network, gamma=0.99)
337-
tensordict = replay_buffer.sample(batch_size)
338-
loss = loss_module(tensordict)
339-
```
347+
### Advantage computation
348+
```python
349+
from torchrl.objectives.value.functional import vec_td_lambda_return_estimate
350+
advantage = vec_td_lambda_return_estimate(gamma, lmbda, next_state_value, reward, done)
351+
```
340352

341-
### Advantage computation
342-
```python
343-
from torchrl.objectives.value.functional import vec_td_lambda_return_estimate
344-
advantage = vec_td_lambda_return_estimate(gamma, lmbda, next_state_value, reward, done)
345-
```
353+
</details>
346354

347-
</details>
355+
- a generic [trainer class](torchrl/trainers/trainers.py)<sup>(1)</sup> that
356+
executes the aforementioned training loop. Through a hooking mechanism,
357+
it also supports any logging or data transformation operation at any given
358+
time.
348359

349360
- various [recipes](torchrl/trainers/helpers/models.py) to build models that
350361
correspond to the environment being deployed.
351362

352363
If you feel a feature is missing from the library, please submit an issue!
353364
If you would like to contribute to new features, check our [call for contributions](https://github.com/pytorch/rl/issues/509) and our [contribution](CONTRIBUTING.md) page.
354365

366+
355367
## Examples, tutorials and demos
356368

357369
A series of [examples](examples/) are provided with an illustrative purpose:

‎docs/source/_static/img/rollout.gif

840 KB
Loading

0 commit comments

Comments
 (0)
Please sign in to comment.