Skip to content

Commit b210665

Browse files
[Feature] Heterogeneous Environments compatibility (pytorch#1411)
Signed-off-by: Matteo Bettini <[email protected]> Co-authored-by: vmoens <[email protected]>
1 parent 83dfff3 commit b210665

File tree

6 files changed

+459
-37
lines changed

6 files changed

+459
-37
lines changed

test/mocking_classes.py

+194-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import torch
88
import torch.nn as nn
99
from tensordict.tensordict import TensorDict, TensorDictBase
10-
from tensordict.utils import NestedKey
10+
from tensordict.utils import expand_right, NestedKey
1111

1212
from torchrl.data.tensor_specs import (
1313
BinaryDiscreteTensorSpec,
@@ -19,6 +19,7 @@
1919
TensorSpec,
2020
UnboundedContinuousTensorSpec,
2121
)
22+
from torchrl.data.utils import consolidate_spec
2223
from torchrl.envs.common import EnvBase
2324
from torchrl.envs.model_based.common import ModelBasedEnvBase
2425

@@ -1290,3 +1291,195 @@ def _step(
12901291
device=self.device,
12911292
)
12921293
return tensordict.select().set("next", tensordict)
1294+
1295+
1296+
class HeteroCountingEnvPolicy:
1297+
def __init__(self, full_action_spec: TensorSpec, count: bool = True):
1298+
self.full_action_spec = full_action_spec
1299+
self.count = count
1300+
1301+
def __call__(self, td: TensorDictBase) -> TensorDictBase:
1302+
action_td = self.full_action_spec.zero()
1303+
if self.count:
1304+
action_td.apply_(lambda x: x + 1)
1305+
return td.update(action_td)
1306+
1307+
1308+
class HeteroCountingEnv(EnvBase):
1309+
"""A heterogeneous, counting Env."""
1310+
1311+
def __init__(self, max_steps: int = 5, start_val: int = 0, **kwargs):
1312+
super().__init__(**kwargs)
1313+
self.n_nested_dim = 3
1314+
self.max_steps = max_steps
1315+
self.start_val = start_val
1316+
1317+
count = torch.zeros((*self.batch_size, 1), device=self.device, dtype=torch.int)
1318+
count[:] = self.start_val
1319+
1320+
self.register_buffer("count", count)
1321+
1322+
obs_specs = []
1323+
action_specs = []
1324+
for index in range(self.n_nested_dim):
1325+
obs_specs.append(self.get_agent_obs_spec(index))
1326+
action_specs.append(self.get_agent_action_spec(index))
1327+
obs_specs = torch.stack(obs_specs, dim=0)
1328+
obs_spec_unlazy = consolidate_spec(obs_specs)
1329+
action_specs = torch.stack(action_specs, dim=0)
1330+
1331+
self.unbatched_observation_spec = CompositeSpec(
1332+
lazy=obs_spec_unlazy,
1333+
state=UnboundedContinuousTensorSpec(
1334+
shape=(
1335+
64,
1336+
64,
1337+
3,
1338+
)
1339+
),
1340+
)
1341+
1342+
self.unbatched_action_spec = CompositeSpec(
1343+
lazy=action_specs,
1344+
)
1345+
self.unbatched_reward_spec = CompositeSpec(
1346+
{
1347+
"lazy": CompositeSpec(
1348+
{
1349+
"reward": UnboundedContinuousTensorSpec(
1350+
shape=(self.n_nested_dim, 1)
1351+
)
1352+
},
1353+
shape=(self.n_nested_dim,),
1354+
)
1355+
}
1356+
)
1357+
self.unbatched_done_spec = CompositeSpec(
1358+
{
1359+
"lazy": CompositeSpec(
1360+
{
1361+
"done": DiscreteTensorSpec(
1362+
n=2,
1363+
shape=(self.n_nested_dim, 1),
1364+
dtype=torch.bool,
1365+
),
1366+
},
1367+
shape=(self.n_nested_dim,),
1368+
)
1369+
}
1370+
)
1371+
1372+
self.action_spec = self.unbatched_action_spec.expand(
1373+
*self.batch_size, *self.unbatched_action_spec.shape
1374+
)
1375+
self.observation_spec = self.unbatched_observation_spec.expand(
1376+
*self.batch_size, *self.unbatched_observation_spec.shape
1377+
)
1378+
self.reward_spec = self.unbatched_reward_spec.expand(
1379+
*self.batch_size, *self.unbatched_reward_spec.shape
1380+
)
1381+
self.done_spec = self.unbatched_done_spec.expand(
1382+
*self.batch_size, *self.unbatched_done_spec.shape
1383+
)
1384+
1385+
def get_agent_obs_spec(self, i):
1386+
camera = BoundedTensorSpec(minimum=0, maximum=200, shape=(7, 7, 3))
1387+
vector_3d = UnboundedContinuousTensorSpec(shape=(3,))
1388+
vector_2d = UnboundedContinuousTensorSpec(shape=(2,))
1389+
lidar = BoundedTensorSpec(minimum=0, maximum=5, shape=(8,))
1390+
1391+
tensor_0 = UnboundedContinuousTensorSpec(shape=(1,))
1392+
tensor_1 = BoundedTensorSpec(minimum=0, maximum=3, shape=(1, 2))
1393+
tensor_2 = UnboundedContinuousTensorSpec(shape=(1, 2, 3))
1394+
1395+
if i == 0:
1396+
return CompositeSpec(
1397+
{
1398+
"camera": camera,
1399+
"lidar": lidar,
1400+
"vector": vector_3d,
1401+
"tensor_0": tensor_0,
1402+
}
1403+
)
1404+
elif i == 1:
1405+
return CompositeSpec(
1406+
{
1407+
"camera": camera,
1408+
"lidar": lidar,
1409+
"vector": vector_2d,
1410+
"tensor_1": tensor_1,
1411+
}
1412+
)
1413+
elif i == 2:
1414+
return CompositeSpec(
1415+
{
1416+
"camera": camera,
1417+
"vector": vector_2d,
1418+
"tensor_2": tensor_2,
1419+
}
1420+
)
1421+
else:
1422+
raise ValueError(f"Index {i} undefined for index 3")
1423+
1424+
def get_agent_action_spec(self, i):
1425+
action_3d = BoundedTensorSpec(minimum=-1, maximum=1, shape=(3,))
1426+
action_2d = BoundedTensorSpec(minimum=-1, maximum=1, shape=(2,))
1427+
1428+
# Some have 2d action and some 3d
1429+
# TODO Introduce composite heterogeneous actions
1430+
if i == 0:
1431+
ret = action_3d
1432+
elif i == 1:
1433+
ret = action_2d
1434+
elif i == 2:
1435+
ret = action_2d
1436+
else:
1437+
raise ValueError(f"Index {i} undefined for index 3")
1438+
1439+
return CompositeSpec({"action": ret})
1440+
1441+
def _reset(
1442+
self,
1443+
tensordict: TensorDictBase = None,
1444+
**kwargs,
1445+
) -> TensorDictBase:
1446+
if tensordict is not None and "_reset" in tensordict.keys():
1447+
_reset = tensordict.get("_reset").squeeze(-1).any(-1)
1448+
self.count[_reset] = self.start_val
1449+
else:
1450+
self.count[:] = self.start_val
1451+
1452+
reset_td = self.observation_spec.zero()
1453+
reset_td.apply_(lambda x: x + expand_right(self.count, x.shape))
1454+
reset_td.update(self.output_spec["_done_spec"].zero())
1455+
1456+
assert reset_td.batch_size == self.batch_size
1457+
1458+
return reset_td
1459+
1460+
def _step(
1461+
self,
1462+
tensordict: TensorDictBase,
1463+
) -> TensorDictBase:
1464+
actions = torch.zeros_like(self.count.squeeze(-1), dtype=torch.bool)
1465+
for i in range(self.n_nested_dim):
1466+
action = tensordict["lazy"][..., i]["action"]
1467+
action = action[..., 0].to(torch.bool)
1468+
actions += action
1469+
1470+
self.count += actions.unsqueeze(-1).to(torch.int)
1471+
1472+
td = self.observation_spec.zero()
1473+
td.apply_(lambda x: x + expand_right(self.count, x.shape))
1474+
td.update(self.output_spec["_done_spec"].zero())
1475+
td.update(self.output_spec["_reward_spec"].zero())
1476+
1477+
assert td.batch_size == self.batch_size
1478+
td[self.done_key] = expand_right(
1479+
self.count > self.max_steps, self.done_spec.shape
1480+
)
1481+
1482+
return td.select().set("next", td)
1483+
1484+
def _set_seed(self, seed: Optional[int]):
1485+
torch.manual_seed(seed)

0 commit comments

Comments
 (0)