Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][tosa] Change the shift of mul to be required #125297

Merged
merged 1 commit into from
Feb 11, 2025

Conversation

Tai78641
Copy link
Contributor

Change the shift operand for the mul operator to be a required operand.

Also defined shift to be Tosa_ScalarInt8Tensor which requires that it is a rank-1 tensor
whose shape is [1] (ie, tensor containing a single element)

@llvmbot
Copy link
Member

llvmbot commented Jan 31, 2025

@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-mlir-tosa

Author: Tai Ly (Tai78641)

Changes

Change the shift operand for the mul operator to be a required operand.

Also defined shift to be Tosa_ScalarInt8Tensor which requires that it is a rank-1 tensor
whose shape is [1] (ie, tensor containing a single element)


Patch is 29.12 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/125297.diff

12 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (+2-1)
  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td (+12)
  • (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp (+13-8)
  • (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+12-9)
  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp (+2-4)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir (+4-2)
  • (modified) mlir/test/Dialect/Tosa/canonicalize.mlir (+9-5)
  • (modified) mlir/test/Dialect/Tosa/constant-op-fold.mlir (+10-5)
  • (modified) mlir/test/Dialect/Tosa/invalid.mlir (+35-7)
  • (modified) mlir/test/Dialect/Tosa/ops.mlir (+2-1)
  • (modified) mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir (+18-16)
  • (modified) mlir/test/Dialect/Tosa/tosa-reduce-transposes.mlir (+6-4)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index c59c582a1f5221..13e4376de8aa96 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -812,7 +812,8 @@ def Tosa_MulOp : Tosa_Op<"mul", [
   let arguments = (ins
     Tosa_Tensor:$input1,
     Tosa_Tensor:$input2,
-    Optional<TosaTensorRankOf<[Tosa_Int8], [1]>>:$shift
+    // Apply right shift on i32_t input data only
+    Tosa_ScalarInt8Tensor:$shift
   );
 
   let results = (outs
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
index 5693acf3a01db4..d02bf1589f44b0 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
@@ -93,6 +93,10 @@ def HasNo0Dimensions : And<[
     IsRankedTensorTypePred,
     CPred<"::llvm::all_of(::llvm::cast<::mlir::RankedTensorType>($_self).getShape(), [](auto v) { return v != 0; })">]>;
 
+def AllDimensionsAreSizeOne : And<[
+    IsRankedTensorTypePred,
+    CPred<"::llvm::all_of(::llvm::cast<::mlir::RankedTensorType>($_self).getShape(), [](auto v) { return v == 1; })">]>;
+
 class TosaTensorOf<
     list<Type> allowedTypes, string summary = "tosa-conformant tensor">
     : TensorOf<allowedTypes, [Or<[HasNo0Dimensions, IsUnrankedTensorTypePred]>], summary>;
@@ -109,6 +113,11 @@ class TosaTensorRankOf<list<Type> allowedTypes, list<int> ranks>
       [HasAnyRankOfPred<ranks>],
       !interleave(!foreach(rank, ranks, rank # "D"), "/") # " tensor">;
 
+class TosaScalarTensorOf<list<Type> allowedTypes, list<int> ranks>
+    : TosaRankedTensorOf<allowedTypes,
+      [HasAnyRankOfPred<ranks>, AllDimensionsAreSizeOne],
+      "tosa-conformant scalar tensor">;
+
 //===----------------------------------------------------------------------===//
 // Tensor types
 //===----------------------------------------------------------------------===//
@@ -139,6 +148,9 @@ class Tosa_TensorOfOrNone<list<Type> allowedTypes, string description = ""> :
 // Rank-0 (scalar) tensor
 def Tosa_ScalarTensor : TosaTensorRankOf<[Tosa_AnyNumber], [0]>;
 
+// Scalar tensors: Rank-1 (with only one element)
+def Tosa_ScalarInt8Tensor : TosaScalarTensorOf<[Tosa_Int8], [1]>;
+
 // We include unranked tensors as a supported type for all possible tosa
 // Tensors as unranked does not guarantee invalid. If unranked tensors exist
 // they should be shape propagate used Tosa's shape inference pass and verified
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index b0eb2d6cbc30b6..a0dfee80360688 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -92,22 +92,27 @@ static Value createLinalgBodyCalculationForElementwiseOp(
   // tosa::MulOp
   if (isa<tosa::MulOp>(op)) {
     auto shift_val = cast<tosa::MulOp>(op).getShift();
+    ElementsAttr shift_elem;
+    if (!shift_val.getImpl() ||
+        !matchPattern(shift_val, m_Constant(&shift_elem))) {
+      (void)rewriter.notifyMatchFailure(op, "shift value of mul not found");
+    }
+
+    int32_t shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
 
     if (isa<FloatType>(elementTy)) {
+      if (shift != 0) {
+        (void)rewriter.notifyMatchFailure(op,
+                                          "Cannot have shift value for float");
+        return nullptr;
+      }
       return rewriter.create<arith::MulFOp>(loc, resultTypes, args[0], args[1]);
     }
 
     if (isa<IntegerType>(elementTy)) {
-      int32_t shift = 0;
-      ElementsAttr shift_elem;
-      if (shift_val.getImpl() &&
-          matchPattern(shift_val, m_Constant(&shift_elem))) {
-        // Explicit shift is set.
-        shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
-      }
-
       Value a = args[0];
       Value b = args[1];
+
       if (shift > 0) {
         auto shiftConst =
             rewriter.create<arith::ConstantIntOp>(loc, shift, /*bitwidth=*/8);
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 0a10439db40803..43470a81cd57ab 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -963,16 +963,10 @@ LogicalResult tosa::MulOp::inferReturnTypeComponents(
     ValueShapeRange operands, DictionaryAttr attributes,
     OpaqueProperties properties, RegionRange regions,
     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
-  LogicalResult status = success();
+  // mul op's output shape only depend on input1 and input2, not on shift
+  ValueShapeRange twoInputs = operands.drop_back();
   llvm::SmallVector<int64_t> outShape;
-  if (operands.size() == 2) {
-    status = resolveBroadcastShape(operands, outShape);
-  } else {
-    // mul op's output shape only depend on input1 and input2, not on shift
-    ValueShapeRange two_inputs = operands.drop_back();
-    status = resolveBroadcastShape(two_inputs, outShape);
-  }
-  if (status.failed()) {
+  if (resolveBroadcastShape(twoInputs, outShape).failed()) {
     inferredReturnShapes.push_back(ShapedTypeComponents());
   } else {
     inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
@@ -1007,6 +1001,15 @@ LogicalResult tosa::MulOp::verify() {
         return emitOpError(
             "requires the same element type for all operands and results");
     }
+
+    // verify shift has value 0 for non-integer types
+    ElementsAttr shift_elem;
+    if (matchPattern(getShift(), m_Constant(&shift_elem))) {
+      int32_t shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
+      if (shift != 0) {
+        return emitOpError() << "require shift to be 0 for float type";
+      }
+    }
   }
 
   // Verify the op has same ranks for all main operands (excludes extra operands
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp
index 520f283a3ba888..4c312ffd124e24 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp
@@ -287,8 +287,7 @@ bool TosaReduceTransposes::collectFanIn(Operation *op,
 
     for (Value operand : op->getOperands()) {
       // If this is a problem in future, think about alternatives to recursion.
-      if (llvm::isa<tosa::MulOp>(op) && op->getNumOperands() == 3 &&
-          operand == op->getOperand(2)) {
+      if (llvm::isa<tosa::MulOp>(op) && operand == op->getOperand(2)) {
         // do not recurse into MulOp's shift operand
         continue;
       }
@@ -332,8 +331,7 @@ std::optional<Value> TosaReduceTransposes::buildMappedToValue(
   for (Value v : op->getOperands()) {
     if (valuesMap.contains(v)) {
       operands.push_back(valuesMap.at(v));
-    } else if (llvm::isa<tosa::MulOp>(op) && op->getNumOperands() == 3 &&
-               v == op->getOperand(2)) {
+    } else if (llvm::isa<tosa::MulOp>(op) && v == op->getOperand(2)) {
       // special case for MulOp's shift operand
       operands.push_back(v);
     } else {
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index f9bdcefa35317a..3704b4c29fceaf 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -472,7 +472,8 @@ func.func @test_simple_f32(%arg0: tensor<1xf32>) -> () {
 
   // CHECK: linalg.generic
   // CHECK: arith.mulf
-  %4 = tosa.mul %0, %1 : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
+  %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+  %4 = tosa.mul %0, %1, %shift : (tensor<1xf32>, tensor<1xf32>, tensor<1xi8>) -> tensor<1xf32>
 
   // CHECK: linalg.generic
   // CHECK: arith.negf
@@ -618,7 +619,8 @@ func.func @test_simple_i16(%arg0: tensor<1xi16>) -> () {
   // CHECK: arith.extsi
   // CHECK: arith.extsi
   // CHECK: arith.muli
-  %0 = tosa.mul %arg0, %arg0 : (tensor<1xi16>, tensor<1xi16>) -> tensor<1xi32>
+  %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+  %0 = tosa.mul %arg0, %arg0, %shift : (tensor<1xi16>, tensor<1xi16>, tensor<1xi8>) -> tensor<1xi32>
 
   return
 }
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 71a7e2826a63cc..a9895dd45d62bd 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -331,8 +331,9 @@ func.func @pad_determine_val_quant(%arg0: tensor<?x?xi32>, %arg1 : tensor<2x2xi3
 func.func @mul_one_float(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
   // CHECK: return %arg0
   // CHECK-NOT: tosa.mul
+  %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
   %ones = "tosa.const"() {value = dense<1.0> : tensor<2x3xf32>} : () -> tensor<2x3xf32>
-  %1 = tosa.mul %arg0, %ones : (tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
+  %1 = tosa.mul %arg0, %ones, %shift : (tensor<2x3xf32>, tensor<2x3xf32>, tensor<1xi8>) -> tensor<2x3xf32>
   return %1 : tensor<2x3xf32>
 }
 
@@ -343,7 +344,8 @@ func.func @mul_bcast_one_float(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
   // CHECK: return %arg0
   // CHECK-NOT: tosa.mul
   %ones = "tosa.const"() {value = dense<1.0> : tensor<1x1xf32>} : () -> tensor<1x1xf32>
-  %1 = tosa.mul %ones, %arg0 : (tensor<1x1xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
+  %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+  %1 = tosa.mul %ones, %arg0, %shift : (tensor<1x1xf32>, tensor<2x3xf32>, tensor<1xi8>) -> tensor<2x3xf32>
   return %1 : tensor<2x3xf32>
 }
 
@@ -379,11 +381,12 @@ func.func @mul_zero_broadcast(%arg0: tensor<2x3xf32>) -> (tensor<2x3xf32>, tenso
   // CHECK: %[[ZERO:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<2x3xf32>}
   // CHECK-NOT: tosa.mul
   %zeros = "tosa.const"() {value = dense<0.0> : tensor<1x1xf32>} : () -> tensor<1x1xf32>
-  %1 = tosa.mul %arg0, %zeros : (tensor<2x3xf32>, tensor<1x1xf32>) -> tensor<2x3xf32>
+  %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+  %1 = tosa.mul %arg0, %zeros, %shift : (tensor<2x3xf32>, tensor<1x1xf32>, tensor<1xi8>) -> tensor<2x3xf32>
 
   // CHECK-NOT: tosa.mul
   // CHECK: return %[[ZERO]], %[[ZERO]]
-  %2 = tosa.mul %zeros, %arg0 : (tensor<1x1xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
+  %2 = tosa.mul %zeros, %arg0, %shift : (tensor<1x1xf32>, tensor<2x3xf32>, tensor<1xi8>) -> tensor<2x3xf32>
   return %1, %2 : tensor<2x3xf32>, tensor<2x3xf32>
 }
 
@@ -966,7 +969,8 @@ func.func @mul_quant_nofold() -> tensor<1x!quant.uniform<i8:f32, 3.0757404601899
    // CHECK: tosa.mul
    %0 = "tosa.const"() {value = dense<0> : tensor<1xi8>} : () -> tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
    %1 = "tosa.const"() {value = dense<1> : tensor<1xi8>} : () -> tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
-   %2 = tosa.mul %0, %1 : (tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>, tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>)-> tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
+   %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+   %2 = tosa.mul %0, %1, %shift : (tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>, tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>, tensor<1xi8>) -> tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
    return %2 : tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
 }
 
diff --git a/mlir/test/Dialect/Tosa/constant-op-fold.mlir b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
index 32677f06e22523..89c17fa1ab5c83 100644
--- a/mlir/test/Dialect/Tosa/constant-op-fold.mlir
+++ b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
@@ -238,7 +238,8 @@ func.func @fold_div_splat_i32() -> tensor<i32> {
 func.func @fold_mul_zero_rhs_f32(%arg0: tensor<f32>) -> tensor<f32> {
   %zero = "tosa.const"() {value = dense<0.0> : tensor<f32>} : () -> tensor<f32>
   // CHECK: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0.000000e+00>
-  %mul = tosa.mul %arg0, %zero : (tensor<f32>, tensor<f32>) -> tensor<f32>
+  %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+  %mul = tosa.mul %arg0, %zero, %shift : (tensor<f32>, tensor<f32>, tensor<1xi8>) -> tensor<f32>
   // CHECK: return %[[ZERO]]
   return %mul : tensor<f32>
 }
@@ -249,7 +250,8 @@ func.func @fold_mul_zero_rhs_f32(%arg0: tensor<f32>) -> tensor<f32> {
 func.func @fold_mul_zero_lhs_f32(%arg0: tensor<f32>) -> tensor<f32> {
   %zero = "tosa.const"() {value = dense<0.0> : tensor<f32>} : () -> tensor<f32>
   // CHECK: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0.000000e+00>
-  %mul = tosa.mul %zero, %arg0 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+  %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+  %mul = tosa.mul %zero, %arg0, %shift : (tensor<f32>, tensor<f32>, tensor<1xi8>) -> tensor<f32>
   // CHECK: return %[[ZERO]]
   return %mul : tensor<f32>
 }
@@ -283,7 +285,8 @@ func.func @fold_mul_zero_lhs_i32(%arg0: tensor<i32>) -> tensor<i32> {
 // CHECK-LABEL: @fold_mul_one_rhs_f32
 func.func @fold_mul_one_rhs_f32(%arg0: tensor<f32>) -> tensor<f32> {
   %one = "tosa.const"() {value = dense<1.0> : tensor<f32>} : () -> tensor<f32>
-  %mul = tosa.mul %arg0, %one : (tensor<f32>, tensor<f32>) -> tensor<f32>
+  %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+  %mul = tosa.mul %arg0, %one, %shift : (tensor<f32>, tensor<f32>, tensor<1xi8>) -> tensor<f32>
   // CHECK: return %arg0
   return %mul : tensor<f32>
 }
@@ -293,7 +296,8 @@ func.func @fold_mul_one_rhs_f32(%arg0: tensor<f32>) -> tensor<f32> {
 // CHECK-LABEL: @fold_mul_one_lhs_f32
 func.func @fold_mul_one_lhs_f32(%arg0: tensor<f32>) -> tensor<f32> {
   %one = "tosa.const"() {value = dense<1.0> : tensor<f32>} : () -> tensor<f32>
-  %mul = tosa.mul %one, %arg0 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+  %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+  %mul = tosa.mul %one, %arg0, %shift : (tensor<f32>, tensor<f32>, tensor<1xi8>) -> tensor<f32>
   // CHECK: return %arg0
   return %mul : tensor<f32>
 }
@@ -339,7 +343,8 @@ func.func @fold_mul_splat_i8() -> tensor<10xi32> {
 func.func @fold_mul_splat_f32() -> tensor<10xf32> {
   %one = "tosa.const"() {value = dense<3.0> : tensor<10xf32>} : () -> tensor<10xf32>
   %two = "tosa.const"() {value = dense<2.0> : tensor<10xf32>} : () -> tensor<10xf32>
-  %mul = tosa.mul %one, %two : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
+  %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+  %mul = tosa.mul %one, %two, %shift : (tensor<10xf32>, tensor<10xf32>, tensor<1xi8>) -> tensor<10xf32>
   // CHECK: %[[THREE:.+]] = "tosa.const"() <{value = dense<6.000000e+00> : tensor<10xf32>}
   // CHECK: return %[[THREE]]
   return %mul : tensor<10xf32>
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index ac4d466aef94b2..5c1dbcac1bcb83 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -730,26 +730,27 @@ func.func @test_transpose_conv2d_invalid_outshape(%arg0: tensor<1x32x32x8xf32>,
 
 // CHECK-LABEL: test_mul_type_mismatch
 func.func @test_mul_type_mismatch(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf16>) -> tensor<13x21x3xf32> {
+  %shift = "tosa.const"() {value = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
   // expected-error@+1 {{'tosa.mul' op requires the same element type for all operands}}
-  %0 = tosa.mul %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<13x1x3xf16>) -> tensor<13x21x3xf32>
+  %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xf32>, tensor<13x1x3xf16>, tensor<1xi8>) -> tensor<13x21x3xf32>
   return %0 : tensor<13x21x3xf32>
 }
 
 // -----
 
 // CHECK-LABEL: test_mul_invalid_shift
-func.func @test_mul_invalid_shift(%arg0: tensor<13x21x3xi32>, %arg1: tensor<13x1x3xi32>) -> tensor<13x21x3xi32> {
-  %shift = "tosa.const"() {value = dense<0.0> : tensor<f32>} : () -> tensor<f32>
-  // expected-error@+1 {{'tosa.mul' op operand #2 must be 1D tensor of 8-bit signless integer values, but got 'tensor<f32>'}}
-  %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xi32>, tensor<13x1x3xi32>, tensor<f32>) -> tensor<13x21x3xi32>
-  return %0 : tensor<13x21x3xi32>
+func.func @test_mul_invalid_shift(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xf32> {
+  %shift = "tosa.const"() {value = dense<1> : tensor<1xi8>} : () -> tensor<1xi8>
+  // expected-error@+1 {{'tosa.mul' op require shift to be 0 for float type}}
+  %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xf32>, tensor<13x1x3xf32>, tensor<1xi8>) -> tensor<13x21x3xf32>
+  return %0 : tensor<13x21x3xf32>
 }
 
 // -----
 
 // CHECK-LABEL: test_mul_missing_shift
 func.func @test_mul_missing_shift(%arg0: tensor<13x21x3xi32>, %arg1: tensor<13x1x3xi32>) -> tensor<13x21x3xi32> {
-  // this is ok because mul's shift operand is optional for now
+  // expected-error@+1 {{'tosa.mul' op expected 3 operands, but found 2}}
   %0 = tosa.mul %arg0, %arg1 : (tensor<13x21x3xi32>, tensor<13x1x3xi32>) -> tensor<13x21x3xi32>
   return %0 : tensor<13x21x3xi32>
 }
@@ -1061,3 +1062,30 @@ func.func @test_sub_with_unequal_result_ranks(%arg0: tensor<1x21x3xf32>, %arg1:
   %0 = tosa.sub %arg0, %arg1 : (tensor<1x21x3xf32>, tensor<13x21x3xf32>) -> tensor<1x13x21x3xf32>
   return %0 : tensor<1x13x21x3xf32>
 }
+
+// -----
+// CHECK-LABEL: test_mul_non_scalar_shift_2d
+func.func @test_mul_non_scalar_shift_2d(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xf32> {
+  %shift = "tosa.const"() <{value = dense<0> : tensor<1x1xi8>}> : () -> tensor<1x1xi8>
+  // expected-error@+1 {{'tosa.mul' op operand #2 must be tosa-conformant scalar tensor of 8-bit signless integer values, but got 'tensor<1x1xi8>'}}
+  %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xf32>, tensor<13x1x3xf32>, tensor<1x1xi8>) -> tensor<13x21x3xf32>
+  return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+// CHECK-LABEL: test_mul_non_scalar_shift_1d
+func.func @test_mul_non_scalar_shift_1d(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xf32> {
+  %shift = "tosa.const"() <{value = dense<0> : tensor<2xi8>}> : () -> tensor<2xi8>
+  // expected-error@+1 {{'tosa.mul' op operand #2 must be tosa-conformant scalar tensor of 8-bit signless integer values, but got 'tensor<2xi8>'}}
+  %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xf32>, tensor<13x1x3xf32>, tensor<2xi8>) -> tensor<13x21x3xf32>
+  return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+// CHECK-LABEL: test_mul_non_broadcast
+func.func @test_mul_non_broadcast(%arg0: tensor<13x21x2xf32>, %arg1: tensor<3x1x3xf32>) -> tensor<13x21x3xf32> {
+  %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+  // expected-error@+1 {{'tosa.mul' op operands don't have broadcast-compatible shapes}}
+  %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x2xf32>, tensor<3x1x3xf32>, tensor<1xi8>) -> tensor<13x21x3xf32>
+  return %0 : tensor<13x21x3xf32>
+}
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index a4596c8f9d5362..2774a82d6fb8b5 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -338,7 +338,8 @@ func.func @test_mul_scalar_with_unranked_output(%arg0: tensor<f32>, %arg1: tenso
 // -----
 // CHECK-LABEL: mul
 func.func @test_mul(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xf32> {
-  %0 = tosa.mul %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<13x1x3xf32>) -> tensor<13x21x3xf32>
+  %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+  %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xf32>, tensor<13x1x3xf32>, tensor<1xi8>) -> tensor<13x21x3xf32>
   return %0 : tensor<13x21x3xf32>
 }
 
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index 73eabab657f380..028105855ce25b 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -114,23 +114,24 @@ func.func @test_binary_scalar_f32(%arg0 : tensor<4xf32>, %arg1 : tensor<1xf32>)
   // CHECK: tosa.minimum %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32>
   %2 = tosa.minimum %arg0, %arg1 ...
[truncated]

@@ -812,7 +812,8 @@ def Tosa_MulOp : Tosa_Op<"mul", [
let arguments = (ins
Tosa_Tensor:$input1,
Tosa_Tensor:$input2,
Optional<TosaTensorRankOf<[Tosa_Int8], [1]>>:$shift
// Apply right shift on i32_t input data only
Tosa_ScalarInt8Tensor:$shift
Copy link
Contributor

@GeorgeARM GeorgeARM Feb 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why shall this be required at dialect level? Not in favor of a compulsory parameter when it doesn't make sense. Any good reasoning behind this?

@@ -139,6 +148,9 @@ class Tosa_TensorOfOrNone<list<Type> allowedTypes, string description = ""> :
// Rank-0 (scalar) tensor
def Tosa_ScalarTensor : TosaTensorRankOf<[Tosa_AnyNumber], [0]>;

// Scalar tensors: Rank-1 (with only one element)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Scalar tensors have rank0. This contradicts the above definition. Maybe worth of a different name?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

renamed to Tosa_Rank0Tensor

@Jerry-Ge
Copy link
Member

Hi @Tai78641 , can you do the rebase.

Set the shift to be a mandatory operand.

Change-Id: Ic8f6e2f653c6f5875f8b789a2da27d2834aa442a
Signed-off-by: Tai Ly <[email protected]>
@Tai78641
Copy link
Contributor Author

Hi @Tai78641 , can you do the rebase.

rebased

@Jerry-Ge
Copy link
Member

Approved, but let @GeorgeARM to take another look to double-confirm.

@Jerry-Ge Jerry-Ge merged commit 20ae283 into llvm:main Feb 11, 2025
8 checks passed
Icohedron pushed a commit to Icohedron/llvm-project that referenced this pull request Feb 11, 2025
Change the shift operand for the mul operator to be a required operand.

Also defined shift to be Tosa_ScalarInt8Tensor which requires that it is
a rank-1 tensor
whose shape is [1] (ie, tensor containing a single element)

Signed-off-by: Tai Ly <[email protected]>
alejandro-alvarez-sonarsource pushed a commit to alejandro-alvarez-sonarsource/llvm-project that referenced this pull request Feb 12, 2025
Change the shift operand for the mul operator to be a required operand.

Also defined shift to be Tosa_ScalarInt8Tensor which requires that it is
a rank-1 tensor
whose shape is [1] (ie, tensor containing a single element)

Signed-off-by: Tai Ly <[email protected]>
flovent pushed a commit to flovent/llvm-project that referenced this pull request Feb 13, 2025
Change the shift operand for the mul operator to be a required operand.

Also defined shift to be Tosa_ScalarInt8Tensor which requires that it is
a rank-1 tensor
whose shape is [1] (ie, tensor containing a single element)

Signed-off-by: Tai Ly <[email protected]>
joaosaffran pushed a commit to joaosaffran/llvm-project that referenced this pull request Feb 14, 2025
Change the shift operand for the mul operator to be a required operand.

Also defined shift to be Tosa_ScalarInt8Tensor which requires that it is
a rank-1 tensor
whose shape is [1] (ie, tensor containing a single element)

Signed-off-by: Tai Ly <[email protected]>
sivan-shani pushed a commit to sivan-shani/llvm-project that referenced this pull request Feb 24, 2025
Change the shift operand for the mul operator to be a required operand.

Also defined shift to be Tosa_ScalarInt8Tensor which requires that it is
a rank-1 tensor
whose shape is [1] (ie, tensor containing a single element)

Signed-off-by: Tai Ly <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants