Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Torch -> ONNX doesn't work after upgrading transformers to 4.49.0 #36276

Open
2 of 4 tasks
dongruliu opened this issue Feb 19, 2025 · 3 comments
Open
2 of 4 tasks

Torch -> ONNX doesn't work after upgrading transformers to 4.49.0 #36276

dongruliu opened this issue Feb 19, 2025 · 3 comments
Labels

Comments

@dongruliu
Copy link

System Info

  • transformers version: 4.49.0
  • onnx version: 1.17.0
  • Platform: Linux-5.4.143.bsk.8-amd64-x86_64-with-glibc2.31
  • Python version: 3.9.21
  • Huggingface_hub version: 0.28.1
  • Safetensors version: 0.5.2
  • Accelerate version: not installed
  • Accelerate config: not found
  • DeepSpeed version: not installed
  • PyTorch version (GPU?): 2.3.0+cu121 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?:
  • Using GPU in script?:
  • GPU type: Tesla V100-SXM2-32GB

Who can help?

@amyeroberts
I am trying to convert the CLIP model to ONNX version. It worked fine when using the transformers==4.35.2. But it doesn't work when using the transformers==4.49.0. Is there any thing I can do fix this issue? Thanks!

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Code

from transformers import CLIPVisionModel
import torch

model = CLIPVisionModel.from_pretrained(
    "laion/CLIP-ViT-B-32-laion2B-s34B-b79K"
)

input = {"pixel_values": torch.randn([1, 3, 224, 224])}

torch.onnx.export(
    model=model,
    args=tuple(input.values()),
    f="./model.onnx",
    input_names=list(input.keys()),
    output_names=["output"],
    dynamic_axes={key: {0: "batch_size"} for key in input.keys()}
    | {"output": {0: "batch_size"}},
)

Error Message

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[1], line 10
      4 model = CLIPVisionModel.from_pretrained(
      5     "laion/CLIP-ViT-B-32-laion2B-s34B-b79K"
      6 )
      8 input = {"pixel_values": torch.randn([1, 3, 224, 224])}
---> 10 torch.onnx.export(
     11     model=model,
     12     args=tuple(input.values()),
     13     f="./model.onnx",
     14     input_names=list(input.keys()),
     15     output_names=["output"],
     16     dynamic_axes={key: {0: "batch_size"} for key in input.keys()}
     17     | {"output": {0: "batch_size"}},
     18 )

File /mnt/bn/dcc-aai/users/dongru/miniconda3/envs/model_inference/lib/python3.9/site-packages/torch/onnx/utils.py:516, in export(model, args, f, export_params, verbose, training, input_names, output_names, operator_export_type, opset_version, do_constant_folding, dynamic_axes, keep_initializers_as_inputs, custom_opsets, export_modules_as_functions, autograd_inlining)
    189 @_beartype.beartype
    190 def export(
    191     model: Union[torch.nn.Module, torch.jit.ScriptModule, torch.jit.ScriptFunction],
   (...)
    208     autograd_inlining: Optional[bool] = True,
    209 ) -> None:
    210     r"""Exports a model into ONNX format.
    211 
    212     If ``model`` is not a :class:`torch.jit.ScriptModule` nor a
   (...)
    513             All errors are subclasses of :class:`errors.OnnxExporterError`.
    514     """
--> 516     _export(
    517         model,
    518         args,
    519         f,
    520         export_params,
    521         verbose,
    522         training,
    523         input_names,
    524         output_names,
    525         operator_export_type=operator_export_type,
    526         opset_version=opset_version,
    527         do_constant_folding=do_constant_folding,
    528         dynamic_axes=dynamic_axes,
    529         keep_initializers_as_inputs=keep_initializers_as_inputs,
    530         custom_opsets=custom_opsets,
    531         export_modules_as_functions=export_modules_as_functions,
    532         autograd_inlining=autograd_inlining,
    533     )

File /mnt/bn/dcc-aai/users/dongru/miniconda3/envs/model_inference/lib/python3.9/site-packages/torch/onnx/utils.py:1612, in _export(model, args, f, export_params, verbose, training, input_names, output_names, operator_export_type, export_type, opset_version, do_constant_folding, dynamic_axes, keep_initializers_as_inputs, fixed_batch_size, custom_opsets, add_node_names, onnx_shape_inference, export_modules_as_functions, autograd_inlining)
   1609     dynamic_axes = {}
   1610 _validate_dynamic_axes(dynamic_axes, model, input_names, output_names)
-> 1612 graph, params_dict, torch_out = _model_to_graph(
   1613     model,
   1614     args,
   1615     verbose,
   1616     input_names,
   1617     output_names,
   1618     operator_export_type,
   1619     val_do_constant_folding,
   1620     fixed_batch_size=fixed_batch_size,
   1621     training=training,
   1622     dynamic_axes=dynamic_axes,
   1623 )
   1625 # TODO: Don't allocate a in-memory string for the protobuf
   1626 defer_weight_export = (
   1627     export_type is not _exporter_states.ExportTypes.PROTOBUF_FILE
   1628 )

