Skip to content

Commit

Permalink
[TOSA] Switch zero point of avgpool2d to input variable type
Browse files Browse the repository at this point in the history
This commit changes the zero point attribute to an input to
align with the 1.0 spec.

Change-Id: Ieee6ba824327913bc8462cbcb7a74c0b6dd53d21
Signed-off-by: Luke Hutton <[email protected]>
  • Loading branch information
lhutton1 authored and Tai78641 committed Feb 27, 2025
1 parent f409340 commit a026655
Show file tree
Hide file tree
Showing 15 changed files with 232 additions and 118 deletions.
12 changes: 6 additions & 6 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ profileComplianceMap = {
{{{Profile::pro_int}, {{i8T, i32T}}},
{{Profile::pro_fp}, {{fp16T, i32T}, {fp32T, i32T}}}}},
{"tosa.avg_pool2d",
{{{Profile::pro_int}, {{i8T, i32T, i8T}}},
{{{Profile::pro_int}, {{i8T, i8T, i8T, i32T, i8T}}},
{{Profile::pro_fp},
{{fp16T, fp16T, fp16T}, {fp16T, fp32T, fp16T}, {fp32T, fp32T, fp32T}}}}},
{{fp16T, fp16T, fp16T, fp16T, fp16T}, {fp16T, fp16T, fp16T, fp32T, fp16T}, {fp32T, fp32T, fp32T, fp32T, fp32T}}}}},
{"tosa.conv2d",
{{{Profile::pro_int}, {{i8T, i8T, i32T, i32T, i32T}}},
{{Profile::pro_fp},
Expand Down Expand Up @@ -243,10 +243,10 @@ extensionComplianceMap = {
{{Extension::fp8e5m2}, {{fp8e5m2T, i32T}}},
{{Extension::bf16}, {{bf16T, i32T}}}}},
{"tosa.avg_pool2d",
{{{Extension::int16}, {{i16T, i32T, i16T}}},
{{Extension::fp8e4m3}, {{fp8e4m3T, fp16T, fp8e4m3T}}},
{{Extension::fp8e5m2}, {{fp8e5m2T, fp16T, fp8e5m2T}}},
{{Extension::bf16}, {{bf16T, fp32T, bf16T}}}}},
{{{Extension::int16}, {{i16T, i16T, i16T, i32T, i16T}}},
{{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T}}},
{{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T}}},
{{Extension::bf16}, {{bf16T, bf16T, bf16T, fp32T, bf16T}}}}},
{"tosa.conv2d",
{{{Extension::int4}, {{i8T, i4T, i32T, i32T, i32T}}},
{{Extension::int16}, {{i16T, i8T, i48T, i48T, i48T}}},
Expand Down
14 changes: 11 additions & 3 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,12 @@ def Tosa_AvgPool2dOp : Tosa_InferShapedTypeOp<"avg_pool2d"> {

let arguments = (ins
Tosa_Tensor4D:$input,
Tosa_ScalarTensor:$input_zp,
Tosa_ScalarTensor:$output_zp,
Tosa_IntArrayAttr2:$kernel,
Tosa_IntArrayAttr2:$stride,
Tosa_IntArrayAttr4:$pad,
TypeAttrOf<Tosa_AccType>:$acc_type,
OptionalAttr<I32Attr>:$input_zp,
OptionalAttr<I32Attr>:$output_zp
TypeAttrOf<Tosa_AccType>:$acc_type
);

let results = (outs
Expand All @@ -97,6 +97,14 @@ def Tosa_AvgPool2dOp : Tosa_InferShapedTypeOp<"avg_pool2d"> {
];

let builders = [Tosa_AvgPool2dOpQuantInfoBuilder];

let extraClassDeclaration = [{
LogicalResult getInputZeroPoint(int64_t &zp);
LogicalResult getOutputZeroPoint(int64_t &zp);
LogicalResult verifyInputZeroPoint(int64_t zp);
LogicalResult verifyOutputZeroPoint(int64_t zp);
}];

let hasVerifier = 1;
}

Expand Down
21 changes: 15 additions & 6 deletions mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -804,6 +804,15 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
return failure();
SmallVector<Value> dynamicDims = *dynamicDimsOr;

int64_t inputZpVal;
int64_t outputZpVal;
if (op.getInputZeroPoint(inputZpVal).failed() ||
op.getOutputZeroPoint(outputZpVal).failed()) {
(void)rewriter.notifyMatchFailure(
op, "zero points could not be statically determined");
return failure();
}

// Apply padding as necessary.
llvm::SmallVector<int64_t> pad;
pad.resize(2, 0);
Expand Down Expand Up @@ -923,9 +932,9 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {

// If we have quantization information we need to apply an offset
// for the input zp value.
if (op.getInputZp()) {
auto inputZp =
rewriter.create<arith::ConstantOp>(loc, op.getInputZpAttr());
if (inputZpVal != 0) {
auto inputZp = rewriter.create<arith::ConstantOp>(
loc, b.getIntegerAttr(accETy, inputZpVal));
Value offset =
rewriter.create<arith::MulIOp>(loc, accETy, count, inputZp);
poolVal =
Expand Down Expand Up @@ -977,9 +986,9 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {

// If we have quantization information we need to apply output
// zeropoint.
if (op.getOutputZp()) {
auto outputZp =
rewriter.create<arith::ConstantOp>(loc, op.getOutputZpAttr());
if (outputZpVal != 0) {
auto outputZp = rewriter.create<arith::ConstantOp>(
loc, b.getIntegerAttr(scaled.getType(), outputZpVal));
scaled = rewriter.create<arith::AddIOp>(loc, scaled, outputZp)
.getResult();
}
Expand Down
111 changes: 64 additions & 47 deletions mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -455,18 +455,10 @@ LogicalResult tosa::ArgMaxOp::verify() {
}

LogicalResult tosa::AvgPool2dOp::verify() {
auto inputType = llvm::cast<ShapedType>(getInput().getType());

auto inputETy = inputType.getElementType();
auto resultETy = llvm::cast<ShapedType>(getType()).getElementType();

if (auto quantType =
llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputETy))
inputETy = quantType.getStorageType();

if (auto quantType =
llvm::dyn_cast<mlir::quant::UniformQuantizedType>(resultETy))
resultETy = quantType.getStorageType();
const Type inputETy = getStorageElementTypeOrSelf(getInput().getType());
const Type resultETy = getStorageElementTypeOrSelf(getOutput().getType());
const Type inputZpETy = getStorageElementTypeOrSelf(getInputZp().getType());
const Type outputZpETy = getStorageElementTypeOrSelf(getOutputZp().getType());

auto accType = getAccType();
if (llvm::isa<IntegerType>(inputETy) && !accType.isInteger(32))
Expand All @@ -481,6 +473,28 @@ LogicalResult tosa::AvgPool2dOp::verify() {
if (inputETy.isF32() && !accType.isF32())
return emitOpError("accumulator type for f32 tensor is not f32");

if (inputETy != inputZpETy)
return emitOpError("expect both input and its zero point are the same "
"element type, got ")
<< inputETy << " and " << inputZpETy;

if (resultETy != outputZpETy)
return emitOpError("expect both output and its zero point are the same "
"element type, got ")
<< resultETy << " and " << outputZpETy;

int64_t inputZpVal;
if (getInputZeroPoint(inputZpVal).succeeded() &&
verifyInputZeroPoint(inputZpVal).failed())
return emitOpError(
"input zero point must be zero for non-int8 integer types");

int64_t outputZpVal;
if (getOutputZeroPoint(outputZpVal).succeeded() &&
verifyOutputZeroPoint(outputZpVal).failed())
return emitOpError(
"output zero point must be zero for non-int8 integer types");

if ((inputETy.isF32() && resultETy.isF32()) ||
(inputETy.isF16() && resultETy.isF16()) ||
(inputETy.isBF16() && resultETy.isBF16()) ||
Expand Down Expand Up @@ -629,27 +643,37 @@ static void buildMatMulOpWithQuantInfo(OpBuilder &builder,
}

/// Both the tosa.avg_pool2d and unary ops use the same
/// UnaruOpQuantizationAttr but avg_pool operator has its own builder as it
/// UnaryOpQuantizationAttr but avg_pool operator has its own builder as it
/// has additional parameters not part of the unary ops.
static void
buildAvgPool2dOpWithQuantInfo(OpBuilder &builder, OperationState &result,
Type outputType, Value input,
DenseArrayAttr kernel, DenseArrayAttr stride,
DenseArrayAttr pad, TypeAttr accType) {
result.addOperands(input);
const Location loc{result.location};
int64_t inputZp{0};
int64_t outputZp{0};

auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType);
if (quantAttr) {
inputZp = quantAttr.getInputZp();
outputZp = quantAttr.getOutputZp();
}
const std::optional<Value> inputZpOp =
createZeroPointTensor(builder, loc, input.getType(), inputZp);
assert(
inputZpOp.has_value() &&
"Failed to create input zero point tensor for quantized AVG_POOL2D op");
const std::optional<Value> outputZpOp =
createZeroPointTensor(builder, loc, outputType, outputZp);
assert(
outputZpOp.has_value() &&
"Failed to create output zero point tensor for quantized AVG_POOL2D op");
result.addOperands({input, inputZpOp.value(), outputZpOp.value()});
result.addAttribute("kernel", kernel);
result.addAttribute("stride", stride);
result.addAttribute("pad", pad);
result.addAttribute("acc_type", accType);
auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType);
if (quantAttr) {
result.addAttribute("input_zp",
builder.getI32IntegerAttr(
static_cast<int32_t>(quantAttr.getInputZp())));
result.addAttribute("output_zp",
builder.getI32IntegerAttr(
static_cast<int32_t>(quantAttr.getOutputZp())));
}
result.types.push_back(outputType);
}

Expand Down Expand Up @@ -1425,13 +1449,6 @@ static LogicalResult getZeroPoint(T op, Value val, int64_t &zp) {

template <typename T>
static LogicalResult verifyZeroPoint(T op, Value val, int64_t &zp) {
// TODO clean it up when the entire zero point (attribute -> input tensor
// type) change is done. Remaining Matmul, Rescale, Negate, and AvgPool2D.
if constexpr (!std::is_same_v<T, Conv2DOp> && !std::is_same_v<T, Conv3DOp> &&
!std::is_same_v<T, DepthwiseConv2DOp> &&
!std::is_same_v<T, TransposeConv2DOp>)
return failure();

Type zpElemType = getElementTypeOrSelf(val);

if (!zpElemType.isIntOrFloat())
Expand All @@ -1446,24 +1463,24 @@ static LogicalResult verifyZeroPoint(T op, Value val, int64_t &zp) {
return success();
}

#define ZERO_POINT_HELPER(OP) \
LogicalResult tosa::OP::getInputZeroPoint(int64_t &zp) { \
return getZeroPoint(*this, getInputZp(), zp); \
#define ZERO_POINT_HELPER(OP, OPERAND_NAME) \
LogicalResult tosa::OP::get##OPERAND_NAME##ZeroPoint(int64_t &zp) { \
return getZeroPoint(*this, get##OPERAND_NAME##Zp(), zp); \
} \
LogicalResult tosa::OP::getWeightZeroPoint(int64_t &zp) { \
return getZeroPoint(*this, getWeightZp(), zp); \
} \
LogicalResult tosa::OP::verifyInputZeroPoint(int64_t zp) { \
return verifyZeroPoint(*this, getInputZp(), zp); \
} \
LogicalResult tosa::OP::verifyWeightZeroPoint(int64_t zp) { \
return verifyZeroPoint(*this, getWeightZp(), zp); \
}

ZERO_POINT_HELPER(Conv2DOp)
ZERO_POINT_HELPER(Conv3DOp)
ZERO_POINT_HELPER(DepthwiseConv2DOp)
ZERO_POINT_HELPER(TransposeConv2DOp)
LogicalResult tosa::OP::verify##OPERAND_NAME##ZeroPoint(int64_t zp) { \
return verifyZeroPoint(*this, get##OPERAND_NAME##Zp(), zp); \
}

ZERO_POINT_HELPER(Conv2DOp, Input)
ZERO_POINT_HELPER(Conv2DOp, Weight)
ZERO_POINT_HELPER(Conv3DOp, Input)
ZERO_POINT_HELPER(Conv3DOp, Weight)
ZERO_POINT_HELPER(DepthwiseConv2DOp, Input)
ZERO_POINT_HELPER(DepthwiseConv2DOp, Weight)
ZERO_POINT_HELPER(TransposeConv2DOp, Input)
ZERO_POINT_HELPER(TransposeConv2DOp, Weight)
ZERO_POINT_HELPER(AvgPool2dOp, Input)
ZERO_POINT_HELPER(AvgPool2dOp, Output)
#undef ZERO_POINT_HELPER

LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
Expand Down
2 changes: 2 additions & 0 deletions mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ void ProfileInfoDepot::populateProfileInfo(tosa::ConcatOp op) {
template <>
void ProfileInfoDepot::populateProfileInfo(tosa::AvgPool2dOp op) {
addValue(op.getInput());
addValue(op.getInputZp());
addValue(op.getOutputZp());
addType(op.getAccType());
addValue(op.getOutput());
}
Expand Down
4 changes: 2 additions & 2 deletions mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
// RUN: mlir-opt --split-input-file -pass-pipeline="builtin.module(func.func(tosa-to-linalg))" %s -verify-diagnostics

// CHECK-LABEL: @avg_pool2d_with_unsupported_quant_type
func.func @avg_pool2d_with_unsupported_quant_type(%arg0: tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>>) -> tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>> {
func.func @avg_pool2d_with_unsupported_quant_type(%arg0: tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>>, %arg1: tensor<1xi8>, %arg2: tensor<1xi8>) -> tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>> {
// expected-error@+1 {{failed to legalize operation 'tosa.avg_pool2d'}}
%0 = "tosa.avg_pool2d"(%arg0) {acc_type = i32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} : (tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>>) -> tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>>
%0 = "tosa.avg_pool2d"(%arg0, %arg1, %arg2) {acc_type = i32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} : (tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>>
return %0 : tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>>
}

Expand Down
16 changes: 12 additions & 4 deletions mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,9 @@ func.func @avg_pool_f32(%arg0: tensor<1x6x34x62xf32>) -> (tensor<1x5x33x62xf32>)
// CHECK: %[[FLT:.+]] = arith.sitofp %[[CAST]]
// CHECK: %[[DIV:.+]] = arith.divf %[[IN]], %[[FLT]]
// CHECK: linalg.yield %[[DIV]]
%0 = tosa.avg_pool2d %arg0 {acc_type = f32, pad = array<i64: 1, 1, 1, 1>, kernel = array<i64: 4, 4>, stride = array<i64: 1, 1>} : (tensor<1x6x34x62xf32>) -> tensor<1x5x33x62xf32>
%input_zp = "tosa.const"() <{value = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
%output_zp = "tosa.const"() <{value = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
%0 = tosa.avg_pool2d %arg0, %input_zp, %output_zp {acc_type = f32, pad = array<i64: 1, 1, 1, 1>, kernel = array<i64: 4, 4>, stride = array<i64: 1, 1>} : (tensor<1x6x34x62xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x5x33x62xf32>
return %0 : tensor<1x5x33x62xf32>
}

Expand Down Expand Up @@ -375,7 +377,9 @@ func.func @avg_pool_f16_f32acc(%arg0: tensor<1x6x34x62xf16>) -> (tensor<1x5x33x6
// CHECK: %[[DIV:.+]] = arith.divf %[[IN]], %[[FLT]]
// CHECK: %[[TRUNC:.+]] = arith.truncf %[[DIV]]
// CHECK: linalg.yield %[[TRUNC]]
%0 = tosa.avg_pool2d %arg0 {acc_type = f32, pad = array<i64: 1, 1, 1, 1>, kernel = array<i64: 4, 4>, stride = array<i64: 1, 1>} : (tensor<1x6x34x62xf16>) -> tensor<1x5x33x62xf16>
%input_zp = "tosa.const"() <{value = dense<0.0> : tensor<1xf16>}> : () -> tensor<1xf16>
%output_zp = "tosa.const"() <{value = dense<0.0> : tensor<1xf16>}> : () -> tensor<1xf16>
%0 = tosa.avg_pool2d %arg0, %input_zp, %output_zp {acc_type = f32, pad = array<i64: 1, 1, 1, 1>, kernel = array<i64: 4, 4>, stride = array<i64: 1, 1>} : (tensor<1x6x34x62xf16>, tensor<1xf16>, tensor<1xf16>) -> tensor<1x5x33x62xf16>
return %0 : tensor<1x5x33x62xf16>
}

Expand Down Expand Up @@ -416,7 +420,9 @@ func.func @avg_pool_i8(%arg0: tensor<1x6x34x62xi8>) -> (tensor<1x5x33x62xi8>) {
// CHECK: %[[CLAMP:.+]] = arith.minsi %[[CMAX]], %[[LOW]]
// CHECK: %[[TRUNC:.+]] = arith.trunci %[[CLAMP]]
// CHECK: linalg.yield %[[TRUNC]]
%0 = tosa.avg_pool2d %arg0 {acc_type = i32, pad = array<i64: 1, 1, 1, 1>, kernel = array<i64: 4, 4>, stride = array<i64: 1, 1>} : (tensor<1x6x34x62xi8>) -> tensor<1x5x33x62xi8>
%input_zp = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
%output_zp = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
%0 = tosa.avg_pool2d %arg0, %input_zp, %output_zp {acc_type = i32, pad = array<i64: 1, 1, 1, 1>, kernel = array<i64: 4, 4>, stride = array<i64: 1, 1>} : (tensor<1x6x34x62xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x5x33x62xi8>
return %0 : tensor<1x5x33x62xi8>
}

Expand All @@ -439,7 +445,9 @@ func.func @avg_pool_dyn(%arg0: tensor<?x6x34x62xf32>) -> (tensor<?x5x33x62xf32>)
// CHECK-SAME: outs(%[[FILL]] : tensor<?x5x33x62xf32>) -> tensor<?x5x33x62xf32>
// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[BATCH]]) : tensor<?x5x33x62xf32>
// CHECK: %[[GENERIC:.+]] = linalg.generic
%0 = tosa.avg_pool2d %arg0 {acc_type = f32, pad = array<i64: 1, 1, 1, 1>, kernel = array<i64: 4, 4>, stride = array<i64: 1, 1>} : (tensor<?x6x34x62xf32>) -> tensor<?x5x33x62xf32>
%input_zp = "tosa.const"() <{value = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
%output_zp = "tosa.const"() <{value = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
%0 = tosa.avg_pool2d %arg0, %input_zp, %output_zp {acc_type = f32, pad = array<i64: 1, 1, 1, 1>, kernel = array<i64: 4, 4>, stride = array<i64: 1, 1>} : (tensor<?x6x34x62xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x5x33x62xf32>
return %0 : tensor<?x5x33x62xf32>
}

Expand Down
10 changes: 5 additions & 5 deletions mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,18 @@ func.func @tensor_with_unknown_rank(%arg0: tensor<*xi8>) -> tensor<*xi8> {
// -----

// check that tosa verify kick in
func.func @test_avg_pool2d_zero_dim_input(%arg0: tensor<1x0x?x9xf32>) -> tensor<1x7x7x9xf32> {
func.func @test_avg_pool2d_zero_dim_input(%arg0: tensor<1x0x?x9xf32>, %arg1: tensor<1xf32>, %arg2: tensor<1xf32>) -> tensor<1x7x7x9xf32> {
// expected-error@+1 {{'tosa.avg_pool2d' op operand #0 must be 4-d tosa-conformant tensor, but got 'tensor<1x0x?x9xf32>'}}
%0 = "tosa.avg_pool2d"(%arg0) {acc_type = f32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>}
: (tensor<1x0x?x9xf32>) -> tensor<1x7x7x9xf32>
%0 = "tosa.avg_pool2d"(%arg0, %arg1, %arg2) {acc_type = f32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>}
: (tensor<1x0x?x9xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x7x7x9xf32>
return %0 : tensor<1x7x7x9xf32>
}

// -----

// check that --tosa-to-linalg kick in
func.func @avg_pool2d_with_unsupported_quant_type(%arg0: tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>>) -> tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>> {
func.func @avg_pool2d_with_unsupported_quant_type(%arg0: tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>>, %arg1: tensor<1xi8>, %arg2: tensor<1xi8>) -> tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>> {
// expected-error@+1 {{failed to legalize operation 'tosa.avg_pool2d'}}
%0 = "tosa.avg_pool2d"(%arg0) {acc_type = i32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} : (tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>>) -> tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>>
%0 = "tosa.avg_pool2d"(%arg0, %arg1, %arg2) {acc_type = i32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} : (tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>>
return %0 : tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>>
}
4 changes: 3 additions & 1 deletion mlir/test/Dialect/Tosa/availability.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ func.func @test_argmax(%arg0: tensor<14x19xf32>) -> tensor<14xi32> {
func.func @test_avg_pool2d(%arg0: tensor<1x7x7x9xf32>) -> tensor<1x7x7x9xf32> {
// CHECK: profiles: [ [pro_int, pro_fp] ]
// CHECK: extensions: [ [int16, fp8e4m3, fp8e5m2, bf16] ]
%0 = tosa.avg_pool2d %arg0 {acc_type = f32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} : (tensor<1x7x7x9xf32>) -> tensor<1x7x7x9xf32>
%input_zp = "tosa.const"() <{value = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
%output_zp = "tosa.const"() <{value = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
%0 = tosa.avg_pool2d %arg0, %input_zp, %output_zp {acc_type = f32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} : (tensor<1x7x7x9xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x7x7x9xf32>
return %0 : tensor<1x7x7x9xf32>
}

Expand Down
Loading

0 comments on commit a026655

Please sign in to comment.