Skip to content

Commit

Permalink
8.0 Release (#2342)
Browse files Browse the repository at this point in the history
  • Loading branch information
junpeiz authored Sep 16, 2024
1 parent 40f6705 commit 7b13371
Show file tree
Hide file tree
Showing 55 changed files with 4,797 additions and 1,857 deletions.
3 changes: 2 additions & 1 deletion coremlpython/CoreMLPython.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ namespace CoreML {
Model(const Model&) = delete;
Model& operator=(const Model&) = delete;
~Model();
explicit Model(const std::string& urlStr, const std::string& computeUnits, const std::string& functionName);
explicit Model(const std::string& urlStr, const std::string& computeUnits, const std::string& functionName, const py::dict& optimizationHints);
explicit Model(MLModel* m_model, NSURL* compiledUrl, bool deleteCompiledModelOnExit);

py::list batchPredict(const py::list& batch) const;
Expand All @@ -67,6 +67,7 @@ namespace CoreML {
py::dict predict(const py::dict& input, State* state=NULL) const;

#if BUILT_WITH_MACOS15_SDK
static void setOptimizationHints(MLModelConfiguration *configuration, const py::dict& optimizationHints);
State newState() const;
#endif

Expand Down
44 changes: 42 additions & 2 deletions coremlpython/CoreMLPython.mm
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,12 @@ bool usingMacOS13OrHigher() {
}
}

Model::Model(const std::string& urlStr, const std::string& computeUnits, const std::string& functionName) {
Model::Model(
const std::string& urlStr,
const std::string& computeUnits,
const std::string& functionName,
const py::dict& optimizationHints
) {
@autoreleasepool {
NSError *error = nil;

Expand Down Expand Up @@ -80,6 +85,10 @@ bool usingMacOS13OrHigher() {
MLModelConfiguration *configuration = [MLModelConfiguration new];
setComputeUnit(configuration, computeUnits);

#if BUILT_WITH_MACOS15_SDK
setOptimizationHints(configuration, optimizationHints);
#endif

if (!functionName.empty()) {
#if BUILT_WITH_MACOS15_SDK
configuration.functionName = [NSString stringWithUTF8String:functionName.c_str()];
Expand Down Expand Up @@ -148,6 +157,37 @@ bool usingMacOS13OrHigher() {
}


#if BUILT_WITH_MACOS15_SDK
void Model::setOptimizationHints(MLModelConfiguration *configuration, const py::dict& optimizationHints) {
// This function does minimal validation. It assumes Python layer has already validated.

// Reshape frequency optimization hint
if (optimizationHints.contains("reshapeFrequency")) {
const std::string val = optimizationHints["reshapeFrequency"].cast<std::string>();
if (val == "Frequent") {
configuration.optimizationHints.reshapeFrequency = MLReshapeFrequencyHintFrequent;
} else {
assert(val == "Infrequent");
configuration.optimizationHints.reshapeFrequency = MLReshapeFrequencyHintInfrequent;
}
}

// Specialization strategy optimization hint
if (optimizationHints.contains("specializationStrategy")) {
const std::string val = optimizationHints["specializationStrategy"].cast<std::string>();
if (val == "Default") {
configuration.optimizationHints.specializationStrategy = MLSpecializationStrategyDefault;
} else {
assert(val == "FastPrediction");
configuration.optimizationHints.specializationStrategy = MLSpecializationStrategyFastPrediction;
}
}


}
#endif


py::list Model::batchPredict(const py::list& batch) const {
@autoreleasepool {
NSError* error = nil;
Expand Down Expand Up @@ -237,7 +277,7 @@ bool usingMacOS13OrHigher() {
py::module m("libcoremlpython", "CoreML.Framework Python bindings");

py::class_<Model>(m, "_MLModelProxy")
.def(py::init<const std::string&, const std::string&, const std::string&>())
.def(py::init<const std::string&, const std::string&, const std::string&, const py::dict&>())
.def("predict", &Model::predict)
.def("batchPredict", &Model::batchPredict)
.def("get_compiled_model_path", &Model::getCompiledModelPath)
Expand Down
34 changes: 29 additions & 5 deletions coremltools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,35 @@ class ComputeUnit(_Enum):
'''
The set of processing-unit configurations the model can use to make predictions.
'''
ALL = 1 # Allows the model to use all compute units available, including the neural engine
CPU_AND_GPU = 2 # Allows the model to use both the CPU and GPU, but not the neural engine
CPU_ONLY = 3 # Limit the model to only use the CPU
CPU_AND_NE = 4 # Allows the model to use both the CPU and neural engine, but not the GPU.
# Only available on macOS >= 13.0
ALL = 1 # Allows model to use all compute units available, including the neural engine.
CPU_AND_GPU = 2 # Allows model to use both the CPU and GPU, but not the neural engine.
CPU_ONLY = 3 # Limits model to only use the CPU.
CPU_AND_NE = 4 # Allows model to use both the CPU and neural engine, but not the GPU.
# Only available on macOS >= 13.0


class ReshapeFrequency(_Enum):
'''
https://developer.apple.com/documentation/coreml/mlreshapefrequencyhint?language=objc
'''
Frequent = 1
Infrequent = 2


class SpecializationStrategy(_Enum):
'''
The optimization strategy for the model specialization.
https://developer.apple.com/documentation/coreml/mlspecializationstrategy?language=objc
'''

# The strategy that works well for most applications.
Default = 1

# Prefer the prediction latency at the potential cost of specialization time, memory footprint,
# and the disk space usage of specialized artifacts.
FastPrediction = 2


# A dictionary that maps the CoreML model specification version to the MLProgram/MIL opset string
_OPSET = {
Expand Down
26 changes: 24 additions & 2 deletions coremltools/_deps/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,18 +153,33 @@ def __get_sklearn_version(version):

# ---------------------------------------------------------------------------------------
_HAS_TORCH = True
_TORCH_MAX_VERSION = "2.3.0"
_TORCH_MAX_VERSION = "2.4.0"
_HAS_TORCH_EXPORT_API = False
_CT_OPTIMIZE_TORCH_MIN_VERSION = "2.1.0"
_IMPORT_CT_OPTIMIZE_TORCH = False
try:
import torch
_warn_if_above_max_supported_version("Torch", torch.__version__, _TORCH_MAX_VERSION)

if _get_version(torch.__version__) >= Version("2.1.0"):
torch_version = _get_version(torch.__version__)

if torch_version >= Version("2.1.0"):
_HAS_TORCH_EXPORT_API = True

if torch_version >= Version(_CT_OPTIMIZE_TORCH_MIN_VERSION):
_IMPORT_CT_OPTIMIZE_TORCH = True
else:
logger.warning(
(
f"Minimum required torch version for importing coremltools.optimize.torch is {_CT_OPTIMIZE_TORCH_MIN_VERSION}. "
f"Got torch version {torch_version}."
)
)

except:
_HAS_TORCH = False
MSG_TORCH_NOT_FOUND = "PyTorch not found."
MSG_TORCH_EXPORT_API_NOT_FOUND = "Torch.Export API not found."


_HAS_TORCH_VISION = True
Expand All @@ -189,6 +204,13 @@ def __get_sklearn_version(version):
_HAS_EXECUTORCH = False
MSG_EXECUTORCH_NOT_FOUND = "Executorch not found."

_HAS_TORCHAO = True
try:
import torchao
except:
_HAS_TORCHAO = False
MSG_TORCHAO_NOT_FOUND = "Torchao not found."

# ---------------------------------------------------------------------------------------
try:
import scipy
Expand Down
27 changes: 22 additions & 5 deletions coremltools/converters/mil/frontend/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,7 +513,13 @@ def _concat_dims(dims, none_if_empty=False):


def _decompose_scaled_dot_product_attention(
q: Var, k: Var, v: Var, mask: Var, name: str, before_op: Optional[Operation] = None
q: Var,
k: Var,
v: Var,
mask: Var,
name: str,
scale: Optional[Var] = None,
before_op: Optional[Operation] = None,
) -> Var:
# scale the query input
embed_size = q.shape[-1]
Expand All @@ -524,9 +530,12 @@ def _decompose_scaled_dot_product_attention(
)

q, k, v = promote_input_dtypes([q, k, v])
multiplicative_scale_factor = 1 / math.sqrt(embed_size)
if types.builtin_to_string(q.dtype) == "fp16":
multiplicative_scale_factor = np.float16(multiplicative_scale_factor)
if scale is None:
multiplicative_scale_factor = 1 / math.sqrt(embed_size)
if types.builtin_to_string(q.dtype) == "fp16":
multiplicative_scale_factor = np.float16(multiplicative_scale_factor)
else:
multiplicative_scale_factor = scale
q = mb.mul(x=q, y=multiplicative_scale_factor, before_op=before_op)

# multiply query and key input tensors
Expand Down Expand Up @@ -583,6 +592,11 @@ def _construct_constexpr_dequant_op(
scale = np.squeeze(scale)
if isinstance(zero_point, (np.ndarray, np.generic)):
zero_point = np.squeeze(zero_point)
if len(scale.shape) > 1 or len(zero_point.shape) > 1:
raise ValueError(
"The more fine-grained quantization (such as blockwise) is only supported since iOS18."
"Please set minimum_deployment_target to iOS18 for using it."
)

kwargs = {
"quantized_data": quantized_weights,
Expand Down Expand Up @@ -631,7 +645,10 @@ def _construct_constexpr_dequant_op(
}
if zero_point is not None and np.any(zero_point):
# Only pass the offset parameter when not all elements in `zero_point` are zeroes.
zero_point = zero_point.reshape(scale.shape).astype(quantized_weights.dtype)
zero_point = zero_point.reshape(scale.shape)
# When zero_point is integer, it's required to have the same dtype as the quantized weight.
if np.issubdtype(zero_point.dtype, np.integer):
zero_point = zero_point.astype(quantized_weights.dtype)
kwargs["offset"] = zero_point
if name is not None:
kwargs["name"] = name
Expand Down
46 changes: 8 additions & 38 deletions coremltools/converters/mil/frontend/tensorflow/test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2622,15 +2622,6 @@ def test_ios17_resize_bilinear_dynamic_shape(
target_shape,
align_corners,
):
if (
backend == ("mlprogram", "fp16")
and input_shape == (2, 5, 2, 3)
and target_shape == (20, 60)
):
pytest.xfail(
"rdar://116060011: re-activate coremltools tests blocked by Core ML regressions"
)

"""
Since iOS17, dynamic shape is supported by lowering to `resize` MIL op.
"""
Expand Down Expand Up @@ -2732,15 +2723,6 @@ def test_ios17_resize_nearest_neighbor_dynamic_shape(
input_shape,
target_shape,
):
if (
backend == ("mlprogram", "fp16")
and input_shape == (2, 5, 2, 3)
and target_shape == (20, 60)
):
pytest.xfail(
"rdar://116060011: re-activate coremltools tests blocked by Core ML regressions"
)

"""
Since iOS17, dynamic shape is supported by lowering to `resize` MIL op.
"""
Expand Down Expand Up @@ -5706,10 +5688,8 @@ def test_sort(self, compute_unit, backend, rank, dynamic):
"""
tf.sort dispatches to tf.math.top_k, and k = size of the axis to be sorted
"""
if backend[0] == "mlprogram" and dynamic:
pytest.xfail(
"rdar://116060011: re-activate coremltools tests blocked by Core ML regressions"
)
if platform.machine() == "x86_64" and dynamic:
pytest.xfail("rdar://135843153 ([Bug] Models failed on x86_64 platform)")

# Here we test the conversion of tf.sort(x, axis=0)
# If dynamic, we prepend None to x shape as the dynamic shape axis
Expand Down Expand Up @@ -6720,7 +6700,6 @@ def build_model(x):
def test_programmatic(
self, compute_unit, backend, input_block_rank, dynamic_input, dynamic_paddings
):

input_rank, block_rank = input_block_rank

# generate data
Expand All @@ -6733,6 +6712,9 @@ def test_programmatic(
if block_shape[0] == 1:
pytest.skip("neuralnetwork backend doesn't support unity block shape.")

if input_block_rank == (4, 1) and dynamic_input and not dynamic_paddings:
pytest.xfail("rdar://133558007 shape deduction failure")

paddings = []
for i in range(block_rank):
while True:
Expand Down Expand Up @@ -6832,14 +6814,12 @@ def test_programmatic(
self, compute_unit, backend, input_block_rank, dynamic_input, dynamic_crops
):
if (
backend == ("mlprogram", "fp16")
and input_block_rank == (3, 1) or (3,2)
platform.machine() == "x86_64"
and input_block_rank == (3, 1)
and dynamic_input
and not dynamic_crops
):
pytest.xfail(
"rdar://116060011: re-activate coremltools tests blocked by Core ML regressions"
)
pytest.xfail("rdar://135843153 ([Bug] Models failed on x86_64 platform)")

input_rank, block_rank = input_block_rank

Expand Down Expand Up @@ -6939,16 +6919,6 @@ def test_smoke_new_op(
input_shape, block_shape, crops = shape_block_crops
crops = np.array(crops, dtype=np.int32)

if (
backend == ("mlprogram", "fp16")
and shape_block_crops == [(4, 4, 6, 1), [1, 2], [[2, 1], [3, 3]]]
and dynamic_input
and not dynamic_crops
):
pytest.xfail(
"rdar://116060011: re-activate coremltools tests blocked by Core ML regressions"
)

# The neuralnetwork backend doesn't support these tests
if backend[0] == "neuralnetwork":
return
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1389,11 +1389,6 @@ def test_lstm_time_distributed_dense(self, compute_unit, backend):
"compute_unit, backend", itertools.product(compute_units, backends)
)
def test_lstm_dynamic_batch(self, compute_unit, backend):
if backend == ("mlprogram", "fp16"):
pytest.xfail(
"rdar://116060011: re-activate coremltools tests blocked by Core ML regressions"
)

input_shape = (1, 1280)
inp = tf.keras.layers.Input(shape=input_shape)
out, hn, cn = tf.keras.layers.LSTM(512,
Expand Down
Loading

0 comments on commit 7b13371

Please sign in to comment.