File /mnt/bn/dcc-aai/users/dongru/miniconda3/envs/model_inference/lib/python3.9/site-packages/torch/onnx/utils.py:1138, in _model_to_graph(model, args, verbose, input_names, output_names, operator_export_type, do_constant_folding, _disable_torch_constant_prop, fixed_batch_size, training, dynamic_axes)
   1135 params_dict = _get_named_param_dict(graph, params)
   1137 try:
-> 1138     graph = _optimize_graph(
   1139         graph,
   1140         operator_export_type,
   1141         _disable_torch_constant_prop=_disable_torch_constant_prop,
   1142         fixed_batch_size=fixed_batch_size,
   1143         params_dict=params_dict,
   1144         dynamic_axes=dynamic_axes,
   1145         input_names=input_names,
   1146         module=module,
   1147     )
   1148 except Exception as e:
   1149     torch.onnx.log("Torch IR graph at exception: ", graph)

File /mnt/bn/dcc-aai/users/dongru/miniconda3/envs/model_inference/lib/python3.9/site-packages/torch/onnx/utils.py:677, in _optimize_graph(graph, operator_export_type, _disable_torch_constant_prop, fixed_batch_size, params_dict, dynamic_axes, input_names, module)
    674     _C._jit_pass_onnx_set_dynamic_input_shape(graph, dynamic_axes, input_names)
    675 _C._jit_pass_onnx_lint(graph)
--> 677 graph = _C._jit_pass_onnx(graph, operator_export_type)
    678 _C._jit_pass_onnx_lint(graph)
    679 _C._jit_pass_lint(graph)

File /mnt/bn/dcc-aai/users/dongru/miniconda3/envs/model_inference/lib/python3.9/site-packages/torch/onnx/utils.py:1956, in _run_symbolic_function(graph, block, node, inputs, env, operator_export_type)
   1951     if symbolic_fn is not None:
   1952         # TODO Wrap almost identical attrs assignment or comment the difference.
   1953         attrs = {
   1954             k: symbolic_helper._node_get(node, k) for k in node.attributeNames()
   1955         }
-> 1956         return symbolic_fn(graph_context, *inputs, **attrs)
   1958 attrs = {
   1959     k + "_" + node.kindOf(k)[0]: symbolic_helper._node_get(node, k)
   1960     for k in node.attributeNames()
   1961 }
   1962 if namespace == "onnx":
   1963     # Clone node to trigger ONNX shape inference

File /mnt/bn/dcc-aai/users/dongru/miniconda3/envs/model_inference/lib/python3.9/site-packages/torch/onnx/symbolic_helper.py:306, in parse_args.<locals>.decorator.<locals>.wrapper(g, *args, **kwargs)
    300 if len(kwargs) == 1:
    301     assert "_outputs" in kwargs, (
    302         f"Symbolic function {fn.__name__}'s '**kwargs' can only contain "
    303         f"'_outputs' key at '**kwargs'. "
    304         f"{FILE_BUG_MSG}"
    305     )
--> 306 return fn(g, *args, **kwargs)

File /mnt/bn/dcc-aai/users/dongru/miniconda3/envs/model_inference/lib/python3.9/site-packages/torch/onnx/symbolic_opset14.py:176, in scaled_dot_product_attention(g, query, key, value, attn_mask, dropout_p, is_causal, scale)
    172 key_transposed = g.op("Transpose", key, perm_i=key_transposed_axes)
    174 # https://github.com/pytorch/pytorch/blob/12da0c70378b5be9135c6fda62a9863bce4a4818/aten/src/ATen/native/transformers/attention.cpp#L653
    175 # Scale q, k before matmul for stability see https://tinyurl.com/sudb9s96 for math
--> 176 query_scaled = g.op("Mul", query, g.op("Sqrt", scale))
    177 key_transposed_scaled = g.op("Mul", key_transposed, g.op("Sqrt", scale))
    178 mul_qk = g.op("MatMul", query_scaled, key_transposed_scaled)

File /mnt/bn/dcc-aai/users/dongru/miniconda3/envs/model_inference/lib/python3.9/site-packages/torch/onnx/_internal/jit_utils.py:87, in GraphContext.op(self, opname, outputs, *raw_args, **kwargs)
     59 """Creates an ONNX operator "opname", taking "raw_args" as inputs and "kwargs" as attributes.
     60 
     61 The set of operators and the inputs/attributes they take
   (...)
     84     keyword argument for multi-return nodes).
     85 """
     86 # FIXME(justinchuby): Add the return type back once we know how to handle mypy
---> 87 return _add_op(self, opname, *raw_args, outputs=outputs, **kwargs)

File /mnt/bn/dcc-aai/users/dongru/miniconda3/envs/model_inference/lib/python3.9/site-packages/torch/onnx/_internal/jit_utils.py:238, in _add_op(graph_context, opname, outputs, *args, **kwargs)
    199 @_beartype.beartype
    200 def _add_op(
    201     graph_context: GraphContext,
   (...)
    205     **kwargs,
    206 ):
    207     """Creates an ONNX operator "opname", taking "args" as inputs and attributes "kwargs".
    208 
    209     The set of operators and the inputs/attributes they take
   (...)
    236         keyword argument for multi-return nodes).
    237     """
--> 238     inputs = [_const_if_tensor(graph_context, arg) for arg in args]
    239     # Filter out None attributes, this can be convenient client side because
    240     # now they can pass through None attributes, and have them not show up
    241     attributes = {k: v for k, v in kwargs.items() if v is not None}

File /mnt/bn/dcc-aai/users/dongru/miniconda3/envs/model_inference/lib/python3.9/site-packages/torch/onnx/_internal/jit_utils.py:238, in <listcomp>(.0)
    199 @_beartype.beartype
    200 def _add_op(
    201     graph_context: GraphContext,
   (...)
    205     **kwargs,
    206 ):
    207     """Creates an ONNX operator "opname", taking "args" as inputs and attributes "kwargs".
    208 
    209     The set of operators and the inputs/attributes they take
   (...)
    236         keyword argument for multi-return nodes).
    237     """
--> 238     inputs = [_const_if_tensor(graph_context, arg) for arg in args]
    239     # Filter out None attributes, this can be convenient client side because
    240     # now they can pass through None attributes, and have them not show up
    241     attributes = {k: v for k, v in kwargs.items() if v is not None}

File /mnt/bn/dcc-aai/users/dongru/miniconda3/envs/model_inference/lib/python3.9/site-packages/torch/onnx/_internal/jit_utils.py:269, in _const_if_tensor(graph_context, arg)
    266 if isinstance(arg, _C.Value):
    267     return arg
--> 269 return _add_op(graph_context, "onnx::Constant", value_z=arg)

File /mnt/bn/dcc-aai/users/dongru/miniconda3/envs/model_inference/lib/python3.9/site-packages/torch/onnx/_internal/jit_utils.py:246, in _add_op(graph_context, opname, outputs, *args, **kwargs)
    243 if "::" not in opname:
    244     opname = "onnx::" + opname
--> 246 node = _create_node(
    247     graph_context.block,
    248     opname,
    249     inputs,
    250     attributes,
    251     params_dict=graph_context.params_dict,
    252     opset_version=graph_context.opset,
    253     n_outputs=outputs,
    254     shape_inference=GLOBALS.onnx_shape_inference,
    255 )
    257 if outputs == 1:
    258     return node.output()

File /mnt/bn/dcc-aai/users/dongru/miniconda3/envs/model_inference/lib/python3.9/site-packages/torch/onnx/_internal/jit_utils.py:305, in _create_node(graph_or_block, domain_op, inputs, attributes, params_dict, opset_version, n_outputs, shape_inference)
    303     if key in _SKIP_NODE_ATTRIBUTES:
    304         continue
--> 305     _add_attribute(node, key, value, aten=aten)
    306 if shape_inference:
    307     _C._jit_pass_onnx_node_shape_type_inference(node, params_dict, opset_version)

File /mnt/bn/dcc-aai/users/dongru/miniconda3/envs/model_inference/lib/python3.9/site-packages/torch/onnx/_internal/jit_utils.py:356, in _add_attribute(node, key, value, aten)
    354         else:
    355             kind = "i"
--> 356 return getattr(node, f"{kind}_")(name, value)

TypeError: z_(): incompatible function arguments. The following argument types are supported:
    1. (self: torch._C.Node, arg0: str, arg1: torch.Tensor) -> torch._C.Node

Invoked with: %287 : Tensor = onnx::Constant(), scope: transformers.models.clip.modeling_clip.CLIPVisionModel::/transformers.models.clip.modeling_clip.CLIPVisionTransformer::vision_model/transformers.models.clip.modeling_clip.CLIPEncoder::encoder/transformers.models.clip.modeling_clip.CLIPEncoderLayer::layers.0/transformers.models.clip.modeling_clip.CLIPSdpaAttention::self_attn
, 'value', 0.125 
(Occurred when translating scaled_dot_product_attention).

Expected behavior

The model should be successfully converted and stored.

@dongruliu dongruliu added the bug label Feb 19, 2025
@Rocketknight1
Copy link
Member

Not sure who handles onnx! cc @muellerzr @SunMarc @MekkCyber

@rwightman
Copy link
Contributor

rwightman commented Feb 19, 2025

it's probably sdpa ... try setting attn_implementation="eager" when you create model via from_pretrained ...also I believe Optimum might have built in handling for this so using that for the export would probably help https://github.com/huggingface/optimum

@dongruliu
Copy link
Author

it's probably sdpa ... try setting attn_implementation="eager" when you create model via from_pretrained ...also I believe Optimum might have built in handling for this so using that for the export would probably help https://github.com/huggingface/optimum

It works! Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

3 participants