From 1b538202ccd5b4da80e8422b1fc742ee689f43ad Mon Sep 17 00:00:00 2001 From: Tatsuya Kamijo Date: Wed, 5 Feb 2025 16:15:16 +0900 Subject: [PATCH] add visualization for contactile data --- .../scripts/visualize_dataset_contactile.py | 380 ++++++++++++++++++ 1 file changed, 380 insertions(+) create mode 100644 lerobot/scripts/visualize_dataset_contactile.py diff --git a/lerobot/scripts/visualize_dataset_contactile.py b/lerobot/scripts/visualize_dataset_contactile.py new file mode 100644 index 000000000..d882947f8 --- /dev/null +++ b/lerobot/scripts/visualize_dataset_contactile.py @@ -0,0 +1,380 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Visualize data of **all** frames of any episode of a dataset of type LeRobotDataset. + +Note: The last frame of the episode doesnt always correspond to a final state. +That's because our datasets are composed of transition from state to state up to +the antepenultimate state associated to the ultimate action to arrive in the final state. +However, there might not be a transition from a final state to another state. + +Note: This script aims to visualize the data used to train the neural networks. +~What you see is what you get~. When visualizing image modality, it is often expected to observe +lossly compression artifacts since these images have been decoded from compressed mp4 videos to +save disk space. The compression factor applied has been tuned to not affect success rate. + +Examples: + +$ python visualize_dataset.py --repo-id contactile --episode-index 1 --root /root/osx-ur/dependencies/datasets + +- Visualize data stored on a local machine: +``` +local$ python lerobot/scripts/visualize_dataset.py \ + --repo-id lerobot/pusht \ + --episode-index 0 +``` + +- Visualize data stored on a distant machine with a local viewer: +``` +distant$ python lerobot/scripts/visualize_dataset.py \ + --repo-id lerobot/pusht \ + --episode-index 0 \ + --save 1 \ + --output-dir path/to/directory + +local$ scp distant:path/to/directory/lerobot_pusht_episode_0.rrd . +local$ rerun lerobot_pusht_episode_0.rrd +``` + +- Visualize data stored on a distant machine through streaming: +(You need to forward the websocket port to the distant machine, with +`ssh -L 9087:localhost:9087 username@remote-host`) +``` +distant$ python lerobot/scripts/visualize_dataset.py \ + --repo-id lerobot/pusht \ + --episode-index 0 \ + --mode distant \ + --ws-port 9087 + +local$ rerun ws://localhost:9087 +``` + +""" + +import argparse +import gc +import logging +import time +from pathlib import Path +from typing import Iterator + +from cv2 import magnitude +import numpy as np +import rerun as rr +import rerun.blueprint as rrb +import torch +import torch.utils.data +import tqdm +from typing import Union + +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset + + +class EpisodeSampler(torch.utils.data.Sampler): + def __init__(self, dataset: LeRobotDataset, episode_index: int): + from_idx = dataset.episode_data_index["from"][episode_index].item() + to_idx = dataset.episode_data_index["to"][episode_index].item() + self.frame_ids = range(from_idx, to_idx) + + def __iter__(self) -> Iterator: + return iter(self.frame_ids) + + def __len__(self) -> int: + return len(self.frame_ids) + + +def to_hwc_uint8_numpy(chw_float32_torch: torch.Tensor) -> np.ndarray: + assert chw_float32_torch.dtype == torch.float32 + assert chw_float32_torch.ndim == 3 + c, h, w = chw_float32_torch.shape + assert c < h and c < w, f"expect channel first images, but instead {chw_float32_torch.shape}" + hwc_uint8_numpy = (chw_float32_torch * 255).type(torch.uint8).permute(1, 2, 0).numpy() + return hwc_uint8_numpy + + +def visualize_dataset( + repo_id: str, + episode_index: int, + batch_size: int = 32, + num_workers: int = 0, + mode: str = "local", + web_port: int = 9090, + ws_port: int = 9087, + save: bool = False, + root: Union[Path, None] = None, + output_dir: Union[Path, None] = None, +) -> Union[Path, None]: + if save: + assert ( + output_dir is not None + ), "Set an output directory where to write .rrd files with `--output-dir path/to/directory`." + + logging.info("Loading dataset") + dataset = LeRobotDataset(repo_id, root=root) + + logging.info("Loading dataloader") + episode_sampler = EpisodeSampler(dataset, episode_index) + dataloader = torch.utils.data.DataLoader( + dataset, + num_workers=num_workers, + batch_size=batch_size, + sampler=episode_sampler, + ) + + logging.info("Starting Rerun") + + if mode not in ["local", "distant"]: + raise ValueError(mode) + + spawn_local_viewer = mode == "local" and not save + rr.init(f"{repo_id}/episode_{episode_index}", spawn=spawn_local_viewer) + + + # Manually call python garbage collector after `rr.init` to avoid hanging in a blocking flush + # when iterating on a dataloader with `num_workers` > 0 + # TODO(rcadene): remove `gc.collect` when rerun version 0.16 is out, which includes a fix + gc.collect() + + if mode == "distant": + rr.serve(open_browser=False, web_port=web_port, ws_port=ws_port) + + # rr.save(f"rerun/{repo_id}.rrd") + + logging.info("Logging to Rerun") + eef_positions = [] + for batch in tqdm.tqdm(dataloader, total=len(dataloader)): + # iterate over the batch + for i in range(len(batch["index"])): + + # Note: Playback speed (fps) is still not settable from code. + # https://github.com/rerun-io/rerun/issues/5577 + rr.set_time_sequence("frame_index", batch["frame_index"][i].item()) + rr.set_time_seconds("timestamp", batch["timestamp"][i].item()) + + # display each camera image + for key in dataset.camera_keys: + # TODO(rcadene): add `.compress()`? is it lossless? + rr.log(key, rr.Image(to_hwc_uint8_numpy(batch[key][i]))) + + # display EEF trajectory + if "observation.eef.position" in batch: + eef_pos = batch["observation.eef.position"][i].numpy() + eef_positions.append(eef_pos) + rr.log("observation/eef/position", rr.LineStrips3D( + np.array(eef_positions), + )) + rr.log("observation/eef/frame", rr.ViewCoordinates.RIGHT_HAND_Z_UP, static=True) # Set an up-axis + rr.log( + "observation/eef/frame_arrow", + rr.Arrows3D( + vectors=[[0.5, 0, 0], [0, 0.5, 0], [0, 0, 0.5]], + # vectors=[[0, 0.5, 0], [-0.5, 0, 0], [0, 0, 0.5]], + colors=[[255, 0, 0], [0, 255, 0], [0, 0, 255]], + ), + ) + + if "observation.contactile" in batch: + flattened_array = batch['observation.contactile'][i].numpy() # shape: (11, 3) + reshaped_array = flattened_array.reshape(11, 3) + global_x_force = reshaped_array[9][0] + global_y_force = reshaped_array[9][1] + global_z_force = reshaped_array[9][2] + global_x_torque = reshaped_array[10][0] + global_y_torque = reshaped_array[10][1] + global_z_torque = reshaped_array[10][2] + rr.log("observation/contactile/g_force_z", rr.Scalar(global_z_force)) + rr.log("observation/contactile/g_force_y", rr.Scalar(global_y_force)) + rr.log("observation/contactile/g_force_x", rr.Scalar(global_x_force)) + rr.log("observation/contactile/g_torque_z", rr.Scalar(global_z_torque)) + rr.log("observation/contactile/g_torque_y", rr.Scalar(global_y_torque)) + rr.log("observation/contactile/g_torque_x", rr.Scalar(global_x_torque)) + + if "observation.vive_tracker_pose" in batch: + hoge = batch['observation.vive_tracker_pose'][i].numpy() + + # display the summed magnitude of flow vectors + if "observation.tactile.flow" in batch: + vectors = batch["observation.tactile.flow"][i].numpy() + flow_vectors = vectors.reshape(-1, 2) + magnitudes = np.linalg.norm(flow_vectors, axis=1) + total_magnitude = magnitudes.sum() + rr.log("observation/flow_vector_magnitude", rr.Scalar(total_magnitude)) + + # display F/T + if "observation.ft" in batch: + force = batch["observation.ft"][i].numpy()[:3] + torque = batch["observation.ft"][i].numpy()[3:] + + rr.log("observation/ft/force_z", rr.Scalar(force[2])) + rr.log("observation/ft/force_y", rr.Scalar(force[1])) + rr.log("observation/ft/force_x", rr.Scalar(force[0])) + rr.log("observation/ft/torque_x", rr.Scalar(torque[0])) + rr.log("observation/ft/torque_y", rr.Scalar(torque[1])) + rr.log("observation/ft/torque_z", rr.Scalar(torque[2])) + + # display extra_camera, traj, tactile, tactile vector, force + blueprint = rrb.Blueprint( + rrb.Horizontal( + rrb.Vertical( + rrb.Spatial2DView(name="extra camera", origin="observation.image.extra_camera"), + rrb.Spatial3DView( + name="EEF trajectory", + contents=[ + "+ /observation/eef/position", + "+ /observation/eef/frame", + "+ /observation/eef/frame_arrow", + ], + background=[0, 0, 0], + overrides={ + "observation/eef/position": [rr.components.Color([128, 0, 128])] + }, + ), + ), + rrb.Vertical( + rrb.TimeSeriesView( + name="contactile force plot", + contents=[ + "+ /observation/contactile/g_force_z", + "+ /observation/contactile/g_force_y", + "+ /observation/contactile/g_force_x", + ]), + rrb.TimeSeriesView( + name="force plot", + contents=[ + "+ /observation/ft/force_x", + "+ /observation/ft/force_y", + "+ /observation/ft/force_z", + ]) + ), + rrb.Vertical( + rrb.TimeSeriesView( + name="contactile torque plot", + contents=[ + "+ /observation/contactile/g_torque_z", + "+ /observation/contactile/g_torque_y", + "+ /observation/contactile/g_torque_x", + ]), + rrb.TimeSeriesView( + name="torque plot", + contents=[ + "+ /observation/ft/torque_x", + "+ /observation/ft/torque_y", + "+ /observation/ft/torque_z", + ]) + ), + ), + rrb.BlueprintPanel(state="expanded"), + rrb.SelectionPanel(state="collapsed"), + rrb.TimePanel(state="collapsed"), + ) + rr.send_blueprint(blueprint) + + if mode == "local" and save: + # save .rrd locally + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + repo_id_str = repo_id.replace("/", "_") + rrd_path = output_dir / f"{repo_id_str}_episode_{episode_index}.rrd" + rr.save(rrd_path) + return rrd_path + + elif mode == "distant": + # stop the process from exiting since it is serving the websocket connection + try: + while True: + time.sleep(1) + except KeyboardInterrupt: + print("Ctrl-C received. Exiting.") + + +def main(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--repo-id", + type=str, + required=True, + help="Name of hugging face repositery containing a LeRobotDataset dataset (e.g. `lerobot/pusht`).", + ) + parser.add_argument( + "--episode-index", + type=int, + required=True, + help="Episode to visualize.", + ) + parser.add_argument( + "--root", + type=Path, + default=None, + help="Root directory for a dataset stored locally (e.g. `--root data`). By default, the dataset will be loaded from hugging face cache folder, or downloaded from the hub if available.", + ) + parser.add_argument( + "--output-dir", + type=Path, + default=None, + help="Directory path to write a .rrd file when `--save 1` is set.", + ) + parser.add_argument( + "--batch-size", + type=int, + default=32, + help="Batch size loaded by DataLoader.", + ) + parser.add_argument( + "--num-workers", + type=int, + default=4, + help="Number of processes of Dataloader for loading the data.", + ) + parser.add_argument( + "--mode", + type=str, + default="local", + help=( + "Mode of viewing between 'local' or 'distant'. " + "'local' requires data to be on a local machine. It spawns a viewer to visualize the data locally. " + "'distant' creates a server on the distant machine where the data is stored. " + "Visualize the data by connecting to the server with `rerun ws://localhost:PORT` on the local machine." + ), + ) + parser.add_argument( + "--web-port", + type=int, + default=9090, + help="Web port for rerun.io when `--mode distant` is set.", + ) + parser.add_argument( + "--ws-port", + type=int, + default=9087, + help="Web socket port for rerun.io when `--mode distant` is set.", + ) + parser.add_argument( + "--save", + type=int, + default=0, + help=( + "Save a .rrd file in the directory provided by `--output-dir`. " + "It also deactivates the spawning of a viewer. " + "Visualize the data by running `rerun path/to/file.rrd` on your local machine." + ), + ) + + args = parser.parse_args() + visualize_dataset(**vars(args)) + + +if __name__ == "__main__": + main()