[Question] How to save the model.policy by torch.jit.script #2099
Labels
check the checklist
You have checked the required items in the checklist but you didn't do what is written...
question
Further information is requested
RTFM
Answer is the documentation
❓ Question
I would like to save PPO model in types of 'model.pt' by torch.jit.script for general usage. But when I try this
scripted_model = th.jit.script(model.policy)
there comes the error:
`ValueError Traceback (most recent call last)
Cell In[49], line 1
----> 1 th.jit.script(model.policy)
File ~/PycharmProjects/PythonProject/.venv/lib/python3.8/site-packages/torch/jit/_script.py:1432, in script(obj, optimize, _frames_up, _rcb, example_inputs)
1429 _TOPLEVEL = False
1431 try:
-> 1432 return _script_impl(
1433 obj=obj,
1434 optimize=optimize,
1435 _frames_up=_frames_up + 1,
1436 _rcb=_rcb,
1437 example_inputs=example_inputs,
1438 )
1439 finally:
1440 _TOPLEVEL = prev
File ~/PycharmProjects/PythonProject/.venv/lib/python3.8/site-packages/torch/jit/_script.py:1146, in _script_impl(obj, optimize, _frames_up, _rcb, example_inputs)
1144 if isinstance(obj, torch.nn.Module):
1145 obj = call_prepare_scriptable_func(obj)
-> 1146 return torch.jit._recursive.create_script_module(
1147 obj, torch.jit._recursive.infer_methods_to_compile
1148 )
1149 else:
1150 obj = obj.prepare_scriptable() if hasattr(obj, "prepare_scriptable") else obj # type: ignore[operator]
File ~/PycharmProjects/PythonProject/.venv/lib/python3.8/site-packages/torch/jit/_recursive.py:556, in create_script_module(nn_module, stubs_fn, share_types, is_tracing)
554 assert not isinstance(nn_module, torch.jit.RecursiveScriptModule)
555 check_module_initialized(nn_module)
--> 556 concrete_type = get_module_concrete_type(nn_module, share_types)
557 if not is_tracing:
558 AttributeTypeIsSupportedChecker().check(nn_module)
File ~/PycharmProjects/PythonProject/.venv/lib/python3.8/site-packages/torch/jit/_recursive.py:505, in get_module_concrete_type(nn_module, share_types)
501 return nn_module._concrete_type
503 if share_types:
504 # Look into the store of cached JIT types
--> 505 concrete_type = concrete_type_store.get_or_create_concrete_type(nn_module)
506 else:
507 # Get a concrete type directly, without trying to re-use an existing JIT
508 # type from the type store.
509 concrete_type_builder = infer_concrete_type_builder(nn_module, share_types)
File ~/PycharmProjects/PythonProject/.venv/lib/python3.8/site-packages/torch/jit/_recursive.py:437, in ConcreteTypeStore.get_or_create_concrete_type(self, nn_module)
435 def get_or_create_concrete_type(self, nn_module):
436 """Infer a ConcreteType from this
nn.Module
instance. Underlying JIT types are re-used if possible."""--> 437 concrete_type_builder = infer_concrete_type_builder(nn_module)
439 nn_module_type = type(nn_module)
440 if nn_module_type not in self.type_store:
File ~/PycharmProjects/PythonProject/.venv/lib/python3.8/site-packages/torch/jit/_recursive.py:272, in infer_concrete_type_builder(nn_module, share_types)
269 if name in user_annotated_ignored_attributes:
270 continue
--> 272 attr_type, _ = infer_type(name, item)
273 if item is None:
274 # Modules can be None. We don't have direct support for optional
275 # Modules, so the register it as an NoneType attribute instead.
276 concrete_type_builder.add_attribute(name, attr_type.type(), False, False)
File ~/PycharmProjects/PythonProject/.venv/lib/python3.8/site-packages/torch/jit/_recursive.py:228, in infer_concrete_type_builder..infer_type(name, item)
222 try:
223 if (
224 name in class_annotations
225 and class_annotations[name]
226 != torch.nn.Module.annotations["forward"]
227 ):
--> 228 ann_to_type = torch.jit.annotations.ann_to_type(
229 class_annotations[name], fake_range()
230 )
231 attr_type = torch._C.InferredType(ann_to_type)
232 elif isinstance(item, torch.jit.Attribute):
File ~/PycharmProjects/PythonProject/.venv/lib/python3.8/site-packages/torch/jit/annotations.py:514, in ann_to_type(ann, loc, rcb)
512 if the_type is not None:
513 return the_type
--> 514 raise ValueError(f"Unknown type annotation: '{ann}' at {loc.highlight()}")
ValueError: Unknown type annotation: '<class 'stable_baselines3.common.torch_layers.BaseFeaturesExtractor'>' at`
I want to ask if there is some method to save the PPO model by torch.jit.script or torch.jit.trace, thx
Checklist
The text was updated successfully, but these errors were encountered: