Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][tosa] Fix merge problems with mul shift #125129

Merged
merged 1 commit into from
Jan 31, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
Original file line number Diff line number Diff line change
@@ -239,7 +239,9 @@ class Tosa_ElementwiseOp<string mnemonic, list<Trait> traits = []> :
Tosa_Op<mnemonic, !listconcat(traits, [
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
["inferReturnTypeComponents"]>,
ResultsBroadcastableShape,
TosaElementwiseOperator,
SameOperandsAndResultRank,
Pure])> {
let assemblyFormat =
"operands attr-dict `:` functional-type(operands, results)";
@@ -248,8 +250,6 @@ class Tosa_ElementwiseOp<string mnemonic, list<Trait> traits = []> :
class Tosa_ElementwiseUnaryOp<string mnemonic, list<Trait> traits = []> :
Tosa_ElementwiseOp<mnemonic, !listconcat(traits, [
SameOperandsAndResultShape,
ResultsBroadcastableShape,
SameOperandsAndResultRank,
SameOperandsAndResultElementType])> {}

class Tosa_InferTensorTypeOp<string mnemonic, list<Trait> traits = []>
98 changes: 30 additions & 68 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
Original file line number Diff line number Diff line change
@@ -482,9 +482,7 @@ def Tosa_ErfOp : Tosa_ElementwiseUnaryOp<"erf"> {
//===----------------------------------------------------------------------===//
def Tosa_AddOp : Tosa_ElementwiseOp<"add", [
Commutative,
ResultsBroadcastableShape,
SameOperandsAndResultElementType,
SameOperandsAndResultRank]> {
SameOperandsAndResultElementType]> {
let summary = "Elementwise addition operator";

let description = [{
@@ -517,10 +515,8 @@ def Tosa_AddOp : Tosa_ElementwiseOp<"add", [
//===----------------------------------------------------------------------===//
// Operator: arithmetic_right_shift
//===----------------------------------------------------------------------===//
def Tosa_ArithmeticRightShiftOp : Tosa_ElementwiseOp<"arithmetic_right_shift", [
ResultsBroadcastableShape,
SameOperandsAndResultElementType,
SameOperandsAndResultRank]> {
def Tosa_ArithmeticRightShiftOp : Tosa_ElementwiseOp<"arithmetic_right_shift",
[SameOperandsAndResultElementType]> {
let summary = "Elementwise Arithmetic Right Shift";

let description = [{
@@ -544,9 +540,7 @@ def Tosa_ArithmeticRightShiftOp : Tosa_ElementwiseOp<"arithmetic_right_shift", [
//===----------------------------------------------------------------------===//
def Tosa_BitwiseAndOp : Tosa_ElementwiseOp<"bitwise_and", [
Commutative,
ResultsBroadcastableShape,
SameOperandsAndResultElementType,
SameOperandsAndResultRank]> {
SameOperandsAndResultElementType]> {
let summary = "Bitwise AND operator";

let description = [{
@@ -569,9 +563,7 @@ def Tosa_BitwiseAndOp : Tosa_ElementwiseOp<"bitwise_and", [
//===----------------------------------------------------------------------===//
def Tosa_BitwiseOrOp : Tosa_ElementwiseOp<"bitwise_or", [
Commutative,
ResultsBroadcastableShape,
SameOperandsAndResultElementType,
SameOperandsAndResultRank]> {
SameOperandsAndResultElementType]> {
let summary = "Bitwise OR operator";

let description = [{
@@ -594,9 +586,7 @@ def Tosa_BitwiseOrOp : Tosa_ElementwiseOp<"bitwise_or", [
//===----------------------------------------------------------------------===//
def Tosa_BitwiseXorOp : Tosa_ElementwiseOp<"bitwise_xor", [
Commutative,
ResultsBroadcastableShape,
SameOperandsAndResultElementType,
SameOperandsAndResultRank]> {
SameOperandsAndResultElementType]> {
let summary = "Bitwise XOR operator";

let description = [{
@@ -617,10 +607,7 @@ def Tosa_BitwiseXorOp : Tosa_ElementwiseOp<"bitwise_xor", [
//===----------------------------------------------------------------------===//
// Operator: int_div
//===----------------------------------------------------------------------===//
def Tosa_IntDivOp : Tosa_ElementwiseOp<"int_div", [
ResultsBroadcastableShape,
SameOperandsAndResultRank,
SameOperandsAndResultElementType]> {
def Tosa_IntDivOp : Tosa_ElementwiseOp<"int_div", [SameOperandsAndResultElementType]> {
let summary = "Integer divide operator";

let description = [{
@@ -645,9 +632,7 @@ def Tosa_IntDivOp : Tosa_ElementwiseOp<"int_div", [
//===----------------------------------------------------------------------===//
def Tosa_LogicalAndOp : Tosa_ElementwiseOp<"logical_and", [
Commutative,
ResultsBroadcastableShape,
SameOperandsAndResultElementType,
SameOperandsAndResultRank]> {
SameOperandsAndResultElementType]> {
let summary = "Returns the truth value of x AND y element-wise.";

let description = [{
@@ -668,10 +653,8 @@ def Tosa_LogicalAndOp : Tosa_ElementwiseOp<"logical_and", [
//===----------------------------------------------------------------------===//
// Operator: logical_left_shift
//===----------------------------------------------------------------------===//
def Tosa_LogicalLeftShiftOp : Tosa_ElementwiseOp<"logical_left_shift", [
ResultsBroadcastableShape,
SameOperandsAndResultElementType,
SameOperandsAndResultRank]> {
def Tosa_LogicalLeftShiftOp : Tosa_ElementwiseOp<"logical_left_shift",
[SameOperandsAndResultElementType]> {
let summary = "Elementwise Logical Left Shift";

let description = [{
@@ -692,10 +675,8 @@ def Tosa_LogicalLeftShiftOp : Tosa_ElementwiseOp<"logical_left_shift", [
//===----------------------------------------------------------------------===//
// Operator: logical_right_shift
//===----------------------------------------------------------------------===//
def Tosa_LogicalRightShiftOp : Tosa_ElementwiseOp<"logical_right_shift", [
ResultsBroadcastableShape,
SameOperandsAndResultElementType,
SameOperandsAndResultRank]> {
def Tosa_LogicalRightShiftOp : Tosa_ElementwiseOp<"logical_right_shift",
[SameOperandsAndResultElementType]> {
let summary = "Elementwise Logical Right Shift";

let description = [{
@@ -718,9 +699,7 @@ def Tosa_LogicalRightShiftOp : Tosa_ElementwiseOp<"logical_right_shift", [
//===----------------------------------------------------------------------===//
def Tosa_LogicalOrOp : Tosa_ElementwiseOp<"logical_or", [
Commutative,
ResultsBroadcastableShape,
SameOperandsAndResultElementType,
SameOperandsAndResultRank]> {
SameOperandsAndResultElementType]> {
let summary = "Returns the truth value of x OR y element-wise.";

let description = [{
@@ -743,9 +722,7 @@ def Tosa_LogicalOrOp : Tosa_ElementwiseOp<"logical_or", [
//===----------------------------------------------------------------------===//
def Tosa_LogicalXorOp : Tosa_ElementwiseOp<"logical_xor", [
Commutative,
ResultsBroadcastableShape,
SameOperandsAndResultElementType,
SameOperandsAndResultRank]> {
SameOperandsAndResultElementType]> {
let summary = "Returns the truth value of x XOR y element-wise.";

let description = [{
@@ -768,9 +745,7 @@ def Tosa_LogicalXorOp : Tosa_ElementwiseOp<"logical_xor", [
//===----------------------------------------------------------------------===//
def Tosa_MaximumOp : Tosa_ElementwiseOp<"maximum", [
Commutative,
ResultsBroadcastableShape,
SameOperandsAndResultElementType,
SameOperandsAndResultRank]> {
SameOperandsAndResultElementType]> {
let summary = "Elementwise Maximum";

let description = [{
@@ -794,9 +769,7 @@ def Tosa_MaximumOp : Tosa_ElementwiseOp<"maximum", [
//===----------------------------------------------------------------------===//
def Tosa_MinimumOp : Tosa_ElementwiseOp<"minimum", [
Commutative,
ResultsBroadcastableShape,
SameOperandsAndResultElementType,
SameOperandsAndResultRank]> {
SameOperandsAndResultElementType]> {
let summary = "Elementwise Minimum";

let description = [{
@@ -823,9 +796,11 @@ def MulOperandsAndResultElementType :
//===----------------------------------------------------------------------===//
// Operator: mul
//===----------------------------------------------------------------------===//
def Tosa_MulOp : Tosa_ElementwiseOp<"mul", [
def Tosa_MulOp : Tosa_Op<"mul", [
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
["inferReturnTypeComponents"]>,
Commutative,
MulOperandsAndResultElementType]> {
Pure]> {
let summary = "Multiplication operator";

let description = [{
@@ -846,15 +821,15 @@ def Tosa_MulOp : Tosa_ElementwiseOp<"mul", [

let hasFolder = 1;
let hasVerifier = 1;

let assemblyFormat =
"operands attr-dict `:` functional-type(operands, results)";
}

//===----------------------------------------------------------------------===//
// Operator: pow
//===----------------------------------------------------------------------===//
def Tosa_PowOp : Tosa_ElementwiseOp<"pow", [
ResultsBroadcastableShape,
SameOperandsAndResultElementType,
SameOperandsAndResultRank]> {
def Tosa_PowOp : Tosa_ElementwiseOp<"pow", [SameOperandsAndResultElementType]> {
let summary = "Computes the power of one value to another.";

let description = [{
@@ -875,10 +850,7 @@ def Tosa_PowOp : Tosa_ElementwiseOp<"pow", [
//===----------------------------------------------------------------------===//
// Operator: sub
//===----------------------------------------------------------------------===//
def Tosa_SubOp : Tosa_ElementwiseOp<"sub", [
ResultsBroadcastableShape,
SameOperandsAndResultElementType,
SameOperandsAndResultRank]> {
def Tosa_SubOp : Tosa_ElementwiseOp<"sub", [SameOperandsAndResultElementType]> {
let summary = "Elementwise subtraction operator";

let description = [{
@@ -1229,9 +1201,7 @@ def Tosa_SinOp : Tosa_ElementwiseUnaryOp<"sin"> {
//===----------------------------------------------------------------------===//
// Operator: select
//===----------------------------------------------------------------------===//
def Tosa_SelectOp : Tosa_ElementwiseOp<"select", [
ResultsBroadcastableShape,
SameOperandsAndResultRank]> {
def Tosa_SelectOp : Tosa_ElementwiseOp<"select"> {
let summary = "Elementwise select operator";

let description = [{
@@ -1267,9 +1237,7 @@ def Tosa_SelectOp : Tosa_ElementwiseOp<"select", [
def Tosa_EqualOp : Tosa_ElementwiseOp<"equal", [
InferTensorType,
Commutative,
ResultsBroadcastableShape,
SameOperandsElementType,
SameOperandsAndResultRank]> {
SameOperandsElementType]> {
let summary = "Returns the truth value of (x == y) element-wise.";

let description = [{
@@ -1297,10 +1265,7 @@ def Tosa_EqualOp : Tosa_ElementwiseOp<"equal", [
//===----------------------------------------------------------------------===//
// Operator: greater
//===----------------------------------------------------------------------===//
def Tosa_GreaterOp : Tosa_ElementwiseOp<"greater", [
ResultsBroadcastableShape,
SameOperandsElementType,
SameOperandsAndResultRank]> {
def Tosa_GreaterOp : Tosa_ElementwiseOp<"greater", [SameOperandsElementType]> {
let summary = "Returns the truth value of (x > y) element-wise.";

let description = [{
@@ -1322,11 +1287,8 @@ def Tosa_GreaterOp : Tosa_ElementwiseOp<"greater", [
//===----------------------------------------------------------------------===//
// Operator: greater_equal
//===----------------------------------------------------------------------===//
def Tosa_GreaterEqualOp : Tosa_ElementwiseOp<"greater_equal", [
ResultsBroadcastableShape,
SameOperandsElementType,
SameOperandsAndResultRank
]> {
def Tosa_GreaterEqualOp : Tosa_ElementwiseOp<"greater_equal",
[SameOperandsElementType]> {
let summary = "Returns the truth value of (x >= y) element-wise.";

let description = [{
37 changes: 36 additions & 1 deletion mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
Original file line number Diff line number Diff line change
@@ -958,6 +958,28 @@ LogicalResult tosa::SliceOp::verify() {
return success();
}

LogicalResult tosa::MulOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
ValueShapeRange operands, DictionaryAttr attributes,
OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
LogicalResult status = success();
llvm::SmallVector<int64_t> outShape;
if (operands.size() == 2) {
status = resolveBroadcastShape(operands, outShape);
} else {
// mul op's output shape only depend on input1 and input2, not on shift
ValueShapeRange two_inputs = operands.drop_back();
status = resolveBroadcastShape(two_inputs, outShape);
}
if (status.failed()) {
inferredReturnShapes.push_back(ShapedTypeComponents());
} else {
inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
}
return success();
}

LogicalResult tosa::MulOp::verify() {
auto resElemType = getElementTypeOrSelf(getOutput());

@@ -1030,6 +1052,20 @@ LogicalResult tosa::MulOp::verify() {
}
}

// check for broadcast compatible shapes in first two operands (ignoring
// shift)

// delegate function that returns shape of shaped type
auto getShape = [](const Type type) {
return mlir::cast<ShapedType>(type).getShape();
};
SmallVector<int64_t> resultShape;
if (!mlir::OpTrait::util::getBroadcastedShape(getShape(rankedOperandTypes[0]),
getShape(rankedOperandTypes[1]),
resultShape)) {
return emitOpError("operands don't have broadcast-compatible shapes");
}

return success();
}

@@ -1670,7 +1706,6 @@ NARY_SHAPE_INFER(tosa::LogicalRightShiftOp)
NARY_SHAPE_INFER(tosa::LogicalXorOp)
NARY_SHAPE_INFER(tosa::MaximumOp)
NARY_SHAPE_INFER(tosa::MinimumOp)
NARY_SHAPE_INFER(tosa::MulOp)
NARY_SHAPE_INFER(tosa::NegateOp)
NARY_SHAPE_INFER(tosa::PowOp)
NARY_SHAPE_INFER(tosa::ReciprocalOp)
18 changes: 15 additions & 3 deletions mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp
Original file line number Diff line number Diff line change
@@ -281,13 +281,20 @@ bool TosaReduceTransposes::collectFanIn(Operation *op,
if (!llvm::isa<tosa::TransposeOp>(op) && !llvm::isa<tosa::ReshapeOp>(op) &&
!llvm::isa<tosa::ConstOp>(op)) {

if (!op->hasTrait<OpTrait::tosa::TosaElementwiseOperator>())
if (!llvm::isa<tosa::MulOp>(op) &&
!op->hasTrait<OpTrait::tosa::TosaElementwiseOperator>())
return false;

for (Value operand : op->getOperands())
for (Value operand : op->getOperands()) {
// If this is a problem in future, think about alternatives to recursion.
if (llvm::isa<tosa::MulOp>(op) && op->getNumOperands() == 3 &&
operand == op->getOperand(2)) {
// do not recurse into MulOp's shift operand
continue;
}
if (!collectFanIn(operand.getDefiningOp(), collected))
return false;
}
}

// Insert in topological order.
@@ -316,14 +323,19 @@ std::optional<Value> TosaReduceTransposes::buildMappedToValue(
Operation *op, const DenseMap<Value, Value> &valuesMap,
IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms) {
if (op->getNumResults() != 1 ||
!op->hasTrait<OpTrait::tosa::TosaElementwiseOperator>())
(!llvm::isa<tosa::MulOp>(op) &&
!op->hasTrait<OpTrait::tosa::TosaElementwiseOperator>()))
return std::nullopt;

auto resultType = llvm::cast<RankedTensorType>(op->getResult(0).getType());
SmallVector<Value, 3> operands;
for (Value v : op->getOperands()) {
if (valuesMap.contains(v)) {
operands.push_back(valuesMap.at(v));
} else if (llvm::isa<tosa::MulOp>(op) && op->getNumOperands() == 3 &&
v == op->getOperand(2)) {
// special case for MulOp's shift operand
operands.push_back(v);
} else {
return std::nullopt;
}
6 changes: 3 additions & 3 deletions mlir/test/Dialect/Tosa/invalid.mlir
Original file line number Diff line number Diff line change
@@ -183,7 +183,7 @@ func.func @test_pad_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<i8>) -> t
// -----

func.func @test_pad_io_rank_mismatch(%arg0: tensor<13x21xf32>) {
%padding = tosa.const_shape {value = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4>
%padding = tosa.const_shape {value = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4>
// expected-error@+1 {{'tosa.pad' op expect same input and output tensor rank.}}
%1 = tosa.pad %arg0, %padding : (tensor<13x21xf32>, !tosa.shape<4>) -> tensor<13x21x3xf32>
return
@@ -211,7 +211,7 @@ func.func @test_pad_invalid_padConst_rank(%arg0: tensor<13x21xf32>, %arg1: tenso
// -----

func.func @test_pad_padding_shape_mismatch(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
%0 = tosa.const_shape {value = dense<1> : tensor<4xindex>} : () -> !tosa.shape<4>
%0 = tosa.const_shape {value = dense<1> : tensor<4xindex>} : () -> !tosa.shape<4>
// expected-error@+1 {{'tosa.pad' op expected padding tensor dim 0 to have size 6 (2*rank(shape1)) but got size 4}}
%1 = tosa.pad %arg0, %0 : (tensor<13x21x3xf32>, !tosa.shape<4>) -> tensor<13x21x3xf32>
return %1 : tensor<13x21x3xf32>
@@ -749,7 +749,7 @@ func.func @test_mul_invalid_shift(%arg0: tensor<13x21x3xi32>, %arg1: tensor<13x1

// CHECK-LABEL: test_mul_missing_shift
func.func @test_mul_missing_shift(%arg0: tensor<13x21x3xi32>, %arg1: tensor<13x1x3xi32>) -> tensor<13x21x3xi32> {
// expected-error@+1 {{'tosa.mul' op expected 3 operands, but found 2}}
// this is ok because mul's shift operand is optional for now
%0 = tosa.mul %arg0, %arg1 : (tensor<13x21x3xi32>, tensor<13x1x3xi32>) -> tensor<13x21x3xi32>
return %0 : tensor<13x21x3xi32>
}
8 changes: 8 additions & 0 deletions mlir/test/Dialect/Tosa/ops.mlir
Original file line number Diff line number Diff line change
@@ -327,6 +327,14 @@ func.func @test_min(%arg0: tensor<13x21x3xf32>, %arg1: tensor<1x21x3xf32>) -> te
return %0 : tensor<13x21x3xf32>
}

// -----
// CHECK-LABEL: test_mul_scalar_with_unranked_output
func.func @test_mul_scalar_with_unranked_output(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<*xf32> {
%shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
%0 = tosa.mul %arg0, %arg1, %shift : (tensor<f32>, tensor<f32>, tensor<1xi8>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}

// -----
// CHECK-LABEL: mul
func.func @test_mul(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xf32> {