Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Question] How to save the model.policy by torch.jit.script #2099

Open
4 tasks done
yucthonni opened this issue Mar 14, 2025 · 0 comments
Open
4 tasks done

[Question] How to save the model.policy by torch.jit.script #2099

yucthonni opened this issue Mar 14, 2025 · 0 comments
Labels
check the checklist You have checked the required items in the checklist but you didn't do what is written... question Further information is requested RTFM Answer is the documentation

Comments

@yucthonni
Copy link

❓ Question

I would like to save PPO model in types of 'model.pt' by torch.jit.script for general usage. But when I try this
scripted_model = th.jit.script(model.policy)
there comes the error:
`ValueError Traceback (most recent call last)
Cell In[49], line 1
----> 1 th.jit.script(model.policy)

File ~/PycharmProjects/PythonProject/.venv/lib/python3.8/site-packages/torch/jit/_script.py:1432, in script(obj, optimize, _frames_up, _rcb, example_inputs)
1429 _TOPLEVEL = False
1431 try:
-> 1432 return _script_impl(
1433 obj=obj,
1434 optimize=optimize,
1435 _frames_up=_frames_up + 1,
1436 _rcb=_rcb,
1437 example_inputs=example_inputs,
1438 )
1439 finally:
1440 _TOPLEVEL = prev

File ~/PycharmProjects/PythonProject/.venv/lib/python3.8/site-packages/torch/jit/_script.py:1146, in _script_impl(obj, optimize, _frames_up, _rcb, example_inputs)
1144 if isinstance(obj, torch.nn.Module):
1145 obj = call_prepare_scriptable_func(obj)
-> 1146 return torch.jit._recursive.create_script_module(
1147 obj, torch.jit._recursive.infer_methods_to_compile
1148 )
1149 else:
1150 obj = obj.prepare_scriptable() if hasattr(obj, "prepare_scriptable") else obj # type: ignore[operator]

File ~/PycharmProjects/PythonProject/.venv/lib/python3.8/site-packages/torch/jit/_recursive.py:556, in create_script_module(nn_module, stubs_fn, share_types, is_tracing)
554 assert not isinstance(nn_module, torch.jit.RecursiveScriptModule)
555 check_module_initialized(nn_module)
--> 556 concrete_type = get_module_concrete_type(nn_module, share_types)
557 if not is_tracing:
558 AttributeTypeIsSupportedChecker().check(nn_module)

File ~/PycharmProjects/PythonProject/.venv/lib/python3.8/site-packages/torch/jit/_recursive.py:505, in get_module_concrete_type(nn_module, share_types)
501 return nn_module._concrete_type
503 if share_types:
504 # Look into the store of cached JIT types
--> 505 concrete_type = concrete_type_store.get_or_create_concrete_type(nn_module)
506 else:
507 # Get a concrete type directly, without trying to re-use an existing JIT
508 # type from the type store.
509 concrete_type_builder = infer_concrete_type_builder(nn_module, share_types)

File ~/PycharmProjects/PythonProject/.venv/lib/python3.8/site-packages/torch/jit/_recursive.py:437, in ConcreteTypeStore.get_or_create_concrete_type(self, nn_module)
435 def get_or_create_concrete_type(self, nn_module):
436 """Infer a ConcreteType from this nn.Module instance. Underlying JIT types are re-used if possible."""
--> 437 concrete_type_builder = infer_concrete_type_builder(nn_module)
439 nn_module_type = type(nn_module)
440 if nn_module_type not in self.type_store:

File ~/PycharmProjects/PythonProject/.venv/lib/python3.8/site-packages/torch/jit/_recursive.py:272, in infer_concrete_type_builder(nn_module, share_types)
269 if name in user_annotated_ignored_attributes:
270 continue
--> 272 attr_type, _ = infer_type(name, item)
273 if item is None:
274 # Modules can be None. We don't have direct support for optional
275 # Modules, so the register it as an NoneType attribute instead.
276 concrete_type_builder.add_attribute(name, attr_type.type(), False, False)

File ~/PycharmProjects/PythonProject/.venv/lib/python3.8/site-packages/torch/jit/_recursive.py:228, in infer_concrete_type_builder..infer_type(name, item)
222 try:
223 if (
224 name in class_annotations
225 and class_annotations[name]
226 != torch.nn.Module.annotations["forward"]
227 ):
--> 228 ann_to_type = torch.jit.annotations.ann_to_type(
229 class_annotations[name], fake_range()
230 )
231 attr_type = torch._C.InferredType(ann_to_type)
232 elif isinstance(item, torch.jit.Attribute):

File ~/PycharmProjects/PythonProject/.venv/lib/python3.8/site-packages/torch/jit/annotations.py:514, in ann_to_type(ann, loc, rcb)
512 if the_type is not None:
513 return the_type
--> 514 raise ValueError(f"Unknown type annotation: '{ann}' at {loc.highlight()}")

ValueError: Unknown type annotation: '<class 'stable_baselines3.common.torch_layers.BaseFeaturesExtractor'>' at`
I want to ask if there is some method to save the PPO model by torch.jit.script or torch.jit.trace, thx

Checklist

@yucthonni yucthonni added the question Further information is requested label Mar 14, 2025
@araffin araffin added RTFM Answer is the documentation check the checklist You have checked the required items in the checklist but you didn't do what is written... labels Mar 14, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
check the checklist You have checked the required items in the checklist but you didn't do what is written... question Further information is requested RTFM Answer is the documentation
Projects
None yet
Development

No branches or pull requests

2 participants