Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir[tosa] Switch zero point of avgpool2d to input variable type #128983

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

Tai78641
Copy link
Contributor

This commit changes the TOSA operator AvgPool2d's zero point attributes to inputs to align with TOSA 1.0 spec.

Change-Id: Ieee6ba824327913bc8462cbcb7a74c0b6dd53d21

@llvmbot
Copy link
Member

llvmbot commented Feb 27, 2025

@llvm/pr-subscribers-mlir-tosa
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-linalg

Author: Tai Ly (Tai78641)

Changes

This commit changes the TOSA operator AvgPool2d's zero point attributes to inputs to align with TOSA 1.0 spec.

Change-Id: Ieee6ba824327913bc8462cbcb7a74c0b6dd53d21


Patch is 44.40 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/128983.diff

15 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc (+6-6)
  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (+11-3)
  • (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp (+15-6)
  • (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+64-47)
  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp (+2)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir (+2-2)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir (+12-4)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir (+5-5)
  • (modified) mlir/test/Dialect/Tosa/availability.mlir (+3-1)
  • (modified) mlir/test/Dialect/Tosa/invalid.mlir (+50-6)
  • (modified) mlir/test/Dialect/Tosa/level_check.mlir (+24-24)
  • (modified) mlir/test/Dialect/Tosa/ops.mlir (+18-6)
  • (modified) mlir/test/Dialect/Tosa/profile_all_unsupported.mlir (+2-2)
  • (modified) mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir (+2-2)
  • (modified) mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir (+16-4)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
index 2617a902c3a0d..941a7fe2b23ad 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
@@ -5,9 +5,9 @@ profileComplianceMap = {
      {{{Profile::pro_int}, {{i8T, i32T}}},
       {{Profile::pro_fp}, {{fp16T, i32T}, {fp32T, i32T}}}}},
     {"tosa.avg_pool2d",
-     {{{Profile::pro_int}, {{i8T, i32T, i8T}}},
+     {{{Profile::pro_int}, {{i8T, i8T, i8T, i32T, i8T}}},
       {{Profile::pro_fp},
-       {{fp16T, fp16T, fp16T}, {fp16T, fp32T, fp16T}, {fp32T, fp32T, fp32T}}}}},
+       {{fp16T, fp16T, fp16T, fp16T, fp16T}, {fp16T, fp16T, fp16T, fp32T, fp16T}, {fp32T, fp32T, fp32T, fp32T, fp32T}}}}},
     {"tosa.conv2d",
      {{{Profile::pro_int}, {{i8T, i8T, i32T, i32T, i32T}}},
       {{Profile::pro_fp},
@@ -243,10 +243,10 @@ extensionComplianceMap = {
       {{Extension::fp8e5m2}, {{fp8e5m2T, i32T}}},
       {{Extension::bf16}, {{bf16T, i32T}}}}},
     {"tosa.avg_pool2d",
-     {{{Extension::int16}, {{i16T, i32T, i16T}}},
-      {{Extension::fp8e4m3}, {{fp8e4m3T, fp16T, fp8e4m3T}}},
-      {{Extension::fp8e5m2}, {{fp8e5m2T, fp16T, fp8e5m2T}}},
-      {{Extension::bf16}, {{bf16T, fp32T, bf16T}}}}},
+     {{{Extension::int16}, {{i16T, i16T, i16T, i32T, i16T}}},
+      {{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T}}},
+      {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T}}},
+      {{Extension::bf16}, {{bf16T, bf16T, bf16T, fp32T, bf16T}}}}},
     {"tosa.conv2d",
      {{{Extension::int4}, {{i8T, i4T, i32T, i32T, i32T}}},
       {{Extension::int16}, {{i16T, i8T, i48T, i48T, i48T}}},
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index ddfec2c9bfcd3..27a4bc57d1964 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -79,12 +79,12 @@ def Tosa_AvgPool2dOp : Tosa_InferShapedTypeOp<"avg_pool2d"> {
 
   let arguments = (ins
     Tosa_Tensor4D:$input,
+    Tosa_ScalarTensor:$input_zp,
+    Tosa_ScalarTensor:$output_zp,
     Tosa_IntArrayAttr2:$kernel,
     Tosa_IntArrayAttr2:$stride,
     Tosa_IntArrayAttr4:$pad,
-    TypeAttrOf<Tosa_AccType>:$acc_type,
-    OptionalAttr<I32Attr>:$input_zp,
-    OptionalAttr<I32Attr>:$output_zp
+    TypeAttrOf<Tosa_AccType>:$acc_type
   );
 
   let results = (outs
@@ -97,6 +97,14 @@ def Tosa_AvgPool2dOp : Tosa_InferShapedTypeOp<"avg_pool2d"> {
   ];
 
   let builders = [Tosa_AvgPool2dOpQuantInfoBuilder];
+
+  let extraClassDeclaration = [{
+    LogicalResult getInputZeroPoint(int64_t &zp);
+    LogicalResult getOutputZeroPoint(int64_t &zp);
+    LogicalResult verifyInputZeroPoint(int64_t zp);
+    LogicalResult verifyOutputZeroPoint(int64_t zp);
+  }];
+
   let hasVerifier = 1;
 }
 
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index 006e35806d64f..9c8d4f75c6e37 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -804,6 +804,15 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
       return failure();
     SmallVector<Value> dynamicDims = *dynamicDimsOr;
 
+    int64_t inputZpVal;
+    int64_t outputZpVal;
+    if (op.getInputZeroPoint(inputZpVal).failed() ||
+        op.getOutputZeroPoint(outputZpVal).failed()) {
+      (void)rewriter.notifyMatchFailure(
+          op, "zero points could not be statically determined");
+      return failure();
+    }
+
     // Apply padding as necessary.
     llvm::SmallVector<int64_t> pad;
     pad.resize(2, 0);
@@ -923,9 +932,9 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
 
             // If we have quantization information we need to apply an offset
             // for the input zp value.
-            if (op.getInputZp()) {
-              auto inputZp =
-                  rewriter.create<arith::ConstantOp>(loc, op.getInputZpAttr());
+            if (inputZpVal != 0) {
+              auto inputZp = rewriter.create<arith::ConstantOp>(
+                  loc, b.getIntegerAttr(accETy, inputZpVal));
               Value offset =
                   rewriter.create<arith::MulIOp>(loc, accETy, count, inputZp);
               poolVal =
@@ -977,9 +986,9 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
 
             // If we have quantization information we need to apply output
             // zeropoint.
-            if (op.getOutputZp()) {
-              auto outputZp =
-                  rewriter.create<arith::ConstantOp>(loc, op.getOutputZpAttr());
+            if (outputZpVal != 0) {
+              auto outputZp = rewriter.create<arith::ConstantOp>(
+                  loc, b.getIntegerAttr(scaled.getType(), outputZpVal));
               scaled = rewriter.create<arith::AddIOp>(loc, scaled, outputZp)
                            .getResult();
             }
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 7b50eceb081dd..63e183538c257 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -455,18 +455,10 @@ LogicalResult tosa::ArgMaxOp::verify() {
 }
 
 LogicalResult tosa::AvgPool2dOp::verify() {
-  auto inputType = llvm::cast<ShapedType>(getInput().getType());
-
-  auto inputETy = inputType.getElementType();
-  auto resultETy = llvm::cast<ShapedType>(getType()).getElementType();
-
-  if (auto quantType =
-          llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputETy))
-    inputETy = quantType.getStorageType();
-
-  if (auto quantType =
-          llvm::dyn_cast<mlir::quant::UniformQuantizedType>(resultETy))
-    resultETy = quantType.getStorageType();
+  const Type inputETy = getStorageElementTypeOrSelf(getInput().getType());
+  const Type resultETy = getStorageElementTypeOrSelf(getOutput().getType());
+  const Type inputZpETy = getStorageElementTypeOrSelf(getInputZp().getType());
+  const Type outputZpETy = getStorageElementTypeOrSelf(getOutputZp().getType());
 
   auto accType = getAccType();
   if (llvm::isa<IntegerType>(inputETy) && !accType.isInteger(32))
@@ -481,6 +473,28 @@ LogicalResult tosa::AvgPool2dOp::verify() {
   if (inputETy.isF32() && !accType.isF32())
     return emitOpError("accumulator type for f32 tensor is not f32");
 
+  if (inputETy != inputZpETy)
+    return emitOpError("expect both input and its zero point are the same "
+                       "element type, got ")
+           << inputETy << " and " << inputZpETy;
+
+  if (resultETy != outputZpETy)
+    return emitOpError("expect both output and its zero point are the same "
+                       "element type, got ")
+           << resultETy << " and " << outputZpETy;
+
+  int64_t inputZpVal;
+  if (getInputZeroPoint(inputZpVal).succeeded() &&
+      verifyInputZeroPoint(inputZpVal).failed())
+    return emitOpError(
+        "input zero point must be zero for non-int8 integer types");
+
+  int64_t outputZpVal;
+  if (getOutputZeroPoint(outputZpVal).succeeded() &&
+      verifyOutputZeroPoint(outputZpVal).failed())
+    return emitOpError(
+        "output zero point must be zero for non-int8 integer types");
+
   if ((inputETy.isF32() && resultETy.isF32()) ||
       (inputETy.isF16() && resultETy.isF16()) ||
       (inputETy.isBF16() && resultETy.isBF16()) ||
@@ -629,27 +643,37 @@ static void buildMatMulOpWithQuantInfo(OpBuilder &builder,
 }
 
 /// Both the tosa.avg_pool2d and unary ops use the same
-/// UnaruOpQuantizationAttr but avg_pool operator has its own builder as it
+/// UnaryOpQuantizationAttr but avg_pool operator has its own builder as it
 /// has additional parameters not part of the unary ops.
 static void
 buildAvgPool2dOpWithQuantInfo(OpBuilder &builder, OperationState &result,
                               Type outputType, Value input,
                               DenseArrayAttr kernel, DenseArrayAttr stride,
                               DenseArrayAttr pad, TypeAttr accType) {
-  result.addOperands(input);
+  const Location loc{result.location};
+  int64_t inputZp{0};
+  int64_t outputZp{0};
+
+  auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType);
+  if (quantAttr) {
+    inputZp = quantAttr.getInputZp();
+    outputZp = quantAttr.getOutputZp();
+  }
+  const std::optional<Value> inputZpOp =
+      createZeroPointTensor(builder, loc, input.getType(), inputZp);
+  assert(
+      inputZpOp.has_value() &&
+      "Failed to create input zero point tensor for quantized AVG_POOL2D op");
+  const std::optional<Value> outputZpOp =
+      createZeroPointTensor(builder, loc, outputType, outputZp);
+  assert(
+      outputZpOp.has_value() &&
+      "Failed to create output zero point tensor for quantized AVG_POOL2D op");
+  result.addOperands({input, inputZpOp.value(), outputZpOp.value()});
   result.addAttribute("kernel", kernel);
   result.addAttribute("stride", stride);
   result.addAttribute("pad", pad);
   result.addAttribute("acc_type", accType);
-  auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType);
-  if (quantAttr) {
-    result.addAttribute("input_zp",
-                        builder.getI32IntegerAttr(
-                            static_cast<int32_t>(quantAttr.getInputZp())));
-    result.addAttribute("output_zp",
-                        builder.getI32IntegerAttr(
-                            static_cast<int32_t>(quantAttr.getOutputZp())));
-  }
   result.types.push_back(outputType);
 }
 
@@ -1425,13 +1449,6 @@ static LogicalResult getZeroPoint(T op, Value val, int64_t &zp) {
 
 template <typename T>
 static LogicalResult verifyZeroPoint(T op, Value val, int64_t &zp) {
-  // TODO clean it up when the entire zero point (attribute -> input tensor
-  // type) change is done. Remaining Matmul, Rescale, Negate, and AvgPool2D.
-  if constexpr (!std::is_same_v<T, Conv2DOp> && !std::is_same_v<T, Conv3DOp> &&
-                !std::is_same_v<T, DepthwiseConv2DOp> &&
-                !std::is_same_v<T, TransposeConv2DOp>)
-    return failure();
-
   Type zpElemType = getElementTypeOrSelf(val);
 
   if (!zpElemType.isIntOrFloat())
@@ -1446,24 +1463,24 @@ static LogicalResult verifyZeroPoint(T op, Value val, int64_t &zp) {
   return success();
 }
 
-#define ZERO_POINT_HELPER(OP)                                                  \
-  LogicalResult tosa::OP::getInputZeroPoint(int64_t &zp) {                     \
-    return getZeroPoint(*this, getInputZp(), zp);                              \
+#define ZERO_POINT_HELPER(OP, OPERAND_NAME)                                    \
+  LogicalResult tosa::OP::get##OPERAND_NAME##ZeroPoint(int64_t &zp) {          \
+    return getZeroPoint(*this, get##OPERAND_NAME##Zp(), zp);                   \
   }                                                                            \
-  LogicalResult tosa::OP::getWeightZeroPoint(int64_t &zp) {                    \
-    return getZeroPoint(*this, getWeightZp(), zp);                             \
-  }                                                                            \
-  LogicalResult tosa::OP::verifyInputZeroPoint(int64_t zp) {                   \
-    return verifyZeroPoint(*this, getInputZp(), zp);                           \
-  }                                                                            \
-  LogicalResult tosa::OP::verifyWeightZeroPoint(int64_t zp) {                  \
-    return verifyZeroPoint(*this, getWeightZp(), zp);                          \
-  }
-
-ZERO_POINT_HELPER(Conv2DOp)
-ZERO_POINT_HELPER(Conv3DOp)
-ZERO_POINT_HELPER(DepthwiseConv2DOp)
-ZERO_POINT_HELPER(TransposeConv2DOp)
+  LogicalResult tosa::OP::verify##OPERAND_NAME##ZeroPoint(int64_t zp) {        \
+    return verifyZeroPoint(*this, get##OPERAND_NAME##Zp(), zp);                \
+  }
+
+ZERO_POINT_HELPER(Conv2DOp, Input)
+ZERO_POINT_HELPER(Conv2DOp, Weight)
+ZERO_POINT_HELPER(Conv3DOp, Input)
+ZERO_POINT_HELPER(Conv3DOp, Weight)
+ZERO_POINT_HELPER(DepthwiseConv2DOp, Input)
+ZERO_POINT_HELPER(DepthwiseConv2DOp, Weight)
+ZERO_POINT_HELPER(TransposeConv2DOp, Input)
+ZERO_POINT_HELPER(TransposeConv2DOp, Weight)
+ZERO_POINT_HELPER(AvgPool2dOp, Input)
+ZERO_POINT_HELPER(AvgPool2dOp, Output)
 #undef ZERO_POINT_HELPER
 
 LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
index 1d8aaa65c2976..345616c9563b5 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
@@ -58,6 +58,8 @@ void ProfileInfoDepot::populateProfileInfo(tosa::ConcatOp op) {
 template <>
 void ProfileInfoDepot::populateProfileInfo(tosa::AvgPool2dOp op) {
   addValue(op.getInput());
+  addValue(op.getInputZp());
+  addValue(op.getOutputZp());
   addType(op.getAccType());
   addValue(op.getOutput());
 }
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
index 5db3f56cf459e..3bd1049f1a0ce 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
@@ -1,9 +1,9 @@
 // RUN: mlir-opt --split-input-file -pass-pipeline="builtin.module(func.func(tosa-to-linalg))" %s -verify-diagnostics
 
 // CHECK-LABEL: @avg_pool2d_with_unsupported_quant_type
-func.func @avg_pool2d_with_unsupported_quant_type(%arg0: tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>>) -> tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>> {
+func.func @avg_pool2d_with_unsupported_quant_type(%arg0: tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>>, %arg1: tensor<1xi8>, %arg2: tensor<1xi8>) -> tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>> {
   // expected-error@+1 {{failed to legalize operation 'tosa.avg_pool2d'}}
-  %0 = "tosa.avg_pool2d"(%arg0) {acc_type = i32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} : (tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>>) -> tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>>
+  %0 = "tosa.avg_pool2d"(%arg0, %arg1, %arg2) {acc_type = i32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} : (tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>>
   return %0 : tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>>
 }
 
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
index 332b706871547..386855420f48e 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
@@ -290,7 +290,9 @@ func.func @avg_pool_f32(%arg0: tensor<1x6x34x62xf32>) -> (tensor<1x5x33x62xf32>)
   // CHECK:   %[[FLT:.+]] = arith.sitofp %[[CAST]]
   // CHECK:   %[[DIV:.+]] = arith.divf %[[IN]], %[[FLT]]
   // CHECK:   linalg.yield %[[DIV]]
-  %0 = tosa.avg_pool2d %arg0 {acc_type = f32, pad = array<i64: 1, 1, 1, 1>, kernel = array<i64: 4, 4>, stride = array<i64: 1, 1>} : (tensor<1x6x34x62xf32>) -> tensor<1x5x33x62xf32>
+  %input_zp = "tosa.const"() <{value = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+  %output_zp = "tosa.const"() <{value = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+  %0 = tosa.avg_pool2d %arg0, %input_zp, %output_zp {acc_type = f32, pad = array<i64: 1, 1, 1, 1>, kernel = array<i64: 4, 4>, stride = array<i64: 1, 1>} : (tensor<1x6x34x62xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x5x33x62xf32>
   return %0 : tensor<1x5x33x62xf32>
 }
 
@@ -375,7 +377,9 @@ func.func @avg_pool_f16_f32acc(%arg0: tensor<1x6x34x62xf16>) -> (tensor<1x5x33x6
   // CHECK:   %[[DIV:.+]] = arith.divf %[[IN]], %[[FLT]]
   // CHECK:   %[[TRUNC:.+]] = arith.truncf %[[DIV]]
   // CHECK:   linalg.yield %[[TRUNC]]
-  %0 = tosa.avg_pool2d %arg0 {acc_type = f32, pad = array<i64: 1, 1, 1, 1>, kernel = array<i64: 4, 4>, stride = array<i64: 1, 1>} : (tensor<1x6x34x62xf16>) -> tensor<1x5x33x62xf16>
+  %input_zp = "tosa.const"() <{value = dense<0.0> : tensor<1xf16>}> : () -> tensor<1xf16>
+  %output_zp = "tosa.const"() <{value = dense<0.0> : tensor<1xf16>}> : () -> tensor<1xf16>
+  %0 = tosa.avg_pool2d %arg0, %input_zp, %output_zp {acc_type = f32, pad = array<i64: 1, 1, 1, 1>, kernel = array<i64: 4, 4>, stride = array<i64: 1, 1>} : (tensor<1x6x34x62xf16>, tensor<1xf16>, tensor<1xf16>) -> tensor<1x5x33x62xf16>
   return %0 : tensor<1x5x33x62xf16>
 }
 
@@ -416,7 +420,9 @@ func.func @avg_pool_i8(%arg0: tensor<1x6x34x62xi8>) -> (tensor<1x5x33x62xi8>) {
   // CHECK: %[[CLAMP:.+]] = arith.minsi %[[CMAX]], %[[LOW]]
   // CHECK: %[[TRUNC:.+]] = arith.trunci %[[CLAMP]]
   // CHECK: linalg.yield %[[TRUNC]]
-  %0 = tosa.avg_pool2d %arg0 {acc_type = i32, pad = array<i64: 1, 1, 1, 1>, kernel = array<i64: 4, 4>, stride = array<i64: 1, 1>} : (tensor<1x6x34x62xi8>) -> tensor<1x5x33x62xi8>
+  %input_zp = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+  %output_zp = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+  %0 = tosa.avg_pool2d %arg0, %input_zp, %output_zp {acc_type = i32, pad = array<i64: 1, 1, 1, 1>, kernel = array<i64: 4, 4>, stride = array<i64: 1, 1>} : (tensor<1x6x34x62xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x5x33x62xi8>
   return %0 : tensor<1x5x33x62xi8>
 }
 
@@ -439,7 +445,9 @@ func.func @avg_pool_dyn(%arg0: tensor<?x6x34x62xf32>) -> (tensor<?x5x33x62xf32>)
   // CHECK-SAME: outs(%[[FILL]] : tensor<?x5x33x62xf32>) -> tensor<?x5x33x62xf32>
   // CHECK: %[[EMPTY:.+]] = tensor.empty(%[[BATCH]]) : tensor<?x5x33x62xf32>
   // CHECK: %[[GENERIC:.+]] = linalg.generic
-  %0 = tosa.avg_pool2d %arg0 {acc_type = f32, pad = array<i64: 1, 1, 1, 1>, kernel = array<i64: 4, 4>, stride = array<i64: 1, 1>} : (tensor<?x6x34x62xf32>) -> tensor<?x5x33x62xf32>
+  %input_zp = "tosa.const"() <{value = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+  %output_zp = "tosa.const"() <{value = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+  %0 = tosa.avg_pool2d %arg0, %input_zp, %output_zp {acc_type = f32, pad = array<i64: 1, 1, 1, 1>, kernel = array<i64: 4, 4>, stride = array<i64: 1, 1>} : (tensor<?x6x34x62xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x5x33x62xf32>
   return %0 : tensor<?x5x33x62xf32>
 }
 
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir
index 73da2810abe04..ecd5c792e08b6 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir
@@ -23,18 +23,18 @@ func.func @tensor_with_unknown_rank(%arg0: tensor<*xi8>) -> tensor<*xi8> {
 // -----
 
 // check that tosa verify kick in
-func.func @test_avg_pool2d_zero_dim_input(%arg0: tensor<1x0x?x9xf32>) -> tensor<1x7x7x9xf32> {
+func.func @test_avg_pool2d_zero_dim_input(%arg0: tensor<1x0x?x9xf32>, %arg1: tensor<1xf32>, %arg2: tensor<1xf32>) -> tensor<1x7x7x9xf32> {
   // expected-error@+1 {{'tosa.avg_pool2d' op operand #0 must be 4-d tosa-conformant tensor, but got 'tensor<1x0x?x9xf32>'}}
-    %0 = "tosa.avg_pool2d"(%arg0) {acc_type = f32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>}
-      : (tensor<1x0x?x9xf32>) -> tensor<1x7x7x9xf32>
+    %0 = "tosa.avg_pool2d"(%arg0, %arg1, %arg2) {acc_type = f32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>}
+      : (tensor<1x0x?x9xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x7x7x9xf32>
     return %0 : tensor<1x7x7x9xf32>
 }
 
 // -----
 
 // check that --tosa-to-linalg kick in
-func.func @avg_pool2d_with_unsupported_quant_type(%arg0: tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>>) -> tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>> {
+func.func @avg_pool2d_with_unsupported_quant_type(%arg0: tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>>, %arg1: tensor<1xi8>, %arg2: tensor<1xi8>) -> tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>> {
   // expected-error@+1 {{failed to legalize operation 'tosa.avg_pool2d'}}
-  %0 = "tosa.avg_pool2d"(%arg0) {acc_type = i32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} : (tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>>) -> tensor...
[truncated]

Copy link

github-actions bot commented Feb 27, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

This commit changes the zero point attribute to an input to
align with the 1.0 spec.

Change-Id: Ieee6ba824327913bc8462cbcb7a74c0b6dd53d21
Signed-off-by: Luke Hutton <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants