|
7 | 7 | import torch
|
8 | 8 | import torch.nn as nn
|
9 | 9 | from tensordict.tensordict import TensorDict, TensorDictBase
|
10 |
| -from tensordict.utils import NestedKey |
| 10 | +from tensordict.utils import expand_right, NestedKey |
11 | 11 |
|
12 | 12 | from torchrl.data.tensor_specs import (
|
13 | 13 | BinaryDiscreteTensorSpec,
|
|
19 | 19 | TensorSpec,
|
20 | 20 | UnboundedContinuousTensorSpec,
|
21 | 21 | )
|
| 22 | +from torchrl.data.utils import consolidate_spec |
22 | 23 | from torchrl.envs.common import EnvBase
|
23 | 24 | from torchrl.envs.model_based.common import ModelBasedEnvBase
|
24 | 25 |
|
@@ -1290,3 +1291,195 @@ def _step(
|
1290 | 1291 | device=self.device,
|
1291 | 1292 | )
|
1292 | 1293 | return tensordict.select().set("next", tensordict)
|
| 1294 | + |
| 1295 | + |
| 1296 | +class HeteroCountingEnvPolicy: |
| 1297 | + def __init__(self, full_action_spec: TensorSpec, count: bool = True): |
| 1298 | + self.full_action_spec = full_action_spec |
| 1299 | + self.count = count |
| 1300 | + |
| 1301 | + def __call__(self, td: TensorDictBase) -> TensorDictBase: |
| 1302 | + action_td = self.full_action_spec.zero() |
| 1303 | + if self.count: |
| 1304 | + action_td.apply_(lambda x: x + 1) |
| 1305 | + return td.update(action_td) |
| 1306 | + |
| 1307 | + |
| 1308 | +class HeteroCountingEnv(EnvBase): |
| 1309 | + """A heterogeneous, counting Env.""" |
| 1310 | + |
| 1311 | + def __init__(self, max_steps: int = 5, start_val: int = 0, **kwargs): |
| 1312 | + super().__init__(**kwargs) |
| 1313 | + self.n_nested_dim = 3 |
| 1314 | + self.max_steps = max_steps |
| 1315 | + self.start_val = start_val |
| 1316 | + |
| 1317 | + count = torch.zeros((*self.batch_size, 1), device=self.device, dtype=torch.int) |
| 1318 | + count[:] = self.start_val |
| 1319 | + |
| 1320 | + self.register_buffer("count", count) |
| 1321 | + |
| 1322 | + obs_specs = [] |
| 1323 | + action_specs = [] |
| 1324 | + for index in range(self.n_nested_dim): |
| 1325 | + obs_specs.append(self.get_agent_obs_spec(index)) |
| 1326 | + action_specs.append(self.get_agent_action_spec(index)) |
| 1327 | + obs_specs = torch.stack(obs_specs, dim=0) |
| 1328 | + obs_spec_unlazy = consolidate_spec(obs_specs) |
| 1329 | + action_specs = torch.stack(action_specs, dim=0) |
| 1330 | + |
| 1331 | + self.unbatched_observation_spec = CompositeSpec( |
| 1332 | + lazy=obs_spec_unlazy, |
| 1333 | + state=UnboundedContinuousTensorSpec( |
| 1334 | + shape=( |
| 1335 | + 64, |
| 1336 | + 64, |
| 1337 | + 3, |
| 1338 | + ) |
| 1339 | + ), |
| 1340 | + ) |
| 1341 | + |
| 1342 | + self.unbatched_action_spec = CompositeSpec( |
| 1343 | + lazy=action_specs, |
| 1344 | + ) |
| 1345 | + self.unbatched_reward_spec = CompositeSpec( |
| 1346 | + { |
| 1347 | + "lazy": CompositeSpec( |
| 1348 | + { |
| 1349 | + "reward": UnboundedContinuousTensorSpec( |
| 1350 | + shape=(self.n_nested_dim, 1) |
| 1351 | + ) |
| 1352 | + }, |
| 1353 | + shape=(self.n_nested_dim,), |
| 1354 | + ) |
| 1355 | + } |
| 1356 | + ) |
| 1357 | + self.unbatched_done_spec = CompositeSpec( |
| 1358 | + { |
| 1359 | + "lazy": CompositeSpec( |
| 1360 | + { |
| 1361 | + "done": DiscreteTensorSpec( |
| 1362 | + n=2, |
| 1363 | + shape=(self.n_nested_dim, 1), |
| 1364 | + dtype=torch.bool, |
| 1365 | + ), |
| 1366 | + }, |
| 1367 | + shape=(self.n_nested_dim,), |
| 1368 | + ) |
| 1369 | + } |
| 1370 | + ) |
| 1371 | + |
| 1372 | + self.action_spec = self.unbatched_action_spec.expand( |
| 1373 | + *self.batch_size, *self.unbatched_action_spec.shape |
| 1374 | + ) |
| 1375 | + self.observation_spec = self.unbatched_observation_spec.expand( |
| 1376 | + *self.batch_size, *self.unbatched_observation_spec.shape |
| 1377 | + ) |
| 1378 | + self.reward_spec = self.unbatched_reward_spec.expand( |
| 1379 | + *self.batch_size, *self.unbatched_reward_spec.shape |
| 1380 | + ) |
| 1381 | + self.done_spec = self.unbatched_done_spec.expand( |
| 1382 | + *self.batch_size, *self.unbatched_done_spec.shape |
| 1383 | + ) |
| 1384 | + |
| 1385 | + def get_agent_obs_spec(self, i): |
| 1386 | + camera = BoundedTensorSpec(minimum=0, maximum=200, shape=(7, 7, 3)) |
| 1387 | + vector_3d = UnboundedContinuousTensorSpec(shape=(3,)) |
| 1388 | + vector_2d = UnboundedContinuousTensorSpec(shape=(2,)) |
| 1389 | + lidar = BoundedTensorSpec(minimum=0, maximum=5, shape=(8,)) |
| 1390 | + |
| 1391 | + tensor_0 = UnboundedContinuousTensorSpec(shape=(1,)) |
| 1392 | + tensor_1 = BoundedTensorSpec(minimum=0, maximum=3, shape=(1, 2)) |
| 1393 | + tensor_2 = UnboundedContinuousTensorSpec(shape=(1, 2, 3)) |
| 1394 | + |
| 1395 | + if i == 0: |
| 1396 | + return CompositeSpec( |
| 1397 | + { |
| 1398 | + "camera": camera, |
| 1399 | + "lidar": lidar, |
| 1400 | + "vector": vector_3d, |
| 1401 | + "tensor_0": tensor_0, |
| 1402 | + } |
| 1403 | + ) |
| 1404 | + elif i == 1: |
| 1405 | + return CompositeSpec( |
| 1406 | + { |
| 1407 | + "camera": camera, |
| 1408 | + "lidar": lidar, |
| 1409 | + "vector": vector_2d, |
| 1410 | + "tensor_1": tensor_1, |
| 1411 | + } |
| 1412 | + ) |
| 1413 | + elif i == 2: |
| 1414 | + return CompositeSpec( |
| 1415 | + { |
| 1416 | + "camera": camera, |
| 1417 | + "vector": vector_2d, |
| 1418 | + "tensor_2": tensor_2, |
| 1419 | + } |
| 1420 | + ) |
| 1421 | + else: |
| 1422 | + raise ValueError(f"Index {i} undefined for index 3") |
| 1423 | + |
| 1424 | + def get_agent_action_spec(self, i): |
| 1425 | + action_3d = BoundedTensorSpec(minimum=-1, maximum=1, shape=(3,)) |
| 1426 | + action_2d = BoundedTensorSpec(minimum=-1, maximum=1, shape=(2,)) |
| 1427 | + |
| 1428 | + # Some have 2d action and some 3d |
| 1429 | + # TODO Introduce composite heterogeneous actions |
| 1430 | + if i == 0: |
| 1431 | + ret = action_3d |
| 1432 | + elif i == 1: |
| 1433 | + ret = action_2d |
| 1434 | + elif i == 2: |
| 1435 | + ret = action_2d |
| 1436 | + else: |
| 1437 | + raise ValueError(f"Index {i} undefined for index 3") |
| 1438 | + |
| 1439 | + return CompositeSpec({"action": ret}) |
| 1440 | + |
| 1441 | + def _reset( |
| 1442 | + self, |
| 1443 | + tensordict: TensorDictBase = None, |
| 1444 | + **kwargs, |
| 1445 | + ) -> TensorDictBase: |
| 1446 | + if tensordict is not None and "_reset" in tensordict.keys(): |
| 1447 | + _reset = tensordict.get("_reset").squeeze(-1).any(-1) |
| 1448 | + self.count[_reset] = self.start_val |
| 1449 | + else: |
| 1450 | + self.count[:] = self.start_val |
| 1451 | + |
| 1452 | + reset_td = self.observation_spec.zero() |
| 1453 | + reset_td.apply_(lambda x: x + expand_right(self.count, x.shape)) |
| 1454 | + reset_td.update(self.output_spec["_done_spec"].zero()) |
| 1455 | + |
| 1456 | + assert reset_td.batch_size == self.batch_size |
| 1457 | + |
| 1458 | + return reset_td |
| 1459 | + |
| 1460 | + def _step( |
| 1461 | + self, |
| 1462 | + tensordict: TensorDictBase, |
| 1463 | + ) -> TensorDictBase: |
| 1464 | + actions = torch.zeros_like(self.count.squeeze(-1), dtype=torch.bool) |
| 1465 | + for i in range(self.n_nested_dim): |
| 1466 | + action = tensordict["lazy"][..., i]["action"] |
| 1467 | + action = action[..., 0].to(torch.bool) |
| 1468 | + actions += action |
| 1469 | + |
| 1470 | + self.count += actions.unsqueeze(-1).to(torch.int) |
| 1471 | + |
| 1472 | + td = self.observation_spec.zero() |
| 1473 | + td.apply_(lambda x: x + expand_right(self.count, x.shape)) |
| 1474 | + td.update(self.output_spec["_done_spec"].zero()) |
| 1475 | + td.update(self.output_spec["_reward_spec"].zero()) |
| 1476 | + |
| 1477 | + assert td.batch_size == self.batch_size |
| 1478 | + td[self.done_key] = expand_right( |
| 1479 | + self.count > self.max_steps, self.done_spec.shape |
| 1480 | + ) |
| 1481 | + |
| 1482 | + return td.select().set("next", td) |
| 1483 | + |
| 1484 | + def _set_seed(self, seed: Optional[int]): |
| 1485 | + torch.manual_seed(seed) |
0 commit comments