-
Notifications
You must be signed in to change notification settings - Fork 12.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][tosa] Change the shift of mul to be required #125297
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-tosa Author: Tai Ly (Tai78641) ChangesChange the shift operand for the mul operator to be a required operand. Also defined shift to be Tosa_ScalarInt8Tensor which requires that it is a rank-1 tensor Patch is 29.12 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/125297.diff 12 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index c59c582a1f5221..13e4376de8aa96 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -812,7 +812,8 @@ def Tosa_MulOp : Tosa_Op<"mul", [
let arguments = (ins
Tosa_Tensor:$input1,
Tosa_Tensor:$input2,
- Optional<TosaTensorRankOf<[Tosa_Int8], [1]>>:$shift
+ // Apply right shift on i32_t input data only
+ Tosa_ScalarInt8Tensor:$shift
);
let results = (outs
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
index 5693acf3a01db4..d02bf1589f44b0 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
@@ -93,6 +93,10 @@ def HasNo0Dimensions : And<[
IsRankedTensorTypePred,
CPred<"::llvm::all_of(::llvm::cast<::mlir::RankedTensorType>($_self).getShape(), [](auto v) { return v != 0; })">]>;
+def AllDimensionsAreSizeOne : And<[
+ IsRankedTensorTypePred,
+ CPred<"::llvm::all_of(::llvm::cast<::mlir::RankedTensorType>($_self).getShape(), [](auto v) { return v == 1; })">]>;
+
class TosaTensorOf<
list<Type> allowedTypes, string summary = "tosa-conformant tensor">
: TensorOf<allowedTypes, [Or<[HasNo0Dimensions, IsUnrankedTensorTypePred]>], summary>;
@@ -109,6 +113,11 @@ class TosaTensorRankOf<list<Type> allowedTypes, list<int> ranks>
[HasAnyRankOfPred<ranks>],
!interleave(!foreach(rank, ranks, rank # "D"), "/") # " tensor">;
+class TosaScalarTensorOf<list<Type> allowedTypes, list<int> ranks>
+ : TosaRankedTensorOf<allowedTypes,
+ [HasAnyRankOfPred<ranks>, AllDimensionsAreSizeOne],
+ "tosa-conformant scalar tensor">;
+
//===----------------------------------------------------------------------===//
// Tensor types
//===----------------------------------------------------------------------===//
@@ -139,6 +148,9 @@ class Tosa_TensorOfOrNone<list<Type> allowedTypes, string description = ""> :
// Rank-0 (scalar) tensor
def Tosa_ScalarTensor : TosaTensorRankOf<[Tosa_AnyNumber], [0]>;
+// Scalar tensors: Rank-1 (with only one element)
+def Tosa_ScalarInt8Tensor : TosaScalarTensorOf<[Tosa_Int8], [1]>;
+
// We include unranked tensors as a supported type for all possible tosa
// Tensors as unranked does not guarantee invalid. If unranked tensors exist
// they should be shape propagate used Tosa's shape inference pass and verified
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index b0eb2d6cbc30b6..a0dfee80360688 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -92,22 +92,27 @@ static Value createLinalgBodyCalculationForElementwiseOp(
// tosa::MulOp
if (isa<tosa::MulOp>(op)) {
auto shift_val = cast<tosa::MulOp>(op).getShift();
+ ElementsAttr shift_elem;
+ if (!shift_val.getImpl() ||
+ !matchPattern(shift_val, m_Constant(&shift_elem))) {
+ (void)rewriter.notifyMatchFailure(op, "shift value of mul not found");
+ }
+
+ int32_t shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
if (isa<FloatType>(elementTy)) {
+ if (shift != 0) {
+ (void)rewriter.notifyMatchFailure(op,
+ "Cannot have shift value for float");
+ return nullptr;
+ }
return rewriter.create<arith::MulFOp>(loc, resultTypes, args[0], args[1]);
}
if (isa<IntegerType>(elementTy)) {
- int32_t shift = 0;
- ElementsAttr shift_elem;
- if (shift_val.getImpl() &&
- matchPattern(shift_val, m_Constant(&shift_elem))) {
- // Explicit shift is set.
- shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
- }
-
Value a = args[0];
Value b = args[1];
+
if (shift > 0) {
auto shiftConst =
rewriter.create<arith::ConstantIntOp>(loc, shift, /*bitwidth=*/8);
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 0a10439db40803..43470a81cd57ab 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -963,16 +963,10 @@ LogicalResult tosa::MulOp::inferReturnTypeComponents(
ValueShapeRange operands, DictionaryAttr attributes,
OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
- LogicalResult status = success();
+ // mul op's output shape only depend on input1 and input2, not on shift
+ ValueShapeRange twoInputs = operands.drop_back();
llvm::SmallVector<int64_t> outShape;
- if (operands.size() == 2) {
- status = resolveBroadcastShape(operands, outShape);
- } else {
- // mul op's output shape only depend on input1 and input2, not on shift
- ValueShapeRange two_inputs = operands.drop_back();
- status = resolveBroadcastShape(two_inputs, outShape);
- }
- if (status.failed()) {
+ if (resolveBroadcastShape(twoInputs, outShape).failed()) {
inferredReturnShapes.push_back(ShapedTypeComponents());
} else {
inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
@@ -1007,6 +1001,15 @@ LogicalResult tosa::MulOp::verify() {
return emitOpError(
"requires the same element type for all operands and results");
}
+
+ // verify shift has value 0 for non-integer types
+ ElementsAttr shift_elem;
+ if (matchPattern(getShift(), m_Constant(&shift_elem))) {
+ int32_t shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
+ if (shift != 0) {
+ return emitOpError() << "require shift to be 0 for float type";
+ }
+ }
}
// Verify the op has same ranks for all main operands (excludes extra operands
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp
index 520f283a3ba888..4c312ffd124e24 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp
@@ -287,8 +287,7 @@ bool TosaReduceTransposes::collectFanIn(Operation *op,
for (Value operand : op->getOperands()) {
// If this is a problem in future, think about alternatives to recursion.
- if (llvm::isa<tosa::MulOp>(op) && op->getNumOperands() == 3 &&
- operand == op->getOperand(2)) {
+ if (llvm::isa<tosa::MulOp>(op) && operand == op->getOperand(2)) {
// do not recurse into MulOp's shift operand
continue;
}
@@ -332,8 +331,7 @@ std::optional<Value> TosaReduceTransposes::buildMappedToValue(
for (Value v : op->getOperands()) {
if (valuesMap.contains(v)) {
operands.push_back(valuesMap.at(v));
- } else if (llvm::isa<tosa::MulOp>(op) && op->getNumOperands() == 3 &&
- v == op->getOperand(2)) {
+ } else if (llvm::isa<tosa::MulOp>(op) && v == op->getOperand(2)) {
// special case for MulOp's shift operand
operands.push_back(v);
} else {
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index f9bdcefa35317a..3704b4c29fceaf 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -472,7 +472,8 @@ func.func @test_simple_f32(%arg0: tensor<1xf32>) -> () {
// CHECK: linalg.generic
// CHECK: arith.mulf
- %4 = tosa.mul %0, %1 : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
+ %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %4 = tosa.mul %0, %1, %shift : (tensor<1xf32>, tensor<1xf32>, tensor<1xi8>) -> tensor<1xf32>
// CHECK: linalg.generic
// CHECK: arith.negf
@@ -618,7 +619,8 @@ func.func @test_simple_i16(%arg0: tensor<1xi16>) -> () {
// CHECK: arith.extsi
// CHECK: arith.extsi
// CHECK: arith.muli
- %0 = tosa.mul %arg0, %arg0 : (tensor<1xi16>, tensor<1xi16>) -> tensor<1xi32>
+ %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %0 = tosa.mul %arg0, %arg0, %shift : (tensor<1xi16>, tensor<1xi16>, tensor<1xi8>) -> tensor<1xi32>
return
}
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 71a7e2826a63cc..a9895dd45d62bd 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -331,8 +331,9 @@ func.func @pad_determine_val_quant(%arg0: tensor<?x?xi32>, %arg1 : tensor<2x2xi3
func.func @mul_one_float(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
// CHECK: return %arg0
// CHECK-NOT: tosa.mul
+ %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
%ones = "tosa.const"() {value = dense<1.0> : tensor<2x3xf32>} : () -> tensor<2x3xf32>
- %1 = tosa.mul %arg0, %ones : (tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
+ %1 = tosa.mul %arg0, %ones, %shift : (tensor<2x3xf32>, tensor<2x3xf32>, tensor<1xi8>) -> tensor<2x3xf32>
return %1 : tensor<2x3xf32>
}
@@ -343,7 +344,8 @@ func.func @mul_bcast_one_float(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
// CHECK: return %arg0
// CHECK-NOT: tosa.mul
%ones = "tosa.const"() {value = dense<1.0> : tensor<1x1xf32>} : () -> tensor<1x1xf32>
- %1 = tosa.mul %ones, %arg0 : (tensor<1x1xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
+ %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %1 = tosa.mul %ones, %arg0, %shift : (tensor<1x1xf32>, tensor<2x3xf32>, tensor<1xi8>) -> tensor<2x3xf32>
return %1 : tensor<2x3xf32>
}
@@ -379,11 +381,12 @@ func.func @mul_zero_broadcast(%arg0: tensor<2x3xf32>) -> (tensor<2x3xf32>, tenso
// CHECK: %[[ZERO:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<2x3xf32>}
// CHECK-NOT: tosa.mul
%zeros = "tosa.const"() {value = dense<0.0> : tensor<1x1xf32>} : () -> tensor<1x1xf32>
- %1 = tosa.mul %arg0, %zeros : (tensor<2x3xf32>, tensor<1x1xf32>) -> tensor<2x3xf32>
+ %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %1 = tosa.mul %arg0, %zeros, %shift : (tensor<2x3xf32>, tensor<1x1xf32>, tensor<1xi8>) -> tensor<2x3xf32>
// CHECK-NOT: tosa.mul
// CHECK: return %[[ZERO]], %[[ZERO]]
- %2 = tosa.mul %zeros, %arg0 : (tensor<1x1xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
+ %2 = tosa.mul %zeros, %arg0, %shift : (tensor<1x1xf32>, tensor<2x3xf32>, tensor<1xi8>) -> tensor<2x3xf32>
return %1, %2 : tensor<2x3xf32>, tensor<2x3xf32>
}
@@ -966,7 +969,8 @@ func.func @mul_quant_nofold() -> tensor<1x!quant.uniform<i8:f32, 3.0757404601899
// CHECK: tosa.mul
%0 = "tosa.const"() {value = dense<0> : tensor<1xi8>} : () -> tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
%1 = "tosa.const"() {value = dense<1> : tensor<1xi8>} : () -> tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
- %2 = tosa.mul %0, %1 : (tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>, tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>)-> tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
+ %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %2 = tosa.mul %0, %1, %shift : (tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>, tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>, tensor<1xi8>) -> tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
return %2 : tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
}
diff --git a/mlir/test/Dialect/Tosa/constant-op-fold.mlir b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
index 32677f06e22523..89c17fa1ab5c83 100644
--- a/mlir/test/Dialect/Tosa/constant-op-fold.mlir
+++ b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
@@ -238,7 +238,8 @@ func.func @fold_div_splat_i32() -> tensor<i32> {
func.func @fold_mul_zero_rhs_f32(%arg0: tensor<f32>) -> tensor<f32> {
%zero = "tosa.const"() {value = dense<0.0> : tensor<f32>} : () -> tensor<f32>
// CHECK: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0.000000e+00>
- %mul = tosa.mul %arg0, %zero : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %mul = tosa.mul %arg0, %zero, %shift : (tensor<f32>, tensor<f32>, tensor<1xi8>) -> tensor<f32>
// CHECK: return %[[ZERO]]
return %mul : tensor<f32>
}
@@ -249,7 +250,8 @@ func.func @fold_mul_zero_rhs_f32(%arg0: tensor<f32>) -> tensor<f32> {
func.func @fold_mul_zero_lhs_f32(%arg0: tensor<f32>) -> tensor<f32> {
%zero = "tosa.const"() {value = dense<0.0> : tensor<f32>} : () -> tensor<f32>
// CHECK: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0.000000e+00>
- %mul = tosa.mul %zero, %arg0 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %mul = tosa.mul %zero, %arg0, %shift : (tensor<f32>, tensor<f32>, tensor<1xi8>) -> tensor<f32>
// CHECK: return %[[ZERO]]
return %mul : tensor<f32>
}
@@ -283,7 +285,8 @@ func.func @fold_mul_zero_lhs_i32(%arg0: tensor<i32>) -> tensor<i32> {
// CHECK-LABEL: @fold_mul_one_rhs_f32
func.func @fold_mul_one_rhs_f32(%arg0: tensor<f32>) -> tensor<f32> {
%one = "tosa.const"() {value = dense<1.0> : tensor<f32>} : () -> tensor<f32>
- %mul = tosa.mul %arg0, %one : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %mul = tosa.mul %arg0, %one, %shift : (tensor<f32>, tensor<f32>, tensor<1xi8>) -> tensor<f32>
// CHECK: return %arg0
return %mul : tensor<f32>
}
@@ -293,7 +296,8 @@ func.func @fold_mul_one_rhs_f32(%arg0: tensor<f32>) -> tensor<f32> {
// CHECK-LABEL: @fold_mul_one_lhs_f32
func.func @fold_mul_one_lhs_f32(%arg0: tensor<f32>) -> tensor<f32> {
%one = "tosa.const"() {value = dense<1.0> : tensor<f32>} : () -> tensor<f32>
- %mul = tosa.mul %one, %arg0 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %mul = tosa.mul %one, %arg0, %shift : (tensor<f32>, tensor<f32>, tensor<1xi8>) -> tensor<f32>
// CHECK: return %arg0
return %mul : tensor<f32>
}
@@ -339,7 +343,8 @@ func.func @fold_mul_splat_i8() -> tensor<10xi32> {
func.func @fold_mul_splat_f32() -> tensor<10xf32> {
%one = "tosa.const"() {value = dense<3.0> : tensor<10xf32>} : () -> tensor<10xf32>
%two = "tosa.const"() {value = dense<2.0> : tensor<10xf32>} : () -> tensor<10xf32>
- %mul = tosa.mul %one, %two : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
+ %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %mul = tosa.mul %one, %two, %shift : (tensor<10xf32>, tensor<10xf32>, tensor<1xi8>) -> tensor<10xf32>
// CHECK: %[[THREE:.+]] = "tosa.const"() <{value = dense<6.000000e+00> : tensor<10xf32>}
// CHECK: return %[[THREE]]
return %mul : tensor<10xf32>
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index ac4d466aef94b2..5c1dbcac1bcb83 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -730,26 +730,27 @@ func.func @test_transpose_conv2d_invalid_outshape(%arg0: tensor<1x32x32x8xf32>,
// CHECK-LABEL: test_mul_type_mismatch
func.func @test_mul_type_mismatch(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf16>) -> tensor<13x21x3xf32> {
+ %shift = "tosa.const"() {value = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
// expected-error@+1 {{'tosa.mul' op requires the same element type for all operands}}
- %0 = tosa.mul %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<13x1x3xf16>) -> tensor<13x21x3xf32>
+ %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xf32>, tensor<13x1x3xf16>, tensor<1xi8>) -> tensor<13x21x3xf32>
return %0 : tensor<13x21x3xf32>
}
// -----
// CHECK-LABEL: test_mul_invalid_shift
-func.func @test_mul_invalid_shift(%arg0: tensor<13x21x3xi32>, %arg1: tensor<13x1x3xi32>) -> tensor<13x21x3xi32> {
- %shift = "tosa.const"() {value = dense<0.0> : tensor<f32>} : () -> tensor<f32>
- // expected-error@+1 {{'tosa.mul' op operand #2 must be 1D tensor of 8-bit signless integer values, but got 'tensor<f32>'}}
- %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xi32>, tensor<13x1x3xi32>, tensor<f32>) -> tensor<13x21x3xi32>
- return %0 : tensor<13x21x3xi32>
+func.func @test_mul_invalid_shift(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xf32> {
+ %shift = "tosa.const"() {value = dense<1> : tensor<1xi8>} : () -> tensor<1xi8>
+ // expected-error@+1 {{'tosa.mul' op require shift to be 0 for float type}}
+ %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xf32>, tensor<13x1x3xf32>, tensor<1xi8>) -> tensor<13x21x3xf32>
+ return %0 : tensor<13x21x3xf32>
}
// -----
// CHECK-LABEL: test_mul_missing_shift
func.func @test_mul_missing_shift(%arg0: tensor<13x21x3xi32>, %arg1: tensor<13x1x3xi32>) -> tensor<13x21x3xi32> {
- // this is ok because mul's shift operand is optional for now
+ // expected-error@+1 {{'tosa.mul' op expected 3 operands, but found 2}}
%0 = tosa.mul %arg0, %arg1 : (tensor<13x21x3xi32>, tensor<13x1x3xi32>) -> tensor<13x21x3xi32>
return %0 : tensor<13x21x3xi32>
}
@@ -1061,3 +1062,30 @@ func.func @test_sub_with_unequal_result_ranks(%arg0: tensor<1x21x3xf32>, %arg1:
%0 = tosa.sub %arg0, %arg1 : (tensor<1x21x3xf32>, tensor<13x21x3xf32>) -> tensor<1x13x21x3xf32>
return %0 : tensor<1x13x21x3xf32>
}
+
+// -----
+// CHECK-LABEL: test_mul_non_scalar_shift_2d
+func.func @test_mul_non_scalar_shift_2d(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xf32> {
+ %shift = "tosa.const"() <{value = dense<0> : tensor<1x1xi8>}> : () -> tensor<1x1xi8>
+ // expected-error@+1 {{'tosa.mul' op operand #2 must be tosa-conformant scalar tensor of 8-bit signless integer values, but got 'tensor<1x1xi8>'}}
+ %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xf32>, tensor<13x1x3xf32>, tensor<1x1xi8>) -> tensor<13x21x3xf32>
+ return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+// CHECK-LABEL: test_mul_non_scalar_shift_1d
+func.func @test_mul_non_scalar_shift_1d(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xf32> {
+ %shift = "tosa.const"() <{value = dense<0> : tensor<2xi8>}> : () -> tensor<2xi8>
+ // expected-error@+1 {{'tosa.mul' op operand #2 must be tosa-conformant scalar tensor of 8-bit signless integer values, but got 'tensor<2xi8>'}}
+ %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xf32>, tensor<13x1x3xf32>, tensor<2xi8>) -> tensor<13x21x3xf32>
+ return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+// CHECK-LABEL: test_mul_non_broadcast
+func.func @test_mul_non_broadcast(%arg0: tensor<13x21x2xf32>, %arg1: tensor<3x1x3xf32>) -> tensor<13x21x3xf32> {
+ %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+ // expected-error@+1 {{'tosa.mul' op operands don't have broadcast-compatible shapes}}
+ %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x2xf32>, tensor<3x1x3xf32>, tensor<1xi8>) -> tensor<13x21x3xf32>
+ return %0 : tensor<13x21x3xf32>
+}
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index a4596c8f9d5362..2774a82d6fb8b5 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -338,7 +338,8 @@ func.func @test_mul_scalar_with_unranked_output(%arg0: tensor<f32>, %arg1: tenso
// -----
// CHECK-LABEL: mul
func.func @test_mul(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xf32> {
- %0 = tosa.mul %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<13x1x3xf32>) -> tensor<13x21x3xf32>
+ %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xf32>, tensor<13x1x3xf32>, tensor<1xi8>) -> tensor<13x21x3xf32>
return %0 : tensor<13x21x3xf32>
}
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index 73eabab657f380..028105855ce25b 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -114,23 +114,24 @@ func.func @test_binary_scalar_f32(%arg0 : tensor<4xf32>, %arg1 : tensor<1xf32>)
// CHECK: tosa.minimum %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32>
%2 = tosa.minimum %arg0, %arg1 ...
[truncated]
|
@@ -812,7 +812,8 @@ def Tosa_MulOp : Tosa_Op<"mul", [ | |||
let arguments = (ins | |||
Tosa_Tensor:$input1, | |||
Tosa_Tensor:$input2, | |||
Optional<TosaTensorRankOf<[Tosa_Int8], [1]>>:$shift | |||
// Apply right shift on i32_t input data only | |||
Tosa_ScalarInt8Tensor:$shift |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why shall this be required at dialect level? Not in favor of a compulsory parameter when it doesn't make sense. Any good reasoning behind this?
@@ -139,6 +148,9 @@ class Tosa_TensorOfOrNone<list<Type> allowedTypes, string description = ""> : | |||
// Rank-0 (scalar) tensor | |||
def Tosa_ScalarTensor : TosaTensorRankOf<[Tosa_AnyNumber], [0]>; | |||
|
|||
// Scalar tensors: Rank-1 (with only one element) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Scalar tensors have rank0. This contradicts the above definition. Maybe worth of a different name?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
renamed to Tosa_Rank0Tensor
561ded2
to
cc5c33d
Compare
Hi @Tai78641 , can you do the rebase. |
Set the shift to be a mandatory operand. Change-Id: Ic8f6e2f653c6f5875f8b789a2da27d2834aa442a Signed-off-by: Tai Ly <[email protected]>
cc5c33d
to
73bf18e
Compare
rebased |
Approved, but let @GeorgeARM to take another look to double-confirm. |
Change the shift operand for the mul operator to be a required operand. Also defined shift to be Tosa_ScalarInt8Tensor which requires that it is a rank-1 tensor whose shape is [1] (ie, tensor containing a single element) Signed-off-by: Tai Ly <[email protected]>
Change the shift operand for the mul operator to be a required operand. Also defined shift to be Tosa_ScalarInt8Tensor which requires that it is a rank-1 tensor whose shape is [1] (ie, tensor containing a single element) Signed-off-by: Tai Ly <[email protected]>
Change the shift operand for the mul operator to be a required operand. Also defined shift to be Tosa_ScalarInt8Tensor which requires that it is a rank-1 tensor whose shape is [1] (ie, tensor containing a single element) Signed-off-by: Tai Ly <[email protected]>
Change the shift operand for the mul operator to be a required operand. Also defined shift to be Tosa_ScalarInt8Tensor which requires that it is a rank-1 tensor whose shape is [1] (ie, tensor containing a single element) Signed-off-by: Tai Ly <[email protected]>
Change the shift operand for the mul operator to be a required operand. Also defined shift to be Tosa_ScalarInt8Tensor which requires that it is a rank-1 tensor whose shape is [1] (ie, tensor containing a single element) Signed-off-by: Tai Ly <[email protected]>
Change the shift operand for the mul operator to be a required operand.
Also defined shift to be Tosa_ScalarInt8Tensor which requires that it is a rank-1 tensor
whose shape is [1] (ie, tensor containing a single element)