Skip to content

Commit

Permalink
improved visualization
Browse files Browse the repository at this point in the history
  • Loading branch information
tatsukamijo committed Nov 26, 2024
1 parent 859bdf3 commit 095335e
Show file tree
Hide file tree
Showing 3 changed files with 137 additions and 62 deletions.
78 changes: 45 additions & 33 deletions lerobot/common/datasets/push_dataset_to_hub/aloha_hdf5_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,56 +249,64 @@ def load_from_raw(
def to_hf_dataset(data_dict, video) -> Dataset:
features = {}

keys = [key for key in data_dict if "observation.images." in key]
keys = [key for key in data_dict if "image" in key]
for key in keys:
if video:
features[key] = VideoFrame()
else:
features[key] = Image()

features["observation.state"] = Sequence(
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
)
# features["observation.state"] = Sequence(
# length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
# )
if "observation.qpos" in data_dict:
features["observation.qpos"] = Sequence(
length=data_dict["observation.qpos"].shape[1], feature=Value(dtype="float32", id=None)
)
if "observation.velocity" in data_dict:
features["observation.velocity"] = Sequence(
length=data_dict["observation.velocity"].shape[1], feature=Value(dtype="float32", id=None)
)
if "observation.effort" in data_dict:
features["observation.effort"] = Sequence(
length=data_dict["observation.effort"].shape[1], feature=Value(dtype="float32", id=None)
if "observation.qvel" in data_dict:
features["observation.qvel"] = Sequence(
length=data_dict["observation.qvel"].shape[1], feature=Value(dtype="float32", id=None)
)
# if "observation.effort" in data_dict:
# features["observation.effort"] = Sequence(
# length=data_dict["observation.effort"].shape[1], feature=Value(dtype="float32", id=None)
# )
if "observation.ft" in data_dict:
features["observation.ft"] = Sequence(
length=data_dict["observation.ft"].shape[1], feature=Value(dtype="float32", id=None)
)
if "observation.eef_pos" in data_dict:
features["observation.eef_pos"] = Sequence(
length=data_dict["observation.eef_pos"].shape[1], feature=Value(dtype="float32", id=None)
if "observation.tactile.flow" in data_dict:
features["observation.tactile.flow"] = Sequence(
length=data_dict["observation.tactile.flow"].shape[1], feature=Value(dtype="float32", id=None)
)
if "observation.eef_pos.position" in data_dict:
features["observation.eef_pos.position"] = Sequence(
length=data_dict["observation.eef_pos.position"].shape[1], feature=Value(dtype="float32", id=None)
# if "observation.eef_pos" in data_dict:
# features["observation.eef_pos"] = Sequence(
# length=data_dict["observation.eef_pos"].shape[1], feature=Value(dtype="float32", id=None)
# )
if "observation.eef.position" in data_dict:
features["observation.eef.position"] = Sequence(
length=data_dict["observation.eef.position"].shape[1], feature=Value(dtype="float32", id=None)
)
if "observation.eef_pos.rotation_ortho6" in data_dict:
features["observation.eef_pos.rotation_ortho6"] = Sequence(
length=data_dict["observation.eef_pos.rotation_ortho6"].shape[1], feature=Value(dtype="float32", id=None)
if "observation.eef.rotation_axis_angle" in data_dict:
features["observation.eef.rotation_axis_angle"] = Sequence(
length=data_dict["observation.eef.rotation_axis_angle"].shape[1], feature=Value(dtype="float32", id=None)
)
if "observation.eef_pos.gripper" in data_dict:
features["observation.eef_pos.gripper"] = Sequence(
length=data_dict["observation.eef_pos.gripper"].shape[1], feature=Value(dtype="float32", id=None)
if "observation.eef.rotation_ortho6" in data_dict:
features["observation.eef.rotation_ortho6"] = Sequence(
length=data_dict["observation.eef.rotation_ortho6"].shape[1], feature=Value(dtype="float32", id=None)
)
if "observation.eef_vel" in data_dict:
features["observation.eef_vel"] = Sequence(
length=data_dict["observation.eef_vel"].shape[1], feature=Value(dtype="float32", id=None)
if "observation.gripper" in data_dict:
features["observation.gripper"] = Sequence(
length=data_dict["observation.gripper"].shape[1], feature=Value(dtype="float32", id=None)
)

features["action"] = Sequence(
length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)
)
# if "observation.eef_vel" in data_dict:
# features["observation.eef_vel"] = Sequence(
# length=data_dict["observation.eef_vel"].shape[1], feature=Value(dtype="float32", id=None)
# )

# features["action"] = Sequence(
# length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)
# )
if "action.position" in data_dict:
features["action.position"] = Sequence(
length=data_dict["action.position"].shape[1], feature=Value(dtype="float32", id=None)
Expand All @@ -307,9 +315,13 @@ def to_hf_dataset(data_dict, video) -> Dataset:
features["action.rotation_ortho6"] = Sequence(
length=data_dict["action.rotation_ortho6"].shape[1], feature=Value(dtype="float32", id=None)
)
if "action.stiffness_diag" in data_dict:
features["action.stiffness_diag"] = Sequence(
length=data_dict["action.stiffness_diag"].shape[1], feature=Value(dtype="float32", id=None)
if "action.stiffness_diag.trans" in data_dict:
features["action.stiffness_diag.trans"] = Sequence(
length=data_dict["action.stiffness_diag.trans"].shape[1], feature=Value(dtype="float32", id=None)
)
if "action.stiffness_diag.rot" in data_dict:
features["action.stiffness_diag.rot"] = Sequence(
length=data_dict["action.stiffness_diag.rot"].shape[1], feature=Value(dtype="float32", id=None)
)
if "action.gripper" in data_dict:
features["action.gripper"] = Sequence(
Expand Down
7 changes: 0 additions & 7 deletions lerobot/scripts/push_dataset_to_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,7 @@
from lerobot.common.datasets.compute_stats import compute_stats
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
from lerobot.common.datasets.push_dataset_to_hub.utils import check_repo_id
<<<<<<< HEAD
from lerobot.common.datasets.utils import flatten_dict
=======
from lerobot.common.datasets.utils import create_branch, create_lerobot_dataset_card, flatten_dict
>>>>>>> d23faab526d85438722e58f536cb7048afef3d08


def get_from_raw_to_lerobot_format_fn(raw_format: str):
Expand Down Expand Up @@ -401,16 +397,13 @@ def main():
help="When set to 1, resumes a previous run.",
)
parser.add_argument(
<<<<<<< HEAD
=======
"--cache-dir",
type=Path,
required=False,
default="/tmp",
help="Directory to store the temporary videos and images generated while creating the dataset.",
)
parser.add_argument(
>>>>>>> d23faab526d85438722e58f536cb7048afef3d08
"--tests-data-dir",
type=Path,
help=(
Expand Down
114 changes: 92 additions & 22 deletions lerobot/scripts/visualize_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,14 @@
from pathlib import Path
from typing import Iterator

from cv2 import magnitude
import numpy as np
import rerun as rr
import rerun.blueprint as rrb
import torch
import torch.utils.data
import tqdm
from typing import Union

from lerobot.common.datasets.lerobot_dataset import LeRobotDataset

Expand Down Expand Up @@ -108,9 +111,9 @@ def visualize_dataset(
web_port: int = 9090,
ws_port: int = 9087,
save: bool = False,
root: Path | None = None,
output_dir: Path | None = None,
) -> Path | None:
root: Union[Path, None] = None,
output_dir: Union[Path, None] = None,
) -> Union[Path, None]:
if save:
assert (
output_dir is not None
Expand All @@ -136,6 +139,7 @@ def visualize_dataset(
spawn_local_viewer = mode == "local" and not save
rr.init(f"{repo_id}/episode_{episode_index}", spawn=spawn_local_viewer)


# Manually call python garbage collector after `rr.init` to avoid hanging in a blocking flush
# when iterating on a dataloader with `num_workers` > 0
# TODO(rcadene): remove `gc.collect` when rerun version 0.16 is out, which includes a fix
Expand All @@ -145,10 +149,13 @@ def visualize_dataset(
rr.serve(open_browser=False, web_port=web_port, ws_port=ws_port)

logging.info("Logging to Rerun")

eef_positions = []
for batch in tqdm.tqdm(dataloader, total=len(dataloader)):
# iterate over the batch
for i in range(len(batch["index"])):

# Note: Playback speed (fps) is still not settable from code.
# https://github.com/rerun-io/rerun/issues/5577
rr.set_time_sequence("frame_index", batch["frame_index"][i].item())
rr.set_time_seconds("timestamp", batch["timestamp"][i].item())

Expand All @@ -157,24 +164,87 @@ def visualize_dataset(
# TODO(rcadene): add `.compress()`? is it lossless?
rr.log(key, rr.Image(to_hwc_uint8_numpy(batch[key][i])))

# display each dimension of action space (e.g. actuators command)
if "action" in batch:
for dim_idx, val in enumerate(batch["action"][i]):
rr.log(f"action/{dim_idx}", rr.Scalar(val.item()))

# display each dimension of observed state space (e.g. agent position in joint space)
if "observation.state" in batch:
for dim_idx, val in enumerate(batch["observation.state"][i]):
rr.log(f"state/{dim_idx}", rr.Scalar(val.item()))

if "next.done" in batch:
rr.log("next.done", rr.Scalar(batch["next.done"][i].item()))

if "next.reward" in batch:
rr.log("next.reward", rr.Scalar(batch["next.reward"][i].item()))

if "next.success" in batch:
rr.log("next.success", rr.Scalar(batch["next.success"][i].item()))
# display EEF trajectory
if "observation.eef.position" in batch:
eef_pos = batch["observation.eef.position"][i].numpy()
eef_positions.append(eef_pos)
rr.log("observation/eef/position", rr.LineStrips3D(
np.array(eef_positions),
))
rr.log("observation/eef/frame", rr.ViewCoordinates.RIGHT_HAND_Z_UP, static=True) # Set an up-axis
rr.log(
"observation/eef/frame_arrow",
rr.Arrows3D(
vectors=[[0.5, 0, 0], [0, 0.5, 0], [0, 0, 0.5]],
# vectors=[[0, 0.5, 0], [-0.5, 0, 0], [0, 0, 0.5]],
colors=[[255, 0, 0], [0, 255, 0], [0, 0, 255]],
),
)

# display the summed magnitude of flow vectors
if "observation.tactile.flow" in batch:
vectors = batch["observation.tactile.flow"][i].numpy()
flow_vectors = vectors.reshape(-1, 2)
magnitudes = np.linalg.norm(flow_vectors, axis=1)
total_magnitude = magnitudes.sum()
rr.log("observation/flow_vector_magnitude", rr.Scalar(total_magnitude))

# display F/T
if "observation.ft" in batch:
force = batch["observation.ft"][i].numpy()[:3]
torque = batch["observation.ft"][i].numpy()[3:]

rr.log("observation/ft/force_z", rr.Scalar(force[2]))
rr.log("observation/ft/force_y", rr.Scalar(force[1]))
rr.log("observation/ft/force_x", rr.Scalar(force[0]))
rr.log("observation/ft/torque_x", rr.Scalar(torque[0]))
rr.log("observation/ft/torque_y", rr.Scalar(torque[1]))
rr.log("observation/ft/torque_z", rr.Scalar(torque[2]))

# display extra_camera, traj, tactile, tactile vector, force
blueprint = rrb.Blueprint(
rrb.Horizontal(
rrb.Vertical(
rrb.Spatial2DView(name="extra camera", origin="observation.image.extra_camera"),
rrb.Spatial3DView(
name="EEF trajectory",
contents=[
"+ /observation/eef/position",
"+ /observation/eef/frame",
"+ /observation/eef/frame_arrow",
],
background=[0, 0, 0],
overrides={
"observation/eef/position": [rr.components.Color([128, 0, 128])]
},
),
),
rrb.Vertical(
rrb.Spatial2DView(name="tactile flow", origin="observation.tactile.image_flow_overlaid"),
rrb.TimeSeriesView(
name="force plot",
contents=[
"+ /observation/ft/force_x",
"+ /observation/ft/force_y",
"+ /observation/ft/force_z",
])
),
rrb.Vertical(
rrb.TimeSeriesView(name="tactile magnitude", origin="observation/flow_vector_magnitude"),
rrb.TimeSeriesView(
name="torque plot",
contents=[
"+ /observation/ft/torque_x",
"+ /observation/ft/torque_y",
"+ /observation/ft/torque_z",
])
),
),
rrb.BlueprintPanel(state="expanded"),
rrb.SelectionPanel(state="collapsed"),
rrb.TimePanel(state="collapsed"),
)
rr.send_blueprint(blueprint)

if mode == "local" and save:
# save .rrd locally
Expand Down

0 comments on commit 095335e

Please sign in to comment.