-
-
Notifications
You must be signed in to change notification settings - Fork 43
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
HuggingFace Integration #94
Conversation
…ll as a test for tuple action spaces and a combo env with nested dict and tuple action spaces
…rvations for create_dataset_from_buffers, this may be inefficient and need refactoring
… observation and action space of data now saved in dataset.
…pss+1 observations were being loaded when calling get_episodes
…dict in datacollector after termination or truncation
…ndencies file name to common.py, removed depdency duplication in serialization.py, added a dataset integrity check to test_download_dataset_from_farama_server
…corresponding test
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added few comments on the functions, but overall I have some concerns on the API.
For the user, the current workflow is something like that:
import gymnasium as gym
import minari
from minor import DataCollectorV0
from minari.integrations.hugging_face import (
convert_hugging_face_dataset_to_minari_dataset,
convert_minari_dataset_to_hugging_face_dataset,
pull_dataset_from_hugging_face,
push_dataset_to_hugging_face,
)
env = DataCollectorV0(gym.make("EnvName"))
... # code that generates the dataset
dataset = minari.create_dataset_from_collector_env(...)
hf_dataset = convert_minari_dataset_to_hugging_face_dataset(dataset)
push_dataset_to_hugging_face(hf_dataset, "name/repo")
and then
hf_dataset = pull_dataset_from_hugging_face("name/repo")
dataset = convert_hugging_face_dataset_to_minari_dataset(hf_dataset)
The main red flag for me here is that we have public functions that are specifically for huggingface. I think this should be (more) transparent to the user, i.e. the function in minari/integrations/hugging_face.py
should be all private.
I have a couple of alternatives in mind:
Add a HF flag to load and upload
dataset = minari.load_dataset("name/repo", hugging_face_hub=True)
Which pull the dataset and return a MinariDataset
that reads from HF dataset using the HuggingFaceStorage that I suggested in a review comment.
Similarly for pushing:
minari.upload_dataset('dataset-name', hugging_face_hub=True)
Cons of this:
- we still have flags specifically for HF.
- Load dataset directly from cloud, while we have a function
download_dataset
. I imagine, in the future, we also want the possibility to stream directly from the cloud: shall we dropdownload_dataset
, and implicitly download onload_dataset
?
Use a setup_remote()
We discussed about having the possibility to setup different remotes than our GCP bucket. This is a particular case of that. We can create the API for that and use it in this case:
minari.setup_remote(
"https://huggingface.co/balisujohn",
# others args like auth_key
)
And now every load_dataset
/upload_dataset
takes from/push to the HF hub directly as before, and conversions are done under the hood.
Cons:
- Still,
download_dataset
it is no-sense setup_remote
should work also for GCP bucket and who knows- It complicates the library code as
setup_remote
changes other function behavior
I am more prone for the second version, but it also requires more work
As HF uses Arrow, I am wondering this: if we switch to Arrow as we discussed, will we natively support HF Dataset without needing any conversion? This is also a reason to mask the conversion to user.
assert False, f"error, invalid observation or action structure{data}" | ||
|
||
|
||
def convert_minari_dataset_to_hugging_face_dataset(dataset: MinariDataset): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
missing return type
from minari.serialization import deserialize_space, serialize_space | ||
|
||
|
||
def _reconstuct_obs_or_action_at_index_recursive( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we already have this function, consider a refactoring
"""Converts a MinariDataset into a HuggingFace datasets dataset.""" | ||
episodes = [episode for episode in dataset.iterate_episodes()] | ||
episodes_dict = { | ||
"observations": [], | ||
"actions": [], | ||
"rewards": [], | ||
"truncations": [], | ||
"terminations": [], | ||
"episode_ids": [], | ||
} | ||
for episode in episodes: | ||
episodes_dict["observations"].extend( | ||
[ | ||
_reconstuct_obs_or_action_at_index_recursive(episode.observations, i) | ||
for i in range(episode.total_timesteps + 1) | ||
] | ||
) | ||
episodes_dict["actions"].extend( | ||
[ | ||
_reconstuct_obs_or_action_at_index_recursive(episode.actions, i) | ||
for i in range(episode.total_timesteps) | ||
] | ||
+ [ | ||
None, | ||
] | ||
) | ||
episodes_dict["rewards"].extend( | ||
list(episode.rewards) | ||
+ [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function can be extremely slow for big dataset.
Wouldn't be better using a generator and from_generator() method?
) | ||
|
||
|
||
def convert_hugging_face_dataset_to_minari_dataset(dataset: Dataset): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
missing return type
code_permalink="https://github.com/Farama-Foundation/Minari/blob/f095bfe07f8dc6642082599e07779ec1dd9b2667/tutorials/LocalStorage/local_storage.py", | ||
author="WillDudley", | ||
author_email="[email protected]", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should have meaningful values
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably settable from function arguments?
def convert_hugging_face_dataset_to_minari_dataset(dataset: Dataset): | ||
|
||
description_data = json.loads(dataset.info.description) | ||
|
||
action_space = deserialize_space(description_data["action_space"]) | ||
observation_space = deserialize_space(description_data["observation_space"]) | ||
env_name = description_data["env_name"] | ||
dataset_id = description_data["dataset_id"] | ||
|
||
episode_ids = dataset.unique("episode_ids") | ||
|
||
buffer = [] | ||
|
||
for episode_id in episode_ids: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function can be very slow.
I propose to instead read from a HuggingFace Dataset using a custom MinariStorage
.
To do so, we need an abstract class MinariStorage where we define the public methods that must be implemented. The current MinariStorage is actually a HDF5Storage that implements that interface. We can do a HuggingFaceStorage that reads from a HF Dataset.
dataset_id=dataset_id, | ||
collector_env=env, | ||
algorithm_name="random_policy", | ||
code_permalink="https://github.com/Farama-Foundation/Minari/blob/f095bfe07f8dc6642082599e07779ec1dd9b2667/tutorials/LocalStorage/local_storage.py", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
more informative value; it can be the GitHub link to the common.py file
@pytest.mark.skip( | ||
reason="relies on a private repo, if you want to use this test locally, you'll need to change it to point at a repo you control" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should have this test on a public repo
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: I didn't do a comprehensive review of the code, just left a few comments for whatever stood out to me.
As for the higher-level design, I don't really mind having public functions in a separate minari.integrations.hugging_face
namespace, if those are the ones that developers would use for pulling/uploading/converting.
I'd potentially consider making the conversion back and forth automatic inside of the push/pull functions, depending on whether or not it makes any sense to operate on "raw" HF datasets in the context of Minari.
@younik can you elaborate on your issue with those functions being public? Imo the namespace makes it explicit enough, but I might be missing something.
As for the two alternative proposals:
- I'm not necessarily a fan of integrating it into the core
load_dataset
etc functions, we'd essentially tie core functionality to an external library and external servers. Keeping integrations separate (but accessible) is the right move imo setup_remote
sounds like a somewhat more ambitious plan for the future, like being generic between GCP/AWS/Azure/HF/whatever else, so my guess is that it's not a solution for right now?
elif isinstance(data, np.ndarray): | ||
return data[index] | ||
else: | ||
assert False, f"error, invalid observation or action structure{data}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a reason to asset False
instead of just raising an exception?
code_permalink="https://github.com/Farama-Foundation/Minari/blob/f095bfe07f8dc6642082599e07779ec1dd9b2667/tutorials/LocalStorage/local_storage.py", | ||
author="WillDudley", | ||
author_email="[email protected]", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably settable from function arguments?
@RedTachyon Also, if we want to support other libraries as well (e.g. RLDS), we need other conversion functions; I don't think this is interesting for the user. And we may want to add some loading keywords to |
@@ -28,6 +28,7 @@ dependencies = [ | |||
"numpy >=1.21.0", | |||
"h5py>=3.8.0", | |||
"tqdm>=4.65.0", | |||
"datasets>=2.13.0", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be an optional requirement? Like install minari[huggingface]
Description
Draft PR for hugging face Minari integration.
Adds functions to convert back and forth between
MinariDataset
anddatasets.Dataset
from Hugging Face datasets. Additionally, it adds functions that allow the user to push and pull datasets from hugging face hub. The core code is ready for review, but there are still a few more features I will add, for which I'm adding a checklist to this description:I also refactored the tests slightly, creating the new helperful function
create_dummy_dataset_with_collecter_env_helper
to avoid code repetition.Additional Features
Checklist:
pre-commit
checks withpre-commit run --all-files
(seeCONTRIBUTING.md
instructions to set it up)pytest -v
and no errors are present.pytest -v
has generated that are related to my code to the best of my knowledge.