Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][tosa] Change the shift of mul to be required #125297

Merged
merged 1 commit into from
Feb 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 11 additions & 10 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,8 @@ def Tosa_Conv2DOp : Tosa_ConvOp<"conv2d"> {
Tosa_Tensor4D:$input,
TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
Tosa_Tensor1D:$bias,
Optional<Tosa_ZeroPointTensor>:$input_zp,
Optional<Tosa_ZeroPointTensor>:$weight_zp,
Optional<Tosa_ScalarTensor>:$input_zp,
Optional<Tosa_ScalarTensor>:$weight_zp,
Tosa_IntArrayAttr4:$pad,
Tosa_IntArrayAttr2:$stride,
Tosa_IntArrayAttr2:$dilation,
Expand Down Expand Up @@ -136,8 +136,8 @@ def Tosa_Conv3DOp : Tosa_ConvOp<"conv3d"> {
Tosa_Tensor5D:$input,
TosaTensorRankOf<[Tosa_Weight], [5]>:$weight,
Tosa_Tensor1D:$bias,
Optional<Tosa_ZeroPointTensor>:$input_zp,
Optional<Tosa_ZeroPointTensor>:$weight_zp,
Optional<Tosa_ScalarTensor>:$input_zp,
Optional<Tosa_ScalarTensor>:$weight_zp,
Tosa_IntArrayAttr6:$pad,
Tosa_IntArrayAttr3:$stride,
Tosa_IntArrayAttr3:$dilation,
Expand Down Expand Up @@ -168,8 +168,8 @@ def Tosa_DepthwiseConv2DOp : Tosa_ConvOp<"depthwise_conv2d"> {
Tosa_Tensor4D:$input,
TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
Tosa_Tensor1D:$bias,
Optional<Tosa_ZeroPointTensor>:$input_zp,
Optional<Tosa_ZeroPointTensor>:$weight_zp,
Optional<Tosa_ScalarTensor>:$input_zp,
Optional<Tosa_ScalarTensor>:$weight_zp,
Tosa_IntArrayAttr4:$pad,
Tosa_IntArrayAttr2:$stride,
Tosa_IntArrayAttr2:$dilation,
Expand Down Expand Up @@ -356,8 +356,8 @@ def Tosa_TransposeConv2DOp : Tosa_ConvOp<"transpose_conv2d"> {
Tosa_Tensor4D:$input,
TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
Tosa_Tensor1D:$bias,
Optional<Tosa_ZeroPointTensor>:$input_zp,
Optional<Tosa_ZeroPointTensor>:$weight_zp,
Optional<Tosa_ScalarTensor>:$input_zp,
Optional<Tosa_ScalarTensor>:$weight_zp,
Tosa_IntArrayAttr4:$out_pad,
Tosa_IntArrayAttr2:$stride,
Tosa_IntArrayAttr4:$out_shape,
Expand Down Expand Up @@ -819,7 +819,8 @@ def Tosa_MulOp : Tosa_Op<"mul", [
let arguments = (ins
Tosa_Tensor:$input1,
Tosa_Tensor:$input2,
Optional<TosaTensorRankOf<[Tosa_Int8], [1]>>:$shift
// Apply right shift on i32_t input data only
Tosa_ScalarInt8Tensor:$shift
Copy link
Contributor

@GeorgeARM GeorgeARM Feb 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why shall this be required at dialect level? Not in favor of a compulsory parameter when it doesn't make sense. Any good reasoning behind this?

);

let results = (outs
Expand Down Expand Up @@ -1592,7 +1593,7 @@ def Tosa_PadOp : Tosa_InferShapedTypeOp<"pad"> {
let arguments = (ins
Tosa_RankedTensor:$input1,
Tosa_Shape:$padding,
Optional<Tosa_ScalarTensor>:$pad_const,
Optional<Tosa_Rank0Tensor>:$pad_const,
OptionalAttr<I32Attr>:$input_zp
);

Expand Down
20 changes: 13 additions & 7 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,10 @@ def HasNo0Dimensions : And<[
IsRankedTensorTypePred,
CPred<"::llvm::all_of(::llvm::cast<::mlir::RankedTensorType>($_self).getShape(), [](auto v) { return v != 0; })">]>;

def AllDimensionsAreSizeOne : And<[
IsRankedTensorTypePred,
CPred<"::llvm::all_of(::llvm::cast<::mlir::RankedTensorType>($_self).getShape(), [](auto v) { return v == 1; })">]>;

class TosaTensorOf<
list<Type> allowedTypes, string summary = "tosa-conformant tensor">
: TensorOf<allowedTypes, [Or<[HasNo0Dimensions, IsUnrankedTensorTypePred]>], summary>;
Expand All @@ -109,6 +113,11 @@ class TosaTensorRankOf<list<Type> allowedTypes, list<int> ranks>
[HasAnyRankOfPred<ranks>],
!interleave(!foreach(rank, ranks, rank # "D"), "/") # " tensor">;

class TosaScalarTensorOf<list<Type> allowedTypes, list<int> ranks>
: TosaRankedTensorOf<allowedTypes,
[HasAnyRankOfPred<ranks>, AllDimensionsAreSizeOne],
"tosa-conformant scalar tensor">;

//===----------------------------------------------------------------------===//
// Tensor types
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -136,8 +145,10 @@ class Tosa_TensorOfOrNone<list<Type> allowedTypes, string description = ""> :
// Tensor types with constrained ranks.
//===----------------------------------------------------------------------===//

// Rank-0 (scalar) tensor
def Tosa_ScalarTensor : TosaTensorRankOf<[Tosa_AnyNumber], [0]>;
def Tosa_Rank0Tensor : TosaTensorRankOf<[Tosa_AnyNumber], [0]>;

def Tosa_ScalarTensor : TosaScalarTensorOf<[Tosa_AnyNumber], [1]>;
def Tosa_ScalarInt8Tensor : TosaScalarTensorOf<[Tosa_Int8], [1]>;

// We include unranked tensors as a supported type for all possible tosa
// Tensors as unranked does not guarantee invalid. If unranked tensors exist
Expand Down Expand Up @@ -288,9 +299,4 @@ def Rank1TosaShape : TosaShapeOfRank<1>;
def Rank2TosaShape : TosaShapeOfRank<2>;
def Rank4TosaShape : TosaShapeOfRank<4>;

// NOTE: Tosa_ScalarTensor is currently defined as rank-0. If and when this
// becomes rank-1 it can be used in place of Tosa_ZeroPointTensor and the
// following def can be removed.
def Tosa_ZeroPointTensor : TosaTensorRankOf<[Tosa_AnyNumber], [1]>;

#endif // TOSA_TYPES_BASE
21 changes: 13 additions & 8 deletions mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,22 +92,27 @@ static Value createLinalgBodyCalculationForElementwiseOp(
// tosa::MulOp
if (isa<tosa::MulOp>(op)) {
auto shift_val = cast<tosa::MulOp>(op).getShift();
ElementsAttr shift_elem;
if (!shift_val.getImpl() ||
!matchPattern(shift_val, m_Constant(&shift_elem))) {
(void)rewriter.notifyMatchFailure(op, "shift value of mul not found");
}

int32_t shift = shift_elem.getValues<IntegerAttr>()[0].getInt();

if (isa<FloatType>(elementTy)) {
if (shift != 0) {
(void)rewriter.notifyMatchFailure(op,
"Cannot have shift value for float");
return nullptr;
}
return rewriter.create<arith::MulFOp>(loc, resultTypes, args[0], args[1]);
}

if (isa<IntegerType>(elementTy)) {
int32_t shift = 0;
ElementsAttr shift_elem;
if (shift_val.getImpl() &&
matchPattern(shift_val, m_Constant(&shift_elem))) {
// Explicit shift is set.
shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
}

Value a = args[0];
Value b = args[1];

if (shift > 0) {
auto shiftConst =
rewriter.create<arith::ConstantIntOp>(loc, shift, /*bitwidth=*/8);
Expand Down
21 changes: 12 additions & 9 deletions mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1104,16 +1104,10 @@ LogicalResult tosa::MulOp::inferReturnTypeComponents(
ValueShapeRange operands, DictionaryAttr attributes,
OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
LogicalResult status = success();
// mul op's output shape only depend on input1 and input2, not on shift
ValueShapeRange twoInputs = operands.drop_back();
llvm::SmallVector<int64_t> outShape;
if (operands.size() == 2) {
status = resolveBroadcastShape(operands, outShape);
} else {
// mul op's output shape only depend on input1 and input2, not on shift
ValueShapeRange two_inputs = operands.drop_back();
status = resolveBroadcastShape(two_inputs, outShape);
}
if (status.failed()) {
if (resolveBroadcastShape(twoInputs, outShape).failed()) {
inferredReturnShapes.push_back(ShapedTypeComponents());
} else {
inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
Expand Down Expand Up @@ -1148,6 +1142,15 @@ LogicalResult tosa::MulOp::verify() {
return emitOpError(
"requires the same element type for all operands and results");
}

// verify shift has value 0 for non-integer types
ElementsAttr shift_elem;
if (matchPattern(getShift(), m_Constant(&shift_elem))) {
int32_t shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
if (shift != 0) {
return emitOpError() << "require shift to be 0 for float type";
}
}
}

// Verify the op has same ranks for all main operands (excludes extra operands
Expand Down
6 changes: 2 additions & 4 deletions mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -287,8 +287,7 @@ bool TosaReduceTransposes::collectFanIn(Operation *op,

for (Value operand : op->getOperands()) {
// If this is a problem in future, think about alternatives to recursion.
if (llvm::isa<tosa::MulOp>(op) && op->getNumOperands() == 3 &&
operand == op->getOperand(2)) {
if (llvm::isa<tosa::MulOp>(op) && operand == op->getOperand(2)) {
// do not recurse into MulOp's shift operand
continue;
}
Expand Down Expand Up @@ -332,8 +331,7 @@ std::optional<Value> TosaReduceTransposes::buildMappedToValue(
for (Value v : op->getOperands()) {
if (valuesMap.contains(v)) {
operands.push_back(valuesMap.at(v));
} else if (llvm::isa<tosa::MulOp>(op) && op->getNumOperands() == 3 &&
v == op->getOperand(2)) {
} else if (llvm::isa<tosa::MulOp>(op) && v == op->getOperand(2)) {
// special case for MulOp's shift operand
operands.push_back(v);
} else {
Expand Down
6 changes: 4 additions & 2 deletions mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,8 @@ func.func @test_simple_f32(%arg0: tensor<1xf32>) -> () {

// CHECK: linalg.generic
// CHECK: arith.mulf
%4 = tosa.mul %0, %1 : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
%shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
%4 = tosa.mul %0, %1, %shift : (tensor<1xf32>, tensor<1xf32>, tensor<1xi8>) -> tensor<1xf32>

// CHECK: linalg.generic
// CHECK: arith.negf
Expand Down Expand Up @@ -618,7 +619,8 @@ func.func @test_simple_i16(%arg0: tensor<1xi16>) -> () {
// CHECK: arith.extsi
// CHECK: arith.extsi
// CHECK: arith.muli
%0 = tosa.mul %arg0, %arg0 : (tensor<1xi16>, tensor<1xi16>) -> tensor<1xi32>
%shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
%0 = tosa.mul %arg0, %arg0, %shift : (tensor<1xi16>, tensor<1xi16>, tensor<1xi8>) -> tensor<1xi32>

return
}
Expand Down
14 changes: 9 additions & 5 deletions mlir/test/Dialect/Tosa/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -331,8 +331,9 @@ func.func @pad_determine_val_quant(%arg0: tensor<?x?xi32>, %arg1 : tensor<2x2xi3
func.func @mul_one_float(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
// CHECK: return %arg0
// CHECK-NOT: tosa.mul
%shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
%ones = "tosa.const"() {value = dense<1.0> : tensor<2x3xf32>} : () -> tensor<2x3xf32>
%1 = tosa.mul %arg0, %ones : (tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
%1 = tosa.mul %arg0, %ones, %shift : (tensor<2x3xf32>, tensor<2x3xf32>, tensor<1xi8>) -> tensor<2x3xf32>
return %1 : tensor<2x3xf32>
}

Expand All @@ -343,7 +344,8 @@ func.func @mul_bcast_one_float(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
// CHECK: return %arg0
// CHECK-NOT: tosa.mul
%ones = "tosa.const"() {value = dense<1.0> : tensor<1x1xf32>} : () -> tensor<1x1xf32>
%1 = tosa.mul %ones, %arg0 : (tensor<1x1xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
%shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
%1 = tosa.mul %ones, %arg0, %shift : (tensor<1x1xf32>, tensor<2x3xf32>, tensor<1xi8>) -> tensor<2x3xf32>
return %1 : tensor<2x3xf32>
}

Expand Down Expand Up @@ -379,11 +381,12 @@ func.func @mul_zero_broadcast(%arg0: tensor<2x3xf32>) -> (tensor<2x3xf32>, tenso
// CHECK: %[[ZERO:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<2x3xf32>}
// CHECK-NOT: tosa.mul
%zeros = "tosa.const"() {value = dense<0.0> : tensor<1x1xf32>} : () -> tensor<1x1xf32>
%1 = tosa.mul %arg0, %zeros : (tensor<2x3xf32>, tensor<1x1xf32>) -> tensor<2x3xf32>
%shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
%1 = tosa.mul %arg0, %zeros, %shift : (tensor<2x3xf32>, tensor<1x1xf32>, tensor<1xi8>) -> tensor<2x3xf32>

// CHECK-NOT: tosa.mul
// CHECK: return %[[ZERO]], %[[ZERO]]
%2 = tosa.mul %zeros, %arg0 : (tensor<1x1xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
%2 = tosa.mul %zeros, %arg0, %shift : (tensor<1x1xf32>, tensor<2x3xf32>, tensor<1xi8>) -> tensor<2x3xf32>
return %1, %2 : tensor<2x3xf32>, tensor<2x3xf32>
}

Expand Down Expand Up @@ -983,7 +986,8 @@ func.func @mul_quant_nofold() -> tensor<1x!quant.uniform<i8:f32, 3.0757404601899
// CHECK: tosa.mul
%0 = "tosa.const"() {value = dense<0> : tensor<1xi8>} : () -> tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
%1 = "tosa.const"() {value = dense<1> : tensor<1xi8>} : () -> tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
%2 = tosa.mul %0, %1 : (tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>, tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>)-> tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
%shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
%2 = tosa.mul %0, %1, %shift : (tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>, tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>, tensor<1xi8>) -> tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
return %2 : tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
}

Expand Down
15 changes: 10 additions & 5 deletions mlir/test/Dialect/Tosa/constant-op-fold.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,8 @@ func.func @fold_div_splat_i32() -> tensor<i32> {
func.func @fold_mul_zero_rhs_f32(%arg0: tensor<f32>) -> tensor<f32> {
%zero = "tosa.const"() {value = dense<0.0> : tensor<f32>} : () -> tensor<f32>
// CHECK: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0.000000e+00>
%mul = tosa.mul %arg0, %zero : (tensor<f32>, tensor<f32>) -> tensor<f32>
%shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
%mul = tosa.mul %arg0, %zero, %shift : (tensor<f32>, tensor<f32>, tensor<1xi8>) -> tensor<f32>
// CHECK: return %[[ZERO]]
return %mul : tensor<f32>
}
Expand All @@ -249,7 +250,8 @@ func.func @fold_mul_zero_rhs_f32(%arg0: tensor<f32>) -> tensor<f32> {
func.func @fold_mul_zero_lhs_f32(%arg0: tensor<f32>) -> tensor<f32> {
%zero = "tosa.const"() {value = dense<0.0> : tensor<f32>} : () -> tensor<f32>
// CHECK: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0.000000e+00>
%mul = tosa.mul %zero, %arg0 : (tensor<f32>, tensor<f32>) -> tensor<f32>
%shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
%mul = tosa.mul %zero, %arg0, %shift : (tensor<f32>, tensor<f32>, tensor<1xi8>) -> tensor<f32>
// CHECK: return %[[ZERO]]
return %mul : tensor<f32>
}
Expand Down Expand Up @@ -283,7 +285,8 @@ func.func @fold_mul_zero_lhs_i32(%arg0: tensor<i32>) -> tensor<i32> {
// CHECK-LABEL: @fold_mul_one_rhs_f32
func.func @fold_mul_one_rhs_f32(%arg0: tensor<f32>) -> tensor<f32> {
%one = "tosa.const"() {value = dense<1.0> : tensor<f32>} : () -> tensor<f32>
%mul = tosa.mul %arg0, %one : (tensor<f32>, tensor<f32>) -> tensor<f32>
%shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
%mul = tosa.mul %arg0, %one, %shift : (tensor<f32>, tensor<f32>, tensor<1xi8>) -> tensor<f32>
// CHECK: return %arg0
return %mul : tensor<f32>
}
Expand All @@ -293,7 +296,8 @@ func.func @fold_mul_one_rhs_f32(%arg0: tensor<f32>) -> tensor<f32> {
// CHECK-LABEL: @fold_mul_one_lhs_f32
func.func @fold_mul_one_lhs_f32(%arg0: tensor<f32>) -> tensor<f32> {
%one = "tosa.const"() {value = dense<1.0> : tensor<f32>} : () -> tensor<f32>
%mul = tosa.mul %one, %arg0 : (tensor<f32>, tensor<f32>) -> tensor<f32>
%shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
%mul = tosa.mul %one, %arg0, %shift : (tensor<f32>, tensor<f32>, tensor<1xi8>) -> tensor<f32>
// CHECK: return %arg0
return %mul : tensor<f32>
}
Expand Down Expand Up @@ -339,7 +343,8 @@ func.func @fold_mul_splat_i8() -> tensor<10xi32> {
func.func @fold_mul_splat_f32() -> tensor<10xf32> {
%one = "tosa.const"() {value = dense<3.0> : tensor<10xf32>} : () -> tensor<10xf32>
%two = "tosa.const"() {value = dense<2.0> : tensor<10xf32>} : () -> tensor<10xf32>
%mul = tosa.mul %one, %two : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
%shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
%mul = tosa.mul %one, %two, %shift : (tensor<10xf32>, tensor<10xf32>, tensor<1xi8>) -> tensor<10xf32>
// CHECK: %[[THREE:.+]] = "tosa.const"() <{value = dense<6.000000e+00> : tensor<10xf32>}
// CHECK: return %[[THREE]]
return %mul : tensor<10xf32>
Expand Down
42 changes: 35 additions & 7 deletions mlir/test/Dialect/Tosa/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -750,26 +750,27 @@ func.func @test_transpose_conv2d_invalid_outshape(%arg0: tensor<1x32x32x8xf32>,

// CHECK-LABEL: test_mul_type_mismatch
func.func @test_mul_type_mismatch(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf16>) -> tensor<13x21x3xf32> {
%shift = "tosa.const"() {value = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
// expected-error@+1 {{'tosa.mul' op requires the same element type for all operands}}
%0 = tosa.mul %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<13x1x3xf16>) -> tensor<13x21x3xf32>
%0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xf32>, tensor<13x1x3xf16>, tensor<1xi8>) -> tensor<13x21x3xf32>
return %0 : tensor<13x21x3xf32>
}

// -----

// CHECK-LABEL: test_mul_invalid_shift
func.func @test_mul_invalid_shift(%arg0: tensor<13x21x3xi32>, %arg1: tensor<13x1x3xi32>) -> tensor<13x21x3xi32> {
%shift = "tosa.const"() {value = dense<0.0> : tensor<f32>} : () -> tensor<f32>
// expected-error@+1 {{'tosa.mul' op operand #2 must be 1D tensor of 8-bit signless integer values, but got 'tensor<f32>'}}
%0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xi32>, tensor<13x1x3xi32>, tensor<f32>) -> tensor<13x21x3xi32>
return %0 : tensor<13x21x3xi32>
func.func @test_mul_invalid_shift(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xf32> {
%shift = "tosa.const"() {value = dense<1> : tensor<1xi8>} : () -> tensor<1xi8>
// expected-error@+1 {{'tosa.mul' op require shift to be 0 for float type}}
%0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xf32>, tensor<13x1x3xf32>, tensor<1xi8>) -> tensor<13x21x3xf32>
return %0 : tensor<13x21x3xf32>
}

// -----

// CHECK-LABEL: test_mul_missing_shift
func.func @test_mul_missing_shift(%arg0: tensor<13x21x3xi32>, %arg1: tensor<13x1x3xi32>) -> tensor<13x21x3xi32> {
// this is ok because mul's shift operand is optional for now
// expected-error@+1 {{'tosa.mul' op expected 3 operands, but found 2}}
%0 = tosa.mul %arg0, %arg1 : (tensor<13x21x3xi32>, tensor<13x1x3xi32>) -> tensor<13x21x3xi32>
return %0 : tensor<13x21x3xi32>
}
Expand Down Expand Up @@ -1081,3 +1082,30 @@ func.func @test_sub_with_unequal_result_ranks(%arg0: tensor<1x21x3xf32>, %arg1:
%0 = tosa.sub %arg0, %arg1 : (tensor<1x21x3xf32>, tensor<13x21x3xf32>) -> tensor<1x13x21x3xf32>
return %0 : tensor<1x13x21x3xf32>
}

// -----
// CHECK-LABEL: test_mul_non_scalar_shift_2d
func.func @test_mul_non_scalar_shift_2d(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xf32> {
%shift = "tosa.const"() <{value = dense<0> : tensor<1x1xi8>}> : () -> tensor<1x1xi8>
// expected-error@+1 {{'tosa.mul' op operand #2 must be tosa-conformant scalar tensor of 8-bit signless integer values, but got 'tensor<1x1xi8>'}}
%0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xf32>, tensor<13x1x3xf32>, tensor<1x1xi8>) -> tensor<13x21x3xf32>
return %0 : tensor<13x21x3xf32>
}

// -----
// CHECK-LABEL: test_mul_non_scalar_shift_1d
func.func @test_mul_non_scalar_shift_1d(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xf32> {
%shift = "tosa.const"() <{value = dense<0> : tensor<2xi8>}> : () -> tensor<2xi8>
// expected-error@+1 {{'tosa.mul' op operand #2 must be tosa-conformant scalar tensor of 8-bit signless integer values, but got 'tensor<2xi8>'}}
%0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xf32>, tensor<13x1x3xf32>, tensor<2xi8>) -> tensor<13x21x3xf32>
return %0 : tensor<13x21x3xf32>
}

// -----
// CHECK-LABEL: test_mul_non_broadcast
func.func @test_mul_non_broadcast(%arg0: tensor<13x21x2xf32>, %arg1: tensor<3x1x3xf32>) -> tensor<13x21x3xf32> {
%shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
// expected-error@+1 {{'tosa.mul' op operands don't have broadcast-compatible shapes}}
%0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x2xf32>, tensor<3x1x3xf32>, tensor<1xi8>) -> tensor<13x21x3xf32>
return %0 : tensor<13x21x3xf32>
}
3 changes: 2 additions & 1 deletion mlir/test/Dialect/Tosa/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,8 @@ func.func @test_mul_scalar_with_unranked_output(%arg0: tensor<f32>, %arg1: tenso
// -----
// CHECK-LABEL: mul
func.func @test_mul(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xf32> {
%0 = tosa.mul %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<13x1x3xf32>) -> tensor<13x21x3xf32>
%shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
%0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xf32>, tensor<13x1x3xf32>, tensor<1xi8>) -> tensor<13x21x3xf32>
return %0 : tensor<13x21x3xf32>
}

Expand Down
Loading