Skip to content

Commit 6316a57

Browse files
authored
[Feature] IsaacGymEnvs integration (pytorch#1443)
1 parent e39e701 commit 6316a57

File tree

4 files changed

+264
-6
lines changed

4 files changed

+264
-6
lines changed

test/test_libs.py

+81
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,16 @@
22
#
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
5+
import importlib
6+
7+
_has_isaac = importlib.util.find_spec("isaacgym") is not None
8+
9+
if _has_isaac:
10+
# isaac gym asks to be imported before torch...
11+
import isaacgym # noqa
12+
import isaacgymenvs # noqa
13+
from torchrl.envs.libs.isaacgym import IsaacGymEnv
14+
515
import argparse
616
import importlib
717

@@ -18,6 +28,7 @@
1828
_make_multithreaded_env,
1929
CARTPOLE_VERSIONED,
2030
get_available_devices,
31+
get_default_devices,
2132
HALFCHEETAH_VERSIONED,
2233
PENDULUM_VERSIONED,
2334
PONG_VERSIONED,
@@ -1534,6 +1545,76 @@ def test_data(self, dataset):
15341545
assert len(data) // 2048 in (i, i - 1)
15351546

15361547

1548+
@pytest.mark.skipif(not _has_isaac, reason="IsaacGym not found")
1549+
@pytest.mark.parametrize(
1550+
"task",
1551+
[
1552+
"AllegroHand",
1553+
# "AllegroKuka",
1554+
# "AllegroKukaTwoArms",
1555+
# "AllegroHandManualDR",
1556+
# "AllegroHandADR",
1557+
"Ant",
1558+
# "Anymal",
1559+
# "AnymalTerrain",
1560+
# "BallBalance",
1561+
# "Cartpole",
1562+
# "FactoryTaskGears",
1563+
# "FactoryTaskInsertion",
1564+
# "FactoryTaskNutBoltPick",
1565+
# "FactoryTaskNutBoltPlace",
1566+
# "FactoryTaskNutBoltScrew",
1567+
# "FrankaCabinet",
1568+
# "FrankaCubeStack",
1569+
"Humanoid",
1570+
# "HumanoidAMP",
1571+
# "Ingenuity",
1572+
# "Quadcopter",
1573+
# "ShadowHand",
1574+
"Trifinger",
1575+
],
1576+
)
1577+
@pytest.mark.parametrize("num_envs", [10, 20])
1578+
@pytest.mark.parametrize("device", get_default_devices())
1579+
class TestIsaacGym:
1580+
@classmethod
1581+
def _run_on_proc(cls, q, task, num_envs, device):
1582+
try:
1583+
env = IsaacGymEnv(task=task, num_envs=num_envs, device=device)
1584+
check_env_specs(env)
1585+
q.put(("succeeded!", None))
1586+
except Exception as err:
1587+
q.put(("failed!", err))
1588+
raise err
1589+
1590+
def test_env(self, task, num_envs, device):
1591+
from torch import multiprocessing as mp
1592+
1593+
q = mp.Queue(1)
1594+
proc = mp.Process(target=self._run_on_proc, args=(q, task, num_envs, device))
1595+
try:
1596+
proc.start()
1597+
msg, error = q.get()
1598+
if msg != "succeeded!":
1599+
raise error
1600+
finally:
1601+
q.close()
1602+
proc.join()
1603+
1604+
#
1605+
# def test_collector(self, task, num_envs, device):
1606+
# env = IsaacGymEnv(task=task, num_envs=num_envs, device=device)
1607+
# collector = SyncDataCollector(
1608+
# env,
1609+
# policy=SafeModule(nn.LazyLinear(out_features=env.observation_spec['obs'].shape[-1]), in_keys=["obs"], out_keys=["action"]),
1610+
# frames_per_batch=20,
1611+
# total_frames=-1
1612+
# )
1613+
# for c in collector:
1614+
# assert c.shape == torch.Size([num_envs, 20])
1615+
# break
1616+
1617+
15371618
if __name__ == "__main__":
15381619
args, unknown = argparse.ArgumentParser().parse_known_args()
15391620
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)

torchrl/envs/gym_like.py

-1
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,6 @@ def read_obs(
166166
167167
"""
168168
if isinstance(observations, dict):
169-
observations = {key: value for key, value in observations.items()}
170169
if "state" in observations and "observation" not in observations:
171170
# we rename "state" in "observation" as "observation" is the conventional name
172171
# for single observation in torchrl.

torchrl/envs/libs/gym.py

+13-5
Original file line numberDiff line numberDiff line change
@@ -528,8 +528,8 @@ def _set_seed_initial(self, seed: int) -> None: # noqa: F811
528528
self._seed_calls_reset = False
529529
self._env.seed(seed=seed)
530530

531-
def _make_specs(self, env: "gym.Env") -> None: # noqa: F821
532-
self.action_spec = _gym_to_torchrl_spec_transform(
531+
def _make_specs(self, env: "gym.Env", batch_size=None) -> None: # noqa: F821
532+
action_spec = _gym_to_torchrl_spec_transform(
533533
env.action_space,
534534
device=self.device,
535535
categorical_action_encoding=self._categorical_action_encoding,
@@ -544,18 +544,26 @@ def _make_specs(self, env: "gym.Env") -> None: # noqa: F821
544544
observation_spec = CompositeSpec(pixels=observation_spec)
545545
else:
546546
observation_spec = CompositeSpec(observation=observation_spec)
547-
self.observation_spec = observation_spec
548547
if hasattr(env, "reward_space") and env.reward_space is not None:
549-
self.reward_spec = _gym_to_torchrl_spec_transform(
548+
reward_spec = _gym_to_torchrl_spec_transform(
550549
env.reward_space,
551550
device=self.device,
552551
categorical_action_encoding=self._categorical_action_encoding,
553552
)
554553
else:
555-
self.reward_spec = UnboundedContinuousTensorSpec(
554+
reward_spec = UnboundedContinuousTensorSpec(
556555
shape=[1],
557556
device=self.device,
558557
)
558+
if batch_size is not None:
559+
action_spec = action_spec.expand(*batch_size, *action_spec.shape)
560+
reward_spec = reward_spec.expand(*batch_size, *reward_spec.shape)
561+
observation_spec = observation_spec.expand(
562+
*batch_size, *observation_spec.shape
563+
)
564+
self.action_spec = action_spec
565+
self.reward_spec = reward_spec
566+
self.observation_spec = observation_spec
559567

560568
def _init_env(self):
561569
self.reset()

torchrl/envs/libs/isaacgym.py

+170
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
import importlib.util
6+
7+
import itertools
8+
import warnings
9+
from typing import Any, Dict, List, Union
10+
11+
import numpy as np
12+
import torch
13+
14+
from tensordict import TensorDictBase
15+
from torchrl.envs import make_composite_from_td
16+
from torchrl.envs.libs.gym import GymWrapper
17+
18+
_has_isaac = importlib.util.find_spec("isaacgym") is not None
19+
20+
21+
class IsaacGymWrapper(GymWrapper):
22+
"""Wrapper for IsaacGymEnvs environments.
23+
24+
The original library can be found `here <https://github.com/NVIDIA-Omniverse/IsaacGymEnvs>`_
25+
and is based on IsaacGym which can be downloaded `through NVIDIA's webpage <https://developer.nvidia.com/isaac-gym>_`.
26+
27+
.. note:: IsaacGym environments cannot be executed consecutively, ie. instantiating one
28+
environment after another (even if it has been cleared) will cause
29+
CUDA memory issues. We recommend creating one environment per process only.
30+
If you need more than one environment, the best way to achieve that is
31+
to spawn them across processes.
32+
33+
.. note:: IsaacGym works on CUDA devices by essence. Make sure your machine
34+
has GPUs available and the required setup for IsaacGym (eg, Ubuntu 20.04).
35+
36+
"""
37+
38+
def __init__(
39+
self, env: "isaacgymenvs.tasks.base.vec_task.Env", **kwargs
40+
): # noqa: F821
41+
warnings.warn(
42+
"IsaacGym environment support is an experimental feature that may change in the future."
43+
)
44+
num_envs = env.num_envs
45+
super().__init__(
46+
env, torch.device(env.device), batch_size=torch.Size([num_envs]), **kwargs
47+
)
48+
if not hasattr(self, "task"):
49+
# by convention in IsaacGymEnvs
50+
self.task = env.__name__
51+
52+
def _make_specs(self, env: "gym.Env") -> None: # noqa: F821
53+
super()._make_specs(env, batch_size=self.batch_size)
54+
self.done_spec = self.done_spec.squeeze(-1)
55+
self.observation_spec["obs"] = self.observation_spec["observation"]
56+
del self.observation_spec["observation"]
57+
58+
data = self.rollout(3).get("next")[..., 0]
59+
del data[self.reward_key]
60+
del data[self.done_key]
61+
specs = make_composite_from_td(data)
62+
63+
obs_spec = self.observation_spec
64+
obs_spec.unlock_()
65+
obs_spec.update(specs)
66+
obs_spec.lock_()
67+
self.__dict__["_observation_spec"] = obs_spec
68+
69+
@classmethod
70+
def _make_envs(cls, *, task, num_envs, device, seed=None, headless=True, **kwargs):
71+
import isaacgym # noqa
72+
import isaacgymenvs # noqa
73+
74+
envs = isaacgymenvs.make(
75+
seed=seed,
76+
task=task,
77+
num_envs=num_envs,
78+
sim_device=str(device),
79+
rl_device=str(device),
80+
headless=headless,
81+
**kwargs,
82+
)
83+
return envs
84+
85+
def _set_seed(self, seed: int) -> int:
86+
# as of #665c32170d84b4be66722eea405a1e08b6e7f761 the seed points nowhere in gym.make for IsaacGymEnvs
87+
return seed
88+
89+
def read_action(self, action):
90+
"""Reads the action obtained from the input TensorDict and transforms it in the format expected by the contained environment.
91+
92+
Args:
93+
action (Tensor or TensorDict): an action to be taken in the environment
94+
95+
Returns: an action in a format compatible with the contained environment.
96+
97+
"""
98+
return action
99+
100+
def read_done(self, done):
101+
"""Done state reader.
102+
103+
Reads a done state and returns a tuple containing:
104+
- a done state to be set in the environment
105+
- a boolean value indicating whether the frame_skip loop should be broken
106+
107+
Args:
108+
done (np.ndarray, boolean or other format): done state obtained from the environment
109+
110+
"""
111+
return done.bool(), done.any()
112+
113+
def read_reward(self, total_reward, step_reward):
114+
"""Reads a reward and the total reward so far (in the frame skip loop) and returns a sum of the two.
115+
116+
Args:
117+
total_reward (torch.Tensor or TensorDict): total reward so far in the step
118+
step_reward (reward in the format provided by the inner env): reward of this particular step
119+
120+
"""
121+
return total_reward + step_reward
122+
123+
def read_obs(
124+
self, observations: Union[Dict[str, Any], torch.Tensor, np.ndarray]
125+
) -> Dict[str, Any]:
126+
"""Reads an observation from the environment and returns an observation compatible with the output TensorDict.
127+
128+
Args:
129+
observations (observation under a format dictated by the inner env): observation to be read.
130+
131+
"""
132+
if isinstance(observations, dict):
133+
if "state" in observations and "observation" not in observations:
134+
# we rename "state" in "observation" as "observation" is the conventional name
135+
# for single observation in torchrl.
136+
# naming it 'state' will result in envs that have a different name for the state vector
137+
# when queried with and without pixels
138+
observations["observation"] = observations.pop("state")
139+
if not isinstance(observations, (TensorDictBase, dict)):
140+
(key,) = itertools.islice(self.observation_spec.keys(True, True), 1)
141+
observations = {key: observations}
142+
return observations
143+
144+
145+
class IsaacGymEnv(IsaacGymWrapper):
146+
"""A TorchRL Env interface for IsaacGym environments.
147+
148+
See :class:`~.IsaacGymWrapper` for more information.
149+
150+
Examples:
151+
>>> env = IsaacGymEnv(task="Ant", num_envs=2000, device="cuda:0")
152+
>>> rollout = env.rollout(3)
153+
>>> assert env.batch_size == (2000,)
154+
155+
"""
156+
157+
@property
158+
def available_envs(cls) -> List[str]:
159+
import isaacgymenvs # noqa
160+
161+
return list(isaacgymenvs.tasks.isaacgym_task_map.keys())
162+
163+
def __init__(self, task=None, *, env=None, num_envs, device, **kwargs):
164+
if env is not None and task is not None:
165+
raise RuntimeError("Cannot provide both `task` and `env` arguments.")
166+
elif env is not None:
167+
task = env
168+
envs = self._make_envs(task=task, num_envs=num_envs, device=device, **kwargs)
169+
self.task = task
170+
super().__init__(envs, **kwargs)

0 commit comments

Comments
 (0)