Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][tosa] Fix merge problems with mul shift #125129

Merged
merged 1 commit into from
Jan 31, 2025

Conversation

Tai78641
Copy link
Contributor

This patch fixes merge issues in TosaOpBase.td and TosaOps.td wrt traits on tosa elementwise ops and multiply op which, with the optional shift operand, is no longer strictly an elementwise op.

fixed up inferReturnTypeComponents to be based on only the first two operands (ie, ignoring shift, if present)

also fixed up TosaReduceTransposes to special handle tosa mul op now that it is not an elementwise op.

This patch fixes merge issues in TosaOpBase.td and TosaOps.td
wrt traits on tosa elementwise ops and multiply op
which, with the optional shift operand, is no longer strictly
an elementwise op.

fixed up inferReturnTypeComponents to be based on only the first
two operands (ie, ignoring shift, if present)

also fixed up TosaReduceTransposes to special handle tosa mul op
now that it is not an elementwise op.

Signed-off-by: Tai Ly <[email protected]>
Change-Id: I30bd10137870cfe079c761da21389ca6a3b2e5e8
@llvmbot
Copy link
Member

llvmbot commented Jan 30, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-tosa

Author: Tai Ly (Tai78641)

Changes

This patch fixes merge issues in TosaOpBase.td and TosaOps.td wrt traits on tosa elementwise ops and multiply op which, with the optional shift operand, is no longer strictly an elementwise op.

fixed up inferReturnTypeComponents to be based on only the first two operands (ie, ignoring shift, if present)

also fixed up TosaReduceTransposes to special handle tosa mul op now that it is not an elementwise op.


Full diff: https://github.com/llvm/llvm-project/pull/125129.diff

6 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td (+2-2)
  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (+30-68)
  • (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+36-1)
  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp (+15-3)
  • (modified) mlir/test/Dialect/Tosa/invalid.mlir (+3-3)
  • (modified) mlir/test/Dialect/Tosa/ops.mlir (+8)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
