Skip to content

Commit f9a4d73

Browse files
Brax Teambtaba
Brax Team
authored andcommitted
Internal change
PiperOrigin-RevId: 605094677 Change-Id: I5983e766a55834b56ef4ab037ec4247398a25851
1 parent 532a88a commit f9a4d73

File tree

9 files changed

+40
-21
lines changed

9 files changed

+40
-21
lines changed

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ If you would like to reference Brax in a publication, please use:
100100
author = {C. Daniel Freeman and Erik Frey and Anton Raichuk and Sertan Girgin and Igor Mordatch and Olivier Bachem},
101101
title = {Brax - A Differentiable Physics Engine for Large Scale Rigid Body Simulation},
102102
url = {http://github.com/google/brax},
103-
version = {0.9.4},
103+
version = {0.10.0},
104104
year = {2021},
105105
}
106106
```

brax/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
"""Import top-level classes and functions here for encapsulation/clarity."""
1616

17-
__version__ = '0.9.4'
17+
__version__ = '0.10.0'
1818

1919
from brax.base import Motion
2020
from brax.base import State

brax/contact.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,7 @@ def get(sys: System, x: Transform) -> Optional[Contact]:
3535
Returns:
3636
Contact pytree
3737
"""
38-
# TODO: use mjx.ncon.
39-
ncon = mjx._src.collision_driver.ncon(sys)
38+
ncon = mjx.ncon(sys)
4039
if not ncon:
4140
return None
4241

brax/io/json.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ def dumps(sys: System, states: List[State]) -> Text:
156156
for id_ in range(sys.ngeom):
157157
link_idx = sys.geom_bodyid[id_] - 1
158158

159-
rgba = sys.mj_model.geom_rgba[id_]
159+
rgba = sys.geom_rgba[id_]
160160
if (rgba == [0.5, 0.5, 0.5, 1.0]).all():
161161
# convert the default mjcf color to brax default color
162162
rgba = np.array([0.4, 0.33, 0.26, 1.0])
@@ -171,8 +171,7 @@ def dumps(sys: System, states: List[State]) -> Text:
171171
}
172172

173173
if geom['name'] in ('Mesh', 'Box'):
174-
# TODO: use sys.geom_dataid.
175-
vert, face = _get_mesh(sys.mj_model, sys.mj_model.geom_dataid[id_])
174+
vert, face = _get_mesh(sys.mj_model, sys.geom_dataid[id_])
176175
geom['vert'] = vert
177176
geom['face'] = face
178177

brax/io/mjcf.py

+16-9
Original file line numberDiff line numberDiff line change
@@ -385,25 +385,32 @@ def load_model(mj: mujoco.MjModel) -> System:
385385
)
386386

387387
# create actuators
388+
# TODO: swap brax actuation for mjx actuation model.
388389
ctrl_range = mj.actuator_ctrlrange
389390
ctrl_range[~(mj.actuator_ctrllimited == 1), :] = np.array([-np.inf, np.inf])
390391
force_range = mj.actuator_forcerange
391392
force_range[~(mj.actuator_forcelimited == 1), :] = np.array([-np.inf, np.inf])
392-
q_id = np.array([mj.jnt_qposadr[i] for i in mj.actuator_trnid[:, 0]])
393-
qd_id = np.array([mj.jnt_dofadr[i] for i in mj.actuator_trnid[:, 0]])
394393
bias_q = mj.actuator_biasprm[:, 1] * (mj.actuator_biastype != 0)
395394
bias_qd = mj.actuator_biasprm[:, 2] * (mj.actuator_biastype != 0)
395+
# mask actuators since brax only supports joint transmission types
396+
act_mask = mj.actuator_trntype == mujoco.mjtTrn.mjTRN_JOINT
397+
trnid = mj.actuator_trnid[act_mask, 0].astype(np.uint32)
398+
q_id = mj.jnt_qposadr[trnid]
399+
qd_id = mj.jnt_dofadr[trnid]
400+
act_kwargs = {
401+
'gain': mj.actuator_gainprm[:, 0],
402+
'gear': mj.actuator_gear[:, 0],
403+
'ctrl_range': ctrl_range,
404+
'force_range': force_range,
405+
'bias_q': bias_q,
406+
'bias_qd': bias_qd,
407+
}
408+
act_kwargs = jax.tree_map(lambda x: x[act_mask], act_kwargs)
396409

397-
# TODO: remove brax actuators
398410
actuator = Actuator( # pytype: disable=wrong-arg-types
399411
q_id=q_id,
400412
qd_id=qd_id,
401-
gain=mj.actuator_gainprm[:, 0],
402-
gear=mj.actuator_gear[:, 0],
403-
ctrl_range=ctrl_range,
404-
force_range=force_range,
405-
bias_q=bias_q,
406-
bias_qd=bias_qd,
413+
**act_kwargs
407414
)
408415

409416
# create non-pytree params. these do not live on device directly, and they

brax/io/mjcf_test.py

+9
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from absl.testing import absltest
1919
from brax import test_utils
2020
from brax.io import mjcf
21+
import mujoco
2122
import numpy as np
2223

2324
assert_almost_equal = np.testing.assert_array_almost_equal
@@ -131,6 +132,14 @@ def test_world_fromto(self):
131132
sys = test_utils.load_fixture('world_fromto.xml')
132133
mjcf.validate_model(sys.mj_model)
133134

135+
def test_loads_different_transmission(self):
136+
"""Tests that the brax model loads with different transmission types."""
137+
mj = test_utils.load_fixture_mujoco('ant.xml')
138+
mj.actuator_trntype[0] = mujoco.mjtTrn.mjTRN_SITE
139+
mjcf.load_model(mj) # loads without raising an error
140+
141+
with self.assertRaisesRegex(NotImplementedError, 'transmission types'):
142+
mjcf.validate_model(mj) # raises an error
134143

135144
if __name__ == '__main__':
136145
absltest.main()

docs/release-notes/next-release.md

-4
Original file line numberDiff line numberDiff line change
@@ -1,5 +1 @@
11
# Brax Release Notes
2-
3-
* Rebase brax System and State onto mjx.Model and mjx.Data.
4-
* Use the MuJoCo renderer instead of pytinyrenderer for brax.io.image.
5-
* Separate validation logic from the model loader in brax.io.mjcf.

docs/release-notes/v0.10.0.md

+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# Brax v0.10.0 Release Notes
2+
3+
This minor release makes several changes to the brax API, such that [MJX](https://mujoco.readthedocs.io/en/stable/mjx.html) data structures are the core data structures used in brax. This allows for more seamless model loading from `MuJoCo` XMLs, and allows for running `MJX` physics more seamlessly in brax.
4+
5+
* Rebase brax `System` and `State` onto `mjx.Model` and `mjx.Data`.
6+
* Separate validation logic from the model loading logic in `brax.io.mjcf`. This allows users to load an [MJX](https://mujoco.readthedocs.io/en/stable/mjx.html) model in brax, without hitting validation errors for other physics backends like `positional` and `spring`.
7+
* Remove `System.geoms`, since `brax.System` inherits from `mjx.Model` and all geom information is available in `mjx.Model`. We also update the brax viewer to work with this new schema.
8+
* Delete the brax contact library and use the contact library from `MJX`.
9+
* Use the MuJoCo renderer instead of pytinyrenderer for `brax.io.image`.

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
setup(
2626
name="brax",
27-
version="0.9.4",
27+
version="0.10.0",
2828
description="A differentiable physics engine written in JAX.",
2929
author="Brax Authors",
3030
author_email="[email protected]",

0 commit comments

Comments
 (0)