-
Notifications
You must be signed in to change notification settings - Fork 12.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][tosa] Fix merge problems with mul shift #125129
Conversation
This patch fixes merge issues in TosaOpBase.td and TosaOps.td wrt traits on tosa elementwise ops and multiply op which, with the optional shift operand, is no longer strictly an elementwise op. fixed up inferReturnTypeComponents to be based on only the first two operands (ie, ignoring shift, if present) also fixed up TosaReduceTransposes to special handle tosa mul op now that it is not an elementwise op. Signed-off-by: Tai Ly <[email protected]> Change-Id: I30bd10137870cfe079c761da21389ca6a3b2e5e8
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-tosa Author: Tai Ly (Tai78641) ChangesThis patch fixes merge issues in TosaOpBase.td and TosaOps.td wrt traits on tosa elementwise ops and multiply op which, with the optional shift operand, is no longer strictly an elementwise op. fixed up inferReturnTypeComponents to be based on only the first two operands (ie, ignoring shift, if present) also fixed up TosaReduceTransposes to special handle tosa mul op now that it is not an elementwise op. Full diff: https://github.com/llvm/llvm-project/pull/125129.diff 6 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
index 29afd6c27302cc..4975530a9588ca 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
@@ -239,7 +239,9 @@ class Tosa_ElementwiseOp<string mnemonic, list<Trait> traits = []> :
Tosa_Op<mnemonic, !listconcat(traits, [
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
["inferReturnTypeComponents"]>,
+ ResultsBroadcastableShape,
TosaElementwiseOperator,
+ SameOperandsAndResultRank,
Pure])> {
let assemblyFormat =
"operands attr-dict `:` functional-type(operands, results)";
@@ -248,8 +250,6 @@ class Tosa_ElementwiseOp<string mnemonic, list<Trait> traits = []> :
class Tosa_ElementwiseUnaryOp<string mnemonic, list<Trait> traits = []> :
Tosa_ElementwiseOp<mnemonic, !listconcat(traits, [
SameOperandsAndResultShape,
- ResultsBroadcastableShape,
- SameOperandsAndResultRank,
SameOperandsAndResultElementType])> {}
class Tosa_InferTensorTypeOp<string mnemonic, list<Trait> traits = []>
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 9e3e41d288e4ac..c59c582a1f5221 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -482,9 +482,7 @@ def Tosa_ErfOp : Tosa_ElementwiseUnaryOp<"erf"> {
//===----------------------------------------------------------------------===//
def Tosa_AddOp : Tosa_ElementwiseOp<"add", [
Commutative,
- ResultsBroadcastableShape,
- SameOperandsAndResultElementType,
- SameOperandsAndResultRank]> {
+ SameOperandsAndResultElementType]> {
let summary = "Elementwise addition operator";
let description = [{
@@ -517,10 +515,8 @@ def Tosa_AddOp : Tosa_ElementwiseOp<"add", [
//===----------------------------------------------------------------------===//
// Operator: arithmetic_right_shift
//===----------------------------------------------------------------------===//
-def Tosa_ArithmeticRightShiftOp : Tosa_ElementwiseOp<"arithmetic_right_shift", [
- ResultsBroadcastableShape,
- SameOperandsAndResultElementType,
- SameOperandsAndResultRank]> {
+def Tosa_ArithmeticRightShiftOp : Tosa_ElementwiseOp<"arithmetic_right_shift",
+ [SameOperandsAndResultElementType]> {
let summary = "Elementwise Arithmetic Right Shift";
let description = [{
@@ -544,9 +540,7 @@ def Tosa_ArithmeticRightShiftOp : Tosa_ElementwiseOp<"arithmetic_right_shift", [
//===----------------------------------------------------------------------===//
def Tosa_BitwiseAndOp : Tosa_ElementwiseOp<"bitwise_and", [
Commutative,
- ResultsBroadcastableShape,
- SameOperandsAndResultElementType,
- SameOperandsAndResultRank]> {
+ SameOperandsAndResultElementType]> {
let summary = "Bitwise AND operator";
let description = [{
@@ -569,9 +563,7 @@ def Tosa_BitwiseAndOp : Tosa_ElementwiseOp<"bitwise_and", [
//===----------------------------------------------------------------------===//
def Tosa_BitwiseOrOp : Tosa_ElementwiseOp<"bitwise_or", [
Commutative,
- ResultsBroadcastableShape,
- SameOperandsAndResultElementType,
- SameOperandsAndResultRank]> {
+ SameOperandsAndResultElementType]> {
let summary = "Bitwise OR operator";
let description = [{
@@ -594,9 +586,7 @@ def Tosa_BitwiseOrOp : Tosa_ElementwiseOp<"bitwise_or", [
//===----------------------------------------------------------------------===//
def Tosa_BitwiseXorOp : Tosa_ElementwiseOp<"bitwise_xor", [
Commutative,
- ResultsBroadcastableShape,
- SameOperandsAndResultElementType,
- SameOperandsAndResultRank]> {
+ SameOperandsAndResultElementType]> {
let summary = "Bitwise XOR operator";
let description = [{
@@ -617,10 +607,7 @@ def Tosa_BitwiseXorOp : Tosa_ElementwiseOp<"bitwise_xor", [
//===----------------------------------------------------------------------===//
// Operator: int_div
//===----------------------------------------------------------------------===//
-def Tosa_IntDivOp : Tosa_ElementwiseOp<"int_div", [
- ResultsBroadcastableShape,
- SameOperandsAndResultRank,
- SameOperandsAndResultElementType]> {
+def Tosa_IntDivOp : Tosa_ElementwiseOp<"int_div", [SameOperandsAndResultElementType]> {
let summary = "Integer divide operator";
let description = [{
@@ -645,9 +632,7 @@ def Tosa_IntDivOp : Tosa_ElementwiseOp<"int_div", [
//===----------------------------------------------------------------------===//
def Tosa_LogicalAndOp : Tosa_ElementwiseOp<"logical_and", [
Commutative,
- ResultsBroadcastableShape,
- SameOperandsAndResultElementType,
- SameOperandsAndResultRank]> {
+ SameOperandsAndResultElementType]> {
let summary = "Returns the truth value of x AND y element-wise.";
let description = [{
@@ -668,10 +653,8 @@ def Tosa_LogicalAndOp : Tosa_ElementwiseOp<"logical_and", [
//===----------------------------------------------------------------------===//
// Operator: logical_left_shift
//===----------------------------------------------------------------------===//
-def Tosa_LogicalLeftShiftOp : Tosa_ElementwiseOp<"logical_left_shift", [
- ResultsBroadcastableShape,
- SameOperandsAndResultElementType,
- SameOperandsAndResultRank]> {
+def Tosa_LogicalLeftShiftOp : Tosa_ElementwiseOp<"logical_left_shift",
+ [SameOperandsAndResultElementType]> {
let summary = "Elementwise Logical Left Shift";
let description = [{
@@ -692,10 +675,8 @@ def Tosa_LogicalLeftShiftOp : Tosa_ElementwiseOp<"logical_left_shift", [
//===----------------------------------------------------------------------===//
// Operator: logical_right_shift
//===----------------------------------------------------------------------===//
-def Tosa_LogicalRightShiftOp : Tosa_ElementwiseOp<"logical_right_shift", [
- ResultsBroadcastableShape,
- SameOperandsAndResultElementType,
- SameOperandsAndResultRank]> {
+def Tosa_LogicalRightShiftOp : Tosa_ElementwiseOp<"logical_right_shift",
+ [SameOperandsAndResultElementType]> {
let summary = "Elementwise Logical Right Shift";
let description = [{
@@ -718,9 +699,7 @@ def Tosa_LogicalRightShiftOp : Tosa_ElementwiseOp<"logical_right_shift", [
//===----------------------------------------------------------------------===//
def Tosa_LogicalOrOp : Tosa_ElementwiseOp<"logical_or", [
Commutative,
- ResultsBroadcastableShape,
- SameOperandsAndResultElementType,
- SameOperandsAndResultRank]> {
+ SameOperandsAndResultElementType]> {
let summary = "Returns the truth value of x OR y element-wise.";
let description = [{
@@ -743,9 +722,7 @@ def Tosa_LogicalOrOp : Tosa_ElementwiseOp<"logical_or", [
//===----------------------------------------------------------------------===//
def Tosa_LogicalXorOp : Tosa_ElementwiseOp<"logical_xor", [
Commutative,
- ResultsBroadcastableShape,
- SameOperandsAndResultElementType,
- SameOperandsAndResultRank]> {
+ SameOperandsAndResultElementType]> {
let summary = "Returns the truth value of x XOR y element-wise.";
let description = [{
@@ -768,9 +745,7 @@ def Tosa_LogicalXorOp : Tosa_ElementwiseOp<"logical_xor", [
//===----------------------------------------------------------------------===//
def Tosa_MaximumOp : Tosa_ElementwiseOp<"maximum", [
Commutative,
- ResultsBroadcastableShape,
- SameOperandsAndResultElementType,
- SameOperandsAndResultRank]> {
+ SameOperandsAndResultElementType]> {
let summary = "Elementwise Maximum";
let description = [{
@@ -794,9 +769,7 @@ def Tosa_MaximumOp : Tosa_ElementwiseOp<"maximum", [
//===----------------------------------------------------------------------===//
def Tosa_MinimumOp : Tosa_ElementwiseOp<"minimum", [
Commutative,
- ResultsBroadcastableShape,
- SameOperandsAndResultElementType,
- SameOperandsAndResultRank]> {
+ SameOperandsAndResultElementType]> {
let summary = "Elementwise Minimum";
let description = [{
@@ -823,9 +796,11 @@ def MulOperandsAndResultElementType :
//===----------------------------------------------------------------------===//
// Operator: mul
//===----------------------------------------------------------------------===//
-def Tosa_MulOp : Tosa_ElementwiseOp<"mul", [
+def Tosa_MulOp : Tosa_Op<"mul", [
+ DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
+ ["inferReturnTypeComponents"]>,
Commutative,
- MulOperandsAndResultElementType]> {
+ Pure]> {
let summary = "Multiplication operator";
let description = [{
@@ -846,15 +821,15 @@ def Tosa_MulOp : Tosa_ElementwiseOp<"mul", [
let hasFolder = 1;
let hasVerifier = 1;
+
+ let assemblyFormat =
+ "operands attr-dict `:` functional-type(operands, results)";
}
//===----------------------------------------------------------------------===//
// Operator: pow
//===----------------------------------------------------------------------===//
-def Tosa_PowOp : Tosa_ElementwiseOp<"pow", [
- ResultsBroadcastableShape,
- SameOperandsAndResultElementType,
- SameOperandsAndResultRank]> {
+def Tosa_PowOp : Tosa_ElementwiseOp<"pow", [SameOperandsAndResultElementType]> {
let summary = "Computes the power of one value to another.";
let description = [{
@@ -875,10 +850,7 @@ def Tosa_PowOp : Tosa_ElementwiseOp<"pow", [
//===----------------------------------------------------------------------===//
// Operator: sub
//===----------------------------------------------------------------------===//
-def Tosa_SubOp : Tosa_ElementwiseOp<"sub", [
- ResultsBroadcastableShape,
- SameOperandsAndResultElementType,
- SameOperandsAndResultRank]> {
+def Tosa_SubOp : Tosa_ElementwiseOp<"sub", [SameOperandsAndResultElementType]> {
let summary = "Elementwise subtraction operator";
let description = [{
@@ -1229,9 +1201,7 @@ def Tosa_SinOp : Tosa_ElementwiseUnaryOp<"sin"> {
//===----------------------------------------------------------------------===//
// Operator: select
//===----------------------------------------------------------------------===//
-def Tosa_SelectOp : Tosa_ElementwiseOp<"select", [
- ResultsBroadcastableShape,
- SameOperandsAndResultRank]> {
+def Tosa_SelectOp : Tosa_ElementwiseOp<"select"> {
let summary = "Elementwise select operator";
let description = [{
@@ -1267,9 +1237,7 @@ def Tosa_SelectOp : Tosa_ElementwiseOp<"select", [
def Tosa_EqualOp : Tosa_ElementwiseOp<"equal", [
InferTensorType,
Commutative,
- ResultsBroadcastableShape,
- SameOperandsElementType,
- SameOperandsAndResultRank]> {
+ SameOperandsElementType]> {
let summary = "Returns the truth value of (x == y) element-wise.";
let description = [{
@@ -1297,10 +1265,7 @@ def Tosa_EqualOp : Tosa_ElementwiseOp<"equal", [
//===----------------------------------------------------------------------===//
// Operator: greater
//===----------------------------------------------------------------------===//
-def Tosa_GreaterOp : Tosa_ElementwiseOp<"greater", [
- ResultsBroadcastableShape,
- SameOperandsElementType,
- SameOperandsAndResultRank]> {
+def Tosa_GreaterOp : Tosa_ElementwiseOp<"greater", [SameOperandsElementType]> {
let summary = "Returns the truth value of (x > y) element-wise.";
let description = [{
@@ -1322,11 +1287,8 @@ def Tosa_GreaterOp : Tosa_ElementwiseOp<"greater", [
//===----------------------------------------------------------------------===//
// Operator: greater_equal
//===----------------------------------------------------------------------===//
-def Tosa_GreaterEqualOp : Tosa_ElementwiseOp<"greater_equal", [
- ResultsBroadcastableShape,
- SameOperandsElementType,
- SameOperandsAndResultRank
- ]> {
+def Tosa_GreaterEqualOp : Tosa_ElementwiseOp<"greater_equal",
+ [SameOperandsElementType]> {
let summary = "Returns the truth value of (x >= y) element-wise.";
let description = [{
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index c0b419b6f473c8..0a10439db40803 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -958,6 +958,28 @@ LogicalResult tosa::SliceOp::verify() {
return success();
}
+LogicalResult tosa::MulOp::inferReturnTypeComponents(
+ MLIRContext *context, ::std::optional<Location> location,
+ ValueShapeRange operands, DictionaryAttr attributes,
+ OpaqueProperties properties, RegionRange regions,
+ SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
+ LogicalResult status = success();
+ llvm::SmallVector<int64_t> outShape;
+ if (operands.size() == 2) {
+ status = resolveBroadcastShape(operands, outShape);
+ } else {
+ // mul op's output shape only depend on input1 and input2, not on shift
+ ValueShapeRange two_inputs = operands.drop_back();
+ status = resolveBroadcastShape(two_inputs, outShape);
+ }
+ if (status.failed()) {
+ inferredReturnShapes.push_back(ShapedTypeComponents());
+ } else {
+ inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
+ }
+ return success();
+}
+
LogicalResult tosa::MulOp::verify() {
auto resElemType = getElementTypeOrSelf(getOutput());
@@ -1030,6 +1052,20 @@ LogicalResult tosa::MulOp::verify() {
}
}
+ // check for broadcast compatible shapes in first two operands (ignoring
+ // shift)
+
+ // delegate function that returns shape of shaped type
+ auto getShape = [](const Type type) {
+ return mlir::cast<ShapedType>(type).getShape();
+ };
+ SmallVector<int64_t> resultShape;
+ if (!mlir::OpTrait::util::getBroadcastedShape(getShape(rankedOperandTypes[0]),
+ getShape(rankedOperandTypes[1]),
+ resultShape)) {
+ return emitOpError("operands don't have broadcast-compatible shapes");
+ }
+
return success();
}
@@ -1670,7 +1706,6 @@ NARY_SHAPE_INFER(tosa::LogicalRightShiftOp)
NARY_SHAPE_INFER(tosa::LogicalXorOp)
NARY_SHAPE_INFER(tosa::MaximumOp)
NARY_SHAPE_INFER(tosa::MinimumOp)
-NARY_SHAPE_INFER(tosa::MulOp)
NARY_SHAPE_INFER(tosa::NegateOp)
NARY_SHAPE_INFER(tosa::PowOp)
NARY_SHAPE_INFER(tosa::ReciprocalOp)
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp
index e1f0a9592e8b4f..520f283a3ba888 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp
@@ -281,13 +281,20 @@ bool TosaReduceTransposes::collectFanIn(Operation *op,
if (!llvm::isa<tosa::TransposeOp>(op) && !llvm::isa<tosa::ReshapeOp>(op) &&
!llvm::isa<tosa::ConstOp>(op)) {
- if (!op->hasTrait<OpTrait::tosa::TosaElementwiseOperator>())
+ if (!llvm::isa<tosa::MulOp>(op) &&
+ !op->hasTrait<OpTrait::tosa::TosaElementwiseOperator>())
return false;
- for (Value operand : op->getOperands())
+ for (Value operand : op->getOperands()) {
// If this is a problem in future, think about alternatives to recursion.
+ if (llvm::isa<tosa::MulOp>(op) && op->getNumOperands() == 3 &&
+ operand == op->getOperand(2)) {
+ // do not recurse into MulOp's shift operand
+ continue;
+ }
if (!collectFanIn(operand.getDefiningOp(), collected))
return false;
+ }
}
// Insert in topological order.
@@ -316,7 +323,8 @@ std::optional<Value> TosaReduceTransposes::buildMappedToValue(
Operation *op, const DenseMap<Value, Value> &valuesMap,
IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms) {
if (op->getNumResults() != 1 ||
- !op->hasTrait<OpTrait::tosa::TosaElementwiseOperator>())
+ (!llvm::isa<tosa::MulOp>(op) &&
+ !op->hasTrait<OpTrait::tosa::TosaElementwiseOperator>()))
return std::nullopt;
auto resultType = llvm::cast<RankedTensorType>(op->getResult(0).getType());
@@ -324,6 +332,10 @@ std::optional<Value> TosaReduceTransposes::buildMappedToValue(
for (Value v : op->getOperands()) {
if (valuesMap.contains(v)) {
operands.push_back(valuesMap.at(v));
+ } else if (llvm::isa<tosa::MulOp>(op) && op->getNumOperands() == 3 &&
+ v == op->getOperand(2)) {
+ // special case for MulOp's shift operand
+ operands.push_back(v);
} else {
return std::nullopt;
}
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 51d7f828510613..ac4d466aef94b2 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -183,7 +183,7 @@ func.func @test_pad_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<i8>) -> t
// -----
func.func @test_pad_io_rank_mismatch(%arg0: tensor<13x21xf32>) {
- %padding = tosa.const_shape {value = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4>
+ %padding = tosa.const_shape {value = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4>
// expected-error@+1 {{'tosa.pad' op expect same input and output tensor rank.}}
%1 = tosa.pad %arg0, %padding : (tensor<13x21xf32>, !tosa.shape<4>) -> tensor<13x21x3xf32>
return
@@ -211,7 +211,7 @@ func.func @test_pad_invalid_padConst_rank(%arg0: tensor<13x21xf32>, %arg1: tenso
// -----
func.func @test_pad_padding_shape_mismatch(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
- %0 = tosa.const_shape {value = dense<1> : tensor<4xindex>} : () -> !tosa.shape<4>
+ %0 = tosa.const_shape {value = dense<1> : tensor<4xindex>} : () -> !tosa.shape<4>
// expected-error@+1 {{'tosa.pad' op expected padding tensor dim 0 to have size 6 (2*rank(shape1)) but got size 4}}
%1 = tosa.pad %arg0, %0 : (tensor<13x21x3xf32>, !tosa.shape<4>) -> tensor<13x21x3xf32>
return %1 : tensor<13x21x3xf32>
@@ -749,7 +749,7 @@ func.func @test_mul_invalid_shift(%arg0: tensor<13x21x3xi32>, %arg1: tensor<13x1
// CHECK-LABEL: test_mul_missing_shift
func.func @test_mul_missing_shift(%arg0: tensor<13x21x3xi32>, %arg1: tensor<13x1x3xi32>) -> tensor<13x21x3xi32> {
- // expected-error@+1 {{'tosa.mul' op expected 3 operands, but found 2}}
+ // this is ok because mul's shift operand is optional for now
%0 = tosa.mul %arg0, %arg1 : (tensor<13x21x3xi32>, tensor<13x1x3xi32>) -> tensor<13x21x3xi32>
return %0 : tensor<13x21x3xi32>
}
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index 9eba2f7e5a06e4..a4596c8f9d5362 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -327,6 +327,14 @@ func.func @test_min(%arg0: tensor<13x21x3xf32>, %arg1: tensor<1x21x3xf32>) -> te
return %0 : tensor<13x21x3xf32>
}
+// -----
+// CHECK-LABEL: test_mul_scalar_with_unranked_output
+func.func @test_mul_scalar_with_unranked_output(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<*xf32> {
+ %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %0 = tosa.mul %arg0, %arg1, %shift : (tensor<f32>, tensor<f32>, tensor<1xi8>) -> tensor<*xf32>
+ return %0 : tensor<*xf32>
+}
+
// -----
// CHECK-LABEL: mul
func.func @test_mul(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xf32> {
|
This patch fixes merge issues in TosaOpBase.td and TosaOps.td wrt traits on tosa elementwise ops and multiply op which, with the optional shift operand, is no longer strictly an elementwise op.
fixed up inferReturnTypeComponents to be based on only the first two operands (ie, ignoring shift, if present)
also fixed up TosaReduceTransposes to special handle tosa mul op now that it is not an elementwise op.