index 29afd6c27302cc..4975530a9588ca 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
@@ -239,7 +239,9 @@ class Tosa_ElementwiseOp<string mnemonic, list<Trait> traits = []> :
     Tosa_Op<mnemonic, !listconcat(traits, [
               DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
                                         ["inferReturnTypeComponents"]>,
+              ResultsBroadcastableShape,
               TosaElementwiseOperator,
+              SameOperandsAndResultRank,
               Pure])> {
   let assemblyFormat =
       "operands attr-dict `:` functional-type(operands, results)";
@@ -248,8 +250,6 @@ class Tosa_ElementwiseOp<string mnemonic, list<Trait> traits = []> :
 class Tosa_ElementwiseUnaryOp<string mnemonic, list<Trait> traits = []> :
     Tosa_ElementwiseOp<mnemonic, !listconcat(traits, [
               SameOperandsAndResultShape,
-              ResultsBroadcastableShape,
-              SameOperandsAndResultRank,
               SameOperandsAndResultElementType])> {}
 
 class Tosa_InferTensorTypeOp<string mnemonic, list<Trait> traits = []>
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 9e3e41d288e4ac..c59c582a1f5221 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -482,9 +482,7 @@ def Tosa_ErfOp : Tosa_ElementwiseUnaryOp<"erf"> {
 //===----------------------------------------------------------------------===//
 def Tosa_AddOp : Tosa_ElementwiseOp<"add", [
     Commutative,
-    ResultsBroadcastableShape,
-    SameOperandsAndResultElementType,
-    SameOperandsAndResultRank]> {
+    SameOperandsAndResultElementType]> {
   let summary = "Elementwise addition operator";
 
   let description = [{
@@ -517,10 +515,8 @@ def Tosa_AddOp : Tosa_ElementwiseOp<"add", [
 //===----------------------------------------------------------------------===//
 // Operator: arithmetic_right_shift
 //===----------------------------------------------------------------------===//
-def Tosa_ArithmeticRightShiftOp : Tosa_ElementwiseOp<"arithmetic_right_shift", [
-    ResultsBroadcastableShape,
-    SameOperandsAndResultElementType,
-    SameOperandsAndResultRank]> {
+def Tosa_ArithmeticRightShiftOp : Tosa_ElementwiseOp<"arithmetic_right_shift",
+    [SameOperandsAndResultElementType]> {
   let summary = "Elementwise Arithmetic Right Shift";
 
   let description = [{
@@ -544,9 +540,7 @@ def Tosa_ArithmeticRightShiftOp : Tosa_ElementwiseOp<"arithmetic_right_shift", [
 //===----------------------------------------------------------------------===//
 def Tosa_BitwiseAndOp : Tosa_ElementwiseOp<"bitwise_and", [
     Commutative,
-    ResultsBroadcastableShape,
-    SameOperandsAndResultElementType,
-    SameOperandsAndResultRank]> {
+    SameOperandsAndResultElementType]> {
   let summary = "Bitwise AND operator";
 
   let description = [{
@@ -569,9 +563,7 @@ def Tosa_BitwiseAndOp : Tosa_ElementwiseOp<"bitwise_and", [
 //===----------------------------------------------------------------------===//
 def Tosa_BitwiseOrOp : Tosa_ElementwiseOp<"bitwise_or", [
     Commutative,
-    ResultsBroadcastableShape,
-    SameOperandsAndResultElementType,
-    SameOperandsAndResultRank]> {
+    SameOperandsAndResultElementType]> {
   let summary = "Bitwise OR operator";
 
   let description = [{
@@ -594,9 +586,7 @@ def Tosa_BitwiseOrOp : Tosa_ElementwiseOp<"bitwise_or", [
 //===----------------------------------------------------------------------===//
 def Tosa_BitwiseXorOp : Tosa_ElementwiseOp<"bitwise_xor", [
     Commutative,
-    ResultsBroadcastableShape,
-    SameOperandsAndResultElementType,
-    SameOperandsAndResultRank]> {
+    SameOperandsAndResultElementType]> {
   let summary = "Bitwise XOR operator";
 
   let description = [{
@@ -617,10 +607,7 @@ def Tosa_BitwiseXorOp : Tosa_ElementwiseOp<"bitwise_xor", [
 //===----------------------------------------------------------------------===//
 // Operator: int_div
 //===----------------------------------------------------------------------===//
-def Tosa_IntDivOp : Tosa_ElementwiseOp<"int_div", [
-    ResultsBroadcastableShape,
-    SameOperandsAndResultRank,
-    SameOperandsAndResultElementType]> {
+def Tosa_IntDivOp : Tosa_ElementwiseOp<"int_div", [SameOperandsAndResultElementType]> {
   let summary = "Integer divide operator";
 
   let description = [{
@@ -645,9 +632,7 @@ def Tosa_IntDivOp : Tosa_ElementwiseOp<"int_div", [
 //===----------------------------------------------------------------------===//
 def Tosa_LogicalAndOp : Tosa_ElementwiseOp<"logical_and", [
     Commutative,
-    ResultsBroadcastableShape,
-    SameOperandsAndResultElementType,
-    SameOperandsAndResultRank]> {
+    SameOperandsAndResultElementType]> {
   let summary = "Returns the truth value of x AND y element-wise.";
 
   let description = [{
@@ -668,10 +653,8 @@ def Tosa_LogicalAndOp : Tosa_ElementwiseOp<"logical_and", [
 //===----------------------------------------------------------------------===//
 // Operator: logical_left_shift
 //===----------------------------------------------------------------------===//
-def Tosa_LogicalLeftShiftOp : Tosa_ElementwiseOp<"logical_left_shift", [
-    ResultsBroadcastableShape,
-    SameOperandsAndResultElementType,
-    SameOperandsAndResultRank]> {
+def Tosa_LogicalLeftShiftOp : Tosa_ElementwiseOp<"logical_left_shift",
+    [SameOperandsAndResultElementType]> {
   let summary = "Elementwise Logical Left Shift";
 
   let description = [{
@@ -692,10 +675,8 @@ def Tosa_LogicalLeftShiftOp : Tosa_ElementwiseOp<"logical_left_shift", [
 //===----------------------------------------------------------------------===//
 // Operator: logical_right_shift
 //===----------------------------------------------------------------------===//
-def Tosa_LogicalRightShiftOp : Tosa_ElementwiseOp<"logical_right_shift", [
-    ResultsBroadcastableShape,
-    SameOperandsAndResultElementType,
-    SameOperandsAndResultRank]> {
+def Tosa_LogicalRightShiftOp : Tosa_ElementwiseOp<"logical_right_shift",
+    [SameOperandsAndResultElementType]> {
   let summary = "Elementwise Logical Right Shift";
 
   let description = [{
@@ -718,9 +699,7 @@ def Tosa_LogicalRightShiftOp : Tosa_ElementwiseOp<"logical_right_shift", [
 //===----------------------------------------------------------------------===//
 def Tosa_LogicalOrOp : Tosa_ElementwiseOp<"logical_or", [
     Commutative,
-    ResultsBroadcastableShape,
-    SameOperandsAndResultElementType,
-    SameOperandsAndResultRank]> {
+    SameOperandsAndResultElementType]> {
   let summary = "Returns the truth value of x OR y element-wise.";
 
   let description = [{
@@ -743,9 +722,7 @@ def Tosa_LogicalOrOp : Tosa_ElementwiseOp<"logical_or", [
 //===----------------------------------------------------------------------===//
 def Tosa_LogicalXorOp : Tosa_ElementwiseOp<"logical_xor", [
     Commutative,
-    ResultsBroadcastableShape,
-    SameOperandsAndResultElementType,
-    SameOperandsAndResultRank]> {
+    SameOperandsAndResultElementType]> {
   let summary = "Returns the truth value of x XOR y element-wise.";
 
   let description = [{
@@ -768,9 +745,7 @@ def Tosa_LogicalXorOp : Tosa_ElementwiseOp<"logical_xor", [
 //===----------------------------------------------------------------------===//
 def Tosa_MaximumOp : Tosa_ElementwiseOp<"maximum", [
     Commutative,
-    ResultsBroadcastableShape,
-    SameOperandsAndResultElementType,
-    SameOperandsAndResultRank]> {
+    SameOperandsAndResultElementType]> {
   let summary = "Elementwise Maximum";
 
   let description = [{
@@ -794,9 +769,7 @@ def Tosa_MaximumOp : Tosa_ElementwiseOp<"maximum", [
 //===----------------------------------------------------------------------===//
 def Tosa_MinimumOp : Tosa_ElementwiseOp<"minimum", [
     Commutative,
-    ResultsBroadcastableShape,
-    SameOperandsAndResultElementType,
-    SameOperandsAndResultRank]> {
+    SameOperandsAndResultElementType]> {
   let summary = "Elementwise Minimum";
 
   let description = [{
@@ -823,9 +796,11 @@ def MulOperandsAndResultElementType :
 //===----------------------------------------------------------------------===//
 // Operator: mul
 //===----------------------------------------------------------------------===//
-def Tosa_MulOp : Tosa_ElementwiseOp<"mul", [
+def Tosa_MulOp : Tosa_Op<"mul", [
+    DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
+                              ["inferReturnTypeComponents"]>,
     Commutative,
-    MulOperandsAndResultElementType]> {
+    Pure]> {
   let summary = "Multiplication operator";
 
   let description = [{
@@ -846,15 +821,15 @@ def Tosa_MulOp : Tosa_ElementwiseOp<"mul", [
 
   let hasFolder = 1;
   let hasVerifier = 1;
+
+  let assemblyFormat =
+      "operands attr-dict `:` functional-type(operands, results)";
 }
 
 //===----------------------------------------------------------------------===//
 // Operator: pow
 //===----------------------------------------------------------------------===//
-def Tosa_PowOp : Tosa_ElementwiseOp<"pow", [
-    ResultsBroadcastableShape,
-    SameOperandsAndResultElementType,
-    SameOperandsAndResultRank]> {
+def Tosa_PowOp : Tosa_ElementwiseOp<"pow", [SameOperandsAndResultElementType]> {
   let summary = "Computes the power of one value to another.";
 
   let description = [{
@@ -875,10 +850,7 @@ def Tosa_PowOp : Tosa_ElementwiseOp<"pow", [
 //===----------------------------------------------------------------------===//
 // Operator: sub
 //===----------------------------------------------------------------------===//
-def Tosa_SubOp : Tosa_ElementwiseOp<"sub", [
-    ResultsBroadcastableShape,
-    SameOperandsAndResultElementType,
-    SameOperandsAndResultRank]> {
+def Tosa_SubOp : Tosa_ElementwiseOp<"sub", [SameOperandsAndResultElementType]> {
   let summary = "Elementwise subtraction operator";
 
   let description = [{
@@ -1229,9 +1201,7 @@ def Tosa_SinOp : Tosa_ElementwiseUnaryOp<"sin"> {
 //===----------------------------------------------------------------------===//
 // Operator: select
 //===----------------------------------------------------------------------===//
-def Tosa_SelectOp : Tosa_ElementwiseOp<"select", [
-  ResultsBroadcastableShape,
-  SameOperandsAndResultRank]> {
+def Tosa_SelectOp : Tosa_ElementwiseOp<"select"> {
   let summary = "Elementwise select operator";
 
   let description = [{
@@ -1267,9 +1237,7 @@ def Tosa_SelectOp : Tosa_ElementwiseOp<"select", [
 def Tosa_EqualOp : Tosa_ElementwiseOp<"equal", [
     InferTensorType,
     Commutative,
-    ResultsBroadcastableShape,
-    SameOperandsElementType,
-    SameOperandsAndResultRank]> {
+    SameOperandsElementType]> {
   let summary = "Returns the truth value of (x == y) element-wise.";
 
   let description = [{
@@ -1297,10 +1265,7 @@ def Tosa_EqualOp : Tosa_ElementwiseOp<"equal", [
 //===----------------------------------------------------------------------===//
 // Operator: greater
 //===----------------------------------------------------------------------===//
-def Tosa_GreaterOp : Tosa_ElementwiseOp<"greater", [
-    ResultsBroadcastableShape,
-    SameOperandsElementType,
-    SameOperandsAndResultRank]> {
+def Tosa_GreaterOp : Tosa_ElementwiseOp<"greater", [SameOperandsElementType]> {
   let summary = "Returns the truth value of (x > y) element-wise.";
 
   let description = [{
@@ -1322,11 +1287,8 @@ def Tosa_GreaterOp : Tosa_ElementwiseOp<"greater", [
 //===----------------------------------------------------------------------===//
 // Operator: greater_equal
 //===----------------------------------------------------------------------===//
-def Tosa_GreaterEqualOp : Tosa_ElementwiseOp<"greater_equal", [
-    ResultsBroadcastableShape,
-     SameOperandsElementType,
-     SameOperandsAndResultRank
-    ]> {
+def Tosa_GreaterEqualOp : Tosa_ElementwiseOp<"greater_equal",
+    [SameOperandsElementType]> {
   let summary = "Returns the truth value of (x >= y) element-wise.";
 
   let description = [{
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index c0b419b6f473c8..0a10439db40803 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -958,6 +958,28 @@ LogicalResult tosa::SliceOp::verify() {
   return success();
 }
 
+LogicalResult tosa::MulOp::inferReturnTypeComponents(
+    MLIRContext *context, ::std::optional<Location> location,
+    ValueShapeRange operands, DictionaryAttr attributes,
+    OpaqueProperties properties, RegionRange regions,
+    SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
+  LogicalResult status = success();
+  llvm::SmallVector<int64_t> outShape;
+  if (operands.size() == 2) {
+    status = resolveBroadcastShape(operands, outShape);
+  } else {
+    // mul op's output shape only depend on input1 and input2, not on shift
+    ValueShapeRange two_inputs = operands.drop_back();
+    status = resolveBroadcastShape(two_inputs, outShape);
+  }
+  if (status.failed()) {
+    inferredReturnShapes.push_back(ShapedTypeComponents());
+  } else {
+    inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
+  }
+  return success();
+}
+
 LogicalResult tosa::MulOp::verify() {
   auto resElemType = getElementTypeOrSelf(getOutput());
 
@@ -1030,6 +1052,20 @@ LogicalResult tosa::MulOp::verify() {
     }
   }
 
+  // check for broadcast compatible shapes in first two operands (ignoring
+  // shift)
+
+  // delegate function that returns shape of shaped type
+  auto getShape = [](const Type type) {
+    return mlir::cast<ShapedType>(type).getShape();
+  };
+  SmallVector<int64_t> resultShape;
+  if (!mlir::OpTrait::util::getBroadcastedShape(getShape(rankedOperandTypes[0]),
+                                                getShape(rankedOperandTypes[1]),
+                                                resultShape)) {
+    return emitOpError("operands don't have broadcast-compatible shapes");
+  }
+
   return success();
 }
 
@@ -1670,7 +1706,6 @@ NARY_SHAPE_INFER(tosa::LogicalRightShiftOp)
 NARY_SHAPE_INFER(tosa::LogicalXorOp)
 NARY_SHAPE_INFER(tosa::MaximumOp)
 NARY_SHAPE_INFER(tosa::MinimumOp)
-NARY_SHAPE_INFER(tosa::MulOp)
 NARY_SHAPE_INFER(tosa::NegateOp)
 NARY_SHAPE_INFER(tosa::PowOp)
 NARY_SHAPE_INFER(tosa::ReciprocalOp)
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp
index e1f0a9592e8b4f..520f283a3ba888 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp
@@ -281,13 +281,20 @@ bool TosaReduceTransposes::collectFanIn(Operation *op,
   if (!llvm::isa<tosa::TransposeOp>(op) && !llvm::isa<tosa::ReshapeOp>(op) &&
       !llvm::isa<tosa::ConstOp>(op)) {
 
-    if (!op->hasTrait<OpTrait::tosa::TosaElementwiseOperator>())
+    if (!llvm::isa<tosa::MulOp>(op) &&
+        !op->hasTrait<OpTrait::tosa::TosaElementwiseOperator>())
       return false;
 
-    for (Value operand : op->getOperands())
+    for (Value operand : op->getOperands()) {
       // If this is a problem in future, think about alternatives to recursion.
+      if (llvm::isa<tosa::MulOp>(op) && op->getNumOperands() == 3 &&
+          operand == op->getOperand(2)) {
+        // do not recurse into MulOp's shift operand
+        continue;
+      }
       if (!collectFanIn(operand.getDefiningOp(), collected))
         return false;
+    }
   }
 
   // Insert in topological order.
@@ -316,7 +323,8 @@ std::optional<Value> TosaReduceTransposes::buildMappedToValue(
     Operation *op, const DenseMap<Value, Value> &valuesMap,
     IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms) {
   if (op->getNumResults() != 1 ||
-      !op->hasTrait<OpTrait::tosa::TosaElementwiseOperator>())
+      (!llvm::isa<tosa::MulOp>(op) &&
+       !op->hasTrait<OpTrait::tosa::TosaElementwiseOperator>()))
     return std::nullopt;
 
   auto resultType = llvm::cast<RankedTensorType>(op->getResult(0).getType());
@@ -324,6 +332,10 @@ std::optional<Value> TosaReduceTransposes::buildMappedToValue(
   for (Value v : op->getOperands()) {
     if (valuesMap.contains(v)) {
       operands.push_back(valuesMap.at(v));
+    } else if (llvm::isa<tosa::MulOp>(op) && op->getNumOperands() == 3 &&
+               v == op->getOperand(2)) {
+      // special case for MulOp's shift operand
+      operands.push_back(v);
     } else {
       return std::nullopt;
     }
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 51d7f828510613..ac4d466aef94b2 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -183,7 +183,7 @@ func.func @test_pad_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<i8>) -> t
 // -----
 
 func.func @test_pad_io_rank_mismatch(%arg0: tensor<13x21xf32>) {
-  %padding = tosa.const_shape {value = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4>  
+  %padding = tosa.const_shape {value = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4>
   // expected-error@+1 {{'tosa.pad' op expect same input and output tensor rank.}}
   %1 = tosa.pad %arg0, %padding : (tensor<13x21xf32>, !tosa.shape<4>) -> tensor<13x21x3xf32>
   return
@@ -211,7 +211,7 @@ func.func @test_pad_invalid_padConst_rank(%arg0: tensor<13x21xf32>, %arg1: tenso
 // -----
 
 func.func @test_pad_padding_shape_mismatch(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
-  %0 = tosa.const_shape {value = dense<1> : tensor<4xindex>} : () -> !tosa.shape<4>  
+  %0 = tosa.const_shape {value = dense<1> : tensor<4xindex>} : () -> !tosa.shape<4>
   // expected-error@+1 {{'tosa.pad' op expected padding tensor dim 0 to have size 6 (2*rank(shape1)) but got size 4}}
   %1 = tosa.pad %arg0, %0 : (tensor<13x21x3xf32>, !tosa.shape<4>) -> tensor<13x21x3xf32>
   return %1 : tensor<13x21x3xf32>
@@ -749,7 +749,7 @@ func.func @test_mul_invalid_shift(%arg0: tensor<13x21x3xi32>, %arg1: tensor<13x1
 
 // CHECK-LABEL: test_mul_missing_shift
 func.func @test_mul_missing_shift(%arg0: tensor<13x21x3xi32>, %arg1: tensor<13x1x3xi32>) -> tensor<13x21x3xi32> {
-  // expected-error@+1 {{'tosa.mul' op expected 3 operands, but found 2}}
+  // this is ok because mul's shift operand is optional for now
   %0 = tosa.mul %arg0, %arg1 : (tensor<13x21x3xi32>, tensor<13x1x3xi32>) -> tensor<13x21x3xi32>
   return %0 : tensor<13x21x3xi32>
 }
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index 9eba2f7e5a06e4..a4596c8f9d5362 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -327,6 +327,14 @@ func.func @test_min(%arg0: tensor<13x21x3xf32>, %arg1: tensor<1x21x3xf32>) -> te
   return %0 : tensor<13x21x3xf32>
 }
 
+// -----
+// CHECK-LABEL: test_mul_scalar_with_unranked_output
+func.func @test_mul_scalar_with_unranked_output(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<*xf32> {
+  %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+  %0 = tosa.mul %arg0, %arg1, %shift : (tensor<f32>, tensor<f32>, tensor<1xi8>) -> tensor<*xf32>
+  return %0 : tensor<*xf32>
+}
+
 // -----
 // CHECK-LABEL: mul
 func.func @test_mul(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xf32> {

@Jerry-Ge Jerry-Ge merged commit 79df1c3 into llvm:main Jan 31, 2025
11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants