Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Exporting ct.convert crashes for torch.unfold (im2col) with flexible shapes #2446

Open
ndrnml opened this issue Feb 10, 2025 · 0 comments
Open
Labels
bug Unexpected behaviour that should be corrected (type) Flexible Shape triaged Reviewed and examined, release as been assigned if applicable (status)

Comments

@ndrnml
Copy link

ndrnml commented Feb 10, 2025

Description

The torch.Unfold op conversion (im2col) crashes with flexible shaped input tensors.

Stack Trace

ERROR - converting 'im2col' op (located at: '14'):

Converting PyTorch Frontend ==> MIL Ops:  92%|█████████▏| 12/13 [00:00<00:00, 563.92 ops/s]
Traceback (most recent call last):
  File "playground.py", line 375, in <module>
    coreml_model = ct.convert(
  File "/lib/python3.10/site-packages/coremltools/converters/_converters_entry.py", line 635, in convert
    mlmodel = mil_convert(
  File "/lib/python3.10/site-packages/coremltools/converters/mil/converter.py", line 186, in mil_convert
    return _mil_convert(
  File "/lib/python3.10/site-packages/coremltools/converters/mil/converter.py", line 218, in _mil_convert
    proto, mil_program = mil_convert_to_proto(
  File "/lib/python3.10/site-packages/coremltools/converters/mil/converter.py", line 294, in mil_convert_to_proto
    prog = frontend_converter(model, **kwargs)
  File "/lib/python3.10/site-packages/coremltools/converters/mil/converter.py", line 106, in __call__
    return load(*args, **kwargs)
  File "/lib/python3.10/site-packages/coremltools/converters/mil/frontend/torch/load.py", line 88, in load
    return _perform_torch_convert(converter, debug)
  File "/lib/python3.10/site-packages/coremltools/converters/mil/frontend/torch/load.py", line 151, in _perform_torch_convert
    prog = converter.convert()
  File "/lib/python3.10/site-packages/coremltools/converters/mil/frontend/torch/converter.py", line 1380, in convert
    convert_nodes(self.context, self.graph, early_exit=not has_states)
  File "/lib/python3.10/site-packages/coremltools/converters/mil/frontend/torch/ops.py", line 117, in convert_nodes
    raise e     # re-raise exception
  File "/lib/python3.10/site-packages/coremltools/converters/mil/frontend/torch/ops.py", line 112, in convert_nodes
    convert_single_node(context, node)
  File "/lib/python3.10/site-packages/coremltools/converters/mil/frontend/torch/ops.py", line 173, in convert_single_node
    add_op(context, node)
  File "/lib/python3.10/site-packages/coremltools/converters/mil/frontend/torch/ops.py", line 8142, in im2col
    indices = _construct_unfold_indices(N, C, H, W, kernel_size, stride)
  File "/lib/python3.10/site-packages/coremltools/converters/mil/frontend/torch/ops.py", line 8060, in _construct_unfold_indices
    offset_idx = np.arange(0, row_extent, stride[0])[None, :, None] * W + np.arange(0, col_extent, stride[1])
  File "/lib/python3.10/site-packages/sympy/core/expr.py", line 340, in __float__
    raise TypeError("Cannot convert expression to float")
TypeError: Cannot convert expression to float

To Reproduce

Working example with fixed shape input.

import torch
import coremltools as ct
from torch import nn
import torch.nn.functional as F


class UnfoldModule(nn.Module):
    def forward(self, x):
        return F.unfold(x, (3, 3), padding=0)

model = UnfoldModule()
x = torch.rand(1, 1, 64, 64)

inpt = (x,)
exported_model = torch.export.export(model, inpt)
traced_model = torch.jit.trace(model, inpt)

ct.convert(traced_model, inputs=[ct.TensorType(shape=x.shape),],).save("/tmp/model.mlpackage")

Non-working example with flexible shape input.

import torch
import coremltools as ct
from torch import nn
import torch.nn.functional as F


class UnfoldModule(nn.Module):
    def forward(self, x):
        return F.unfold(x, (3, 3), padding=0)

model = UnfoldModule()
x = torch.rand(1, 1, 64, 64)

inpt = (x,)
exported_model = torch.export.export(model, inpt)
traced_model = torch.jit.trace(model, inpt)

flexible_shape = ct.Shape(shape=(1, 1, ct.RangeDim(lower_bound=64, upper_bound=128, default=64),
                                 ct.RangeDim(lower_bound=64, upper_bound=128, default=64)))

ct.convert(traced_model, inputs=[ct.TensorType(shape=flexible_shape),],).save("/tmp/model.mlpackage")

System environment

Ubuntu 22.04.5 LTS
Python 3.10.12
torch 2.3.1
torchvision 0.18.1
coremltools 8.2

@ndrnml ndrnml added the bug Unexpected behaviour that should be corrected (type) label Feb 10, 2025
@ndrnml ndrnml changed the title Unfold (im2col) can't handle flexible shapes Exporting ct.convert crashes for torch.unfold (im2col) with flexible shapes Feb 10, 2025
@TobyRoseman TobyRoseman added triaged Reviewed and examined, release as been assigned if applicable (status) Flexible Shape labels Feb 11, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Unexpected behaviour that should be corrected (type) Flexible Shape triaged Reviewed and examined, release as been assigned if applicable (status)
Projects
None yet
Development

No branches or pull requests

2 participants