Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][Vector] Move vector.insert canonicalizers for DenseElementsAttr to folders #128040

Merged
merged 1 commit into from
Mar 6, 2025

Conversation

Groverkss
Copy link
Member

This PR moves vector.insert canonicalizers for DenseElementsAttr (splat and non splat case) to folders. Folders are local, and it's always better to implement a folder than a canonicalizer.

This PR is mostly NFC-ish, because the functionality mostly remains same, but is now run as part of a folder, which is why some tests are changed, because GreedyPatternRewriter tries to fold by default.

@Groverkss
Copy link
Member Author

Depends on #127995 , please review top commit only

@llvmbot
Copy link
Member

llvmbot commented Feb 20, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-linalg

Author: Kunwar Grover (Groverkss)

Changes

This PR moves vector.insert canonicalizers for DenseElementsAttr (splat and non splat case) to folders. Folders are local, and it's always better to implement a folder than a canonicalizer.

This PR is mostly NFC-ish, because the functionality mostly remains same, but is now run as part of a folder, which is why some tests are changed, because GreedyPatternRewriter tries to fold by default.


Patch is 24.89 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/128040.diff

7 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+100-141)
  • (modified) mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir (+3-7)
  • (modified) mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir (+7-19)
  • (modified) mlir/test/Dialect/Vector/linearize.mlir (+2-2)
  • (modified) mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir (+2-3)
  • (modified) mlir/test/Dialect/Vector/vector-gather-lowering.mlir (+5-9)
  • (modified) mlir/test/Dialect/Vector/vector-mask-lowering-transforms.mlir (+6-11)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index d5f3634377e4c..f21ad23a03c6e 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2047,6 +2047,49 @@ static Attribute foldPoisonSrcExtractOp(Attribute srcAttr) {
   return {};
 }
 
+static Attribute foldDenseElementsAttrSrcExtractOp(ExtractOp extractOp,
+                                                   Attribute srcAttr) {
+  auto denseAttr = dyn_cast_if_present<DenseElementsAttr>(srcAttr);
+  if (!denseAttr) {
+    return {};
+  }
+
+  if (denseAttr.isSplat()) {
+    Attribute newAttr = denseAttr.getSplatValue<Attribute>();
+    if (auto vecDstType = llvm::dyn_cast<VectorType>(extractOp.getType()))
+      newAttr = DenseElementsAttr::get(vecDstType, newAttr);
+    return newAttr;
+  }
+
+  auto vecTy = llvm::cast<VectorType>(extractOp.getSourceVectorType());
+  if (vecTy.isScalable())
+    return {};
+
+  if (extractOp.hasDynamicPosition()) {
+    return {};
+  }
+
+  // Calculate the linearized position of the continuous chunk of elements to
+  // extract.
+  llvm::SmallVector<int64_t> completePositions(vecTy.getRank(), 0);
+  copy(extractOp.getStaticPosition(), completePositions.begin());
+  int64_t elemBeginPosition =
+      linearize(completePositions, computeStrides(vecTy.getShape()));
+  auto denseValuesBegin =
+      denseAttr.value_begin<TypedAttr>() + elemBeginPosition;
+
+  TypedAttr newAttr;
+  if (auto resVecTy = llvm::dyn_cast<VectorType>(extractOp.getType())) {
+    SmallVector<Attribute> elementValues(
+        denseValuesBegin, denseValuesBegin + resVecTy.getNumElements());
+    newAttr = DenseElementsAttr::get(resVecTy, elementValues);
+  } else {
+    newAttr = *denseValuesBegin;
+  }
+
+  return newAttr;
+}
+
 OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
   // Fold "vector.extract %v[] : vector<2x2xf32> from vector<2x2xf32>" to %v.
   // Note: Do not fold "vector.extract %v[] : f32 from vector<f32>" (type
@@ -2058,6 +2101,8 @@ OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
     return res;
   if (auto res = foldPoisonSrcExtractOp(adaptor.getVector()))
     return res;
+  if (auto res = foldDenseElementsAttrSrcExtractOp(*this, adaptor.getVector()))
+    return res;
   if (succeeded(foldExtractOpFromExtractChain(*this)))
     return getResult();
   if (auto res = ExtractFromInsertTransposeChainState(*this).fold())
@@ -2121,80 +2166,6 @@ class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
   }
 };
 
-// Pattern to rewrite a ExtractOp(splat ConstantOp) -> ConstantOp.
-class ExtractOpSplatConstantFolder final : public OpRewritePattern<ExtractOp> {
-public:
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(ExtractOp extractOp,
-                                PatternRewriter &rewriter) const override {
-    // Return if 'ExtractOp' operand is not defined by a splat vector
-    // ConstantOp.
-    Value sourceVector = extractOp.getVector();
-    Attribute vectorCst;
-    if (!matchPattern(sourceVector, m_Constant(&vectorCst)))
-      return failure();
-    auto splat = llvm::dyn_cast<SplatElementsAttr>(vectorCst);
-    if (!splat)
-      return failure();
-    TypedAttr newAttr = splat.getSplatValue<TypedAttr>();
-    if (auto vecDstType = llvm::dyn_cast<VectorType>(extractOp.getType()))
-      newAttr = DenseElementsAttr::get(vecDstType, newAttr);
-    rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractOp, newAttr);
-    return success();
-  }
-};
-
-// Pattern to rewrite a ExtractOp(non-splat ConstantOp)[...] -> ConstantOp.
-class ExtractOpNonSplatConstantFolder final
-    : public OpRewritePattern<ExtractOp> {
-public:
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(ExtractOp extractOp,
-                                PatternRewriter &rewriter) const override {
-    // TODO: Canonicalization for dynamic position not implemented yet.
-    if (extractOp.hasDynamicPosition())
-      return failure();
-
-    // Return if 'ExtractOp' operand is not defined by a compatible vector
-    // ConstantOp.
-    Value sourceVector = extractOp.getVector();
-    Attribute vectorCst;
-    if (!matchPattern(sourceVector, m_Constant(&vectorCst)))
-      return failure();
-
-    auto vecTy = llvm::cast<VectorType>(sourceVector.getType());
-    if (vecTy.isScalable())
-      return failure();
-
-    // The splat case is handled by `ExtractOpSplatConstantFolder`.
-    auto dense = llvm::dyn_cast<DenseElementsAttr>(vectorCst);
-    if (!dense || dense.isSplat())
-      return failure();
-
-    // Calculate the linearized position of the continuous chunk of elements to
-    // extract.
-    llvm::SmallVector<int64_t> completePositions(vecTy.getRank(), 0);
-    copy(extractOp.getStaticPosition(), completePositions.begin());
-    int64_t elemBeginPosition =
-        linearize(completePositions, computeStrides(vecTy.getShape()));
-    auto denseValuesBegin = dense.value_begin<TypedAttr>() + elemBeginPosition;
-
-    TypedAttr newAttr;
-    if (auto resVecTy = llvm::dyn_cast<VectorType>(extractOp.getType())) {
-      SmallVector<Attribute> elementValues(
-          denseValuesBegin, denseValuesBegin + resVecTy.getNumElements());
-      newAttr = DenseElementsAttr::get(resVecTy, elementValues);
-    } else {
-      newAttr = *denseValuesBegin;
-    }
-
-    rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractOp, newAttr);
-    return success();
-  }
-};
-
 // Pattern to rewrite a ExtractOp(CreateMask) -> CreateMask.
 class ExtractOpFromCreateMask final : public OpRewritePattern<ExtractOp> {
 public:
@@ -2332,8 +2303,7 @@ LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
 
 void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                             MLIRContext *context) {
-  results.add<ExtractOpSplatConstantFolder, ExtractOpNonSplatConstantFolder,
-              ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
+  results.add<ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
   results.add(foldExtractFromShapeCastToShapeCast);
   results.add(foldExtractFromFromElements);
 }
@@ -3043,94 +3013,78 @@ class InsertSplatToSplat final : public OpRewritePattern<InsertOp> {
   }
 };
 
-// Pattern to rewrite a InsertOp(ConstantOp into ConstantOp) -> ConstantOp.
-class InsertOpConstantFolder final : public OpRewritePattern<InsertOp> {
-public:
-  using OpRewritePattern::OpRewritePattern;
-
-  // Do not create constants with more than `vectorSizeFoldThreashold` elements,
-  // unless the source vector constant has a single use.
-  static constexpr int64_t vectorSizeFoldThreshold = 256;
-
-  LogicalResult matchAndRewrite(InsertOp op,
-                                PatternRewriter &rewriter) const override {
-    // TODO: Canonicalization for dynamic position not implemented yet.
-    if (op.hasDynamicPosition())
-      return failure();
+} // namespace
 
-    // Return if 'InsertOp' operand is not defined by a compatible vector
-    // ConstantOp.
-    TypedValue<VectorType> destVector = op.getDest();
-    Attribute vectorDestCst;
-    if (!matchPattern(destVector, m_Constant(&vectorDestCst)))
-      return failure();
-    auto denseDest = llvm::dyn_cast<DenseElementsAttr>(vectorDestCst);
-    if (!denseDest)
-      return failure();
+static Attribute
+foldDenseElementsAttrDestInsertOp(InsertOp insertOp, Attribute srcAttr,
+                                  Attribute dstAttr,
+                                  int64_t maxVectorSizeFoldThreshold) {
+  if (insertOp.hasDynamicPosition())
+    return {};
 
-    VectorType destTy = destVector.getType();
-    if (destTy.isScalable())
-      return failure();
+  auto denseDst = llvm::dyn_cast_if_present<DenseElementsAttr>(dstAttr);
+  if (!denseDst)
+    return {};
 
-    // Make sure we do not create too many large constants.
-    if (destTy.getNumElements() > vectorSizeFoldThreshold &&
-        !destVector.hasOneUse())
-      return failure();
+  if (!srcAttr) {
+    return {};
+  }
 
-    Value sourceValue = op.getSource();
-    Attribute sourceCst;
-    if (!matchPattern(sourceValue, m_Constant(&sourceCst)))
-      return failure();
+  VectorType destTy = insertOp.getDestVectorType();
+  if (destTy.isScalable())
+    return {};
 
-    // Calculate the linearized position of the continuous chunk of elements to
-    // insert.
-    llvm::SmallVector<int64_t> completePositions(destTy.getRank(), 0);
-    copy(op.getStaticPosition(), completePositions.begin());
-    int64_t insertBeginPosition =
-        linearize(completePositions, computeStrides(destTy.getShape()));
-
-    SmallVector<Attribute> insertedValues;
-    Type destEltType = destTy.getElementType();
-
-    // The `convertIntegerAttr` method specifically handles the case
-    // for `llvm.mlir.constant` which can hold an attribute with a
-    // different type than the return type.
-    if (auto denseSource = llvm::dyn_cast<DenseElementsAttr>(sourceCst)) {
-      for (auto value : denseSource.getValues<Attribute>())
-        insertedValues.push_back(convertIntegerAttr(value, destEltType));
-    } else {
-      insertedValues.push_back(convertIntegerAttr(sourceCst, destEltType));
-    }
+  // Make sure we do not create too many large constants.
+  if (destTy.getNumElements() > maxVectorSizeFoldThreshold &&
+      !insertOp->hasOneUse())
+    return {};
 
-    auto allValues = llvm::to_vector(denseDest.getValues<Attribute>());
-    copy(insertedValues, allValues.begin() + insertBeginPosition);
-    auto newAttr = DenseElementsAttr::get(destTy, allValues);
+  // Calculate the linearized position of the continuous chunk of elements to
+  // insert.
+  llvm::SmallVector<int64_t> completePositions(destTy.getRank(), 0);
+  copy(insertOp.getStaticPosition(), completePositions.begin());
+  int64_t insertBeginPosition =
+      linearize(completePositions, computeStrides(destTy.getShape()));
 
-    rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, newAttr);
-    return success();
-  }
+  SmallVector<Attribute> insertedValues;
+  Type destEltType = destTy.getElementType();
 
-private:
   /// Converts the expected type to an IntegerAttr if there's
   /// a mismatch.
-  Attribute convertIntegerAttr(Attribute attr, Type expectedType) const {
+  auto convertIntegerAttr = [](Attribute attr, Type expectedType) -> Attribute {
     if (auto intAttr = mlir::dyn_cast<IntegerAttr>(attr)) {
       if (intAttr.getType() != expectedType)
         return IntegerAttr::get(expectedType, intAttr.getInt());
     }
     return attr;
+  };
+
+  // The `convertIntegerAttr` method specifically handles the case
+  // for `llvm.mlir.constant` which can hold an attribute with a
+  // different type than the return type.
+  if (auto denseSource = llvm::dyn_cast<DenseElementsAttr>(srcAttr)) {
+    for (auto value : denseSource.getValues<Attribute>())
+      insertedValues.push_back(convertIntegerAttr(value, destEltType));
+  } else {
+    insertedValues.push_back(convertIntegerAttr(srcAttr, destEltType));
   }
-};
 
-} // namespace
+  auto allValues = llvm::to_vector(denseDst.getValues<Attribute>());
+  copy(insertedValues, allValues.begin() + insertBeginPosition);
+  auto newAttr = DenseElementsAttr::get(destTy, allValues);
+
+  return newAttr;
+}
 
 void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                            MLIRContext *context) {
-  results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat,
-              InsertOpConstantFolder>(context);
+  results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat>(context);
 }
 
 OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
+  // Do not create constants with more than `vectorSizeFoldThreashold` elements,
+  // unless the source vector constant has a single use.
+  constexpr int64_t vectorSizeFoldThreshold = 256;
   // Fold "vector.insert %v, %dest [] : vector<2x2xf32> from vector<2x2xf32>" to
   // %v. Note: Do not fold "vector.insert %v, %dest [] : f32 into vector<f32>"
   // (type mismatch).
@@ -3142,6 +3096,11 @@ OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
   if (auto res = foldPoisonIndexInsertExtractOp(
           getContext(), adaptor.getStaticPosition(), kPoisonIndex))
     return res;
+  if (auto res = foldDenseElementsAttrDestInsertOp(*this, adaptor.getSource(),
+                                                   adaptor.getDest(),
+                                                   vectorSizeFoldThreshold)) {
+    return res;
+  }
 
   return {};
 }
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 36b37a137ac1e..1ab28b9df2d19 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1517,13 +1517,9 @@ func.func @constant_mask_2d() -> vector<4x4xi1> {
 }
 
 // CHECK-LABEL: func @constant_mask_2d
-// CHECK: %[[VAL_0:.*]] = arith.constant dense<[true, true, false, false]> : vector<4xi1>
-// CHECK: %[[VAL_1:.*]] = arith.constant dense<false> : vector<4x4xi1>
-// CHECK: %[[VAL_2:.*]] = builtin.unrealized_conversion_cast %[[VAL_1]] : vector<4x4xi1> to !llvm.array<4 x vector<4xi1>>
-// CHECK: %[[VAL_3:.*]] = llvm.insertvalue %[[VAL_0]], %[[VAL_2]][0] : !llvm.array<4 x vector<4xi1>>
-// CHECK: %[[VAL_4:.*]] = llvm.insertvalue %[[VAL_0]], %[[VAL_3]][1] : !llvm.array<4 x vector<4xi1>>
-// CHECK: %[[VAL_5:.*]] = builtin.unrealized_conversion_cast %[[VAL_4]] : !llvm.array<4 x vector<4xi1>> to vector<4x4xi1>
-// CHECK: return %[[VAL_5]] : vector<4x4xi1>
+// CHECK: %[[VAL_0:.*]] = arith.constant 
+// CHECK-SAME{LITERAL}: dense<[[true, true, false, false], [true, true, false, false], [false, false, false, false], [false, false, false, false]]> : vector<4x4xi1>
+// CHECK: return %[[VAL_0]] : vector<4x4xi1>
 
 // -----
 
diff --git a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
index e66fbe968d9b0..cd83e1239fdda 100644
--- a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
@@ -32,14 +32,8 @@ func.func @vectorize_nd_tensor_extract_transfer_read_basic(
 
 // CHECK-DAG:  %[[C0:.+]] = arith.constant 0 : index
 // CHECK-DAG:  %[[CST:.+]] = arith.constant 0.000000e+00 : f32
-// CHECK-DAG:  %[[CST_0:.+]] = arith.constant dense<0> : vector<1xindex>
-// CHECK-DAG:  %[[CST_1:.+]] = arith.constant dense<[0, 1, 2]> : vector<3xindex>
 
-// CHECK-DAG: %[[IDX1:.+]] = vector.extract %[[CST_0]][0] : index from vector<1xindex>
-// CHECK-DAG: %[[IDX2:.+]] = vector.extract %[[CST_0]][0] : index from vector<1xindex>
-// CHECK-DAG: %[[IDX3:.+]] = vector.extract %[[CST_1]][0] : index from vector<3xindex>
-
-// CHECK:   %[[READ:.*]] = vector.transfer_read %[[ARG0]][%[[IDX1]], %[[IDX2]], %[[IDX3]]], %[[CST]] {in_bounds = [true, true, true]} : tensor<3x3x3xf32>, vector<1x1x3xf32>
+// CHECK:   %[[READ:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]], %[[C0]]], %[[CST]] {in_bounds = [true, true, true]} : tensor<3x3x3xf32>, vector<1x1x3xf32>
 // CHECK:   vector.transfer_write %[[READ]], %[[ARG1]][%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<1x1x3xf32>, tensor<1x1x3xf32>
 
  // -----
@@ -175,16 +169,12 @@ func.func @vectorize_nd_tensor_extract_with_maxsi_contiguous(%arg0: tensor<80x16
 // CHECK-LABEL:   func.func @vectorize_nd_tensor_extract_with_maxsi_contiguous(
 // CHECK-SAME:                                                                 %[[VAL_0:.*]]: tensor<80x16xf32>,
 // CHECK-SAME:                                                                 %[[VAL_1:.*]]: tensor<1x4xf32>) -> tensor<1x4xf32> {
-// CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 0 : index
-// CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 0.000000e+00 : f32
-
-// CHECK-DAG:       %[[CST_0:.+]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
-// CHECK-DAG:       %[[CST_1:.+]] = arith.constant dense<16> : vector<4x1xindex>
-// CHECK-DAG:       %[[IDX0:.+]] = vector.extract %[[CST_1]][0, 0] : index from vector<4x1xindex>
-// CHECK-DAG:       %[[IDX1:.+]] = vector.extract %[[CST_0]][0] : index from vector<4xindex>
+// CHECK-DAG:       %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:       %[[C16:.*]] = arith.constant 16 : index
+// CHECK-DAG:       %[[CST:.*]] = arith.constant 0.000000e+00 : f32
 
-// CHECK:           %[[VAL_8:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[IDX0]], %[[IDX1]]], %[[VAL_5]] {in_bounds = [true, true]} : tensor<80x16xf32>, vector<1x4xf32>
-// CHECK:           %[[VAL_9:.*]] = vector.transfer_write %[[VAL_8]], %[[VAL_1]]{{\[}}%[[VAL_4]], %[[VAL_4]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32>
+// CHECK:           %[[VAL_8:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[C16]], %[[C0]]], %[[CST]] {in_bounds = [true, true]} : tensor<80x16xf32>, vector<1x4xf32>
+// CHECK:           %[[VAL_9:.*]] = vector.transfer_write %[[VAL_8]], %[[VAL_1]]{{\[}}%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32>
 // CHECK:           return %[[VAL_9]] : tensor<1x4xf32>
 // CHECK:         }
 
@@ -675,9 +665,7 @@ func.func @scalar_read_with_broadcast_from_column_tensor(%init: tensor<1x1x4xi32
 // CHECK-DAG:       %[[PAD:.*]] = arith.constant 0 : i32
 // CHECK-DAG:       %[[C0:.*]] = arith.constant 0 : index
 // CHECK-DAG:       %[[SRC:.*]] = arith.constant dense<{{\[\[}}0], [1], [2], [3], [4], [5], [6], [7], [8], [9], [10], [11], [12], [13], [14]]> : tensor<15x1xi32>
-// CHECK-DAG:       %[[IDX_VEC:.*]] = arith.constant dense<0> : vector<1xindex>
-// CHECK:           %[[IDX_ELT:.*]] = vector.extract %[[IDX_VEC]][0] : index from vector<1xindex>
-// CHECK:           %[[READ:.*]] = vector.transfer_read %[[SRC]]{{\[}}%[[IDX_ELT]], %[[C0]]], %[[PAD]] : tensor<15x1xi32>, vector<i32>
+// CHECK:           %[[READ:.*]] = vector.transfer_read %[[SRC]]{{\[}}%[[C0]], %[[C0]]], %[[PAD]] : tensor<15x1xi32>, vector<i32>
 // CHECK:           %[[READ_BCAST:.*]] = vector.broadcast %[[READ]] : vector<i32> to vector<1x1x4xi32>
 // CHECK:           %[[RES:.*]] = vector.transfer_write %[[READ_BCAST]], %[[INIT]][%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<1x1x4xi32>, tensor<1x1x4xi32>
 
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 99b1bbab1eede..8e5ddbfffcdd9 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -310,12 +310,12 @@ func.func @test_vector_insert_scalable(%arg0: vector<2x8x[4]xf32>, %arg1: vector
 // -----
 
 // ALL-LABEL: test_vector_extract_scalar
-func.func @test_vector_extract_scalar() {
+func.func @test_vector_extract_scalar(%idx : index) {
   %cst = arith.constant dense<[1, 2, 3, 4]> : vector<4xi32>
   // ALL-NOT: vector.shuffle
   // ALL:     vector.extract
   // ALL-NOT: vector.shuffle
-  %0 = vector.extract %cst[0] : i32 from vector<4xi32>
+  %0 = vector.extract %cst[%idx] : i32 from vector<4xi32>
   return
 }
 
diff --git a/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir b/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir
index c5cb09b9aa9f9..b4ebb14b8829e 100644
--- a/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir
+++ b/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir
@@ -101,9 +101,8 @@ func.func @transfer_read_2d_extract(%m: memref<?x?x?x?xf32>, %idx: index, %idx2:
 
 // CHECK-LABEL: func @transfer_write_arith_constant(
 //  CHECK-SAME:     %[[m:.*]]: memref<?x?x?xf32>, %[[idx:.*]]: index
-//       CHECK:   %[[cst:.*]] = arith.constant dense<5.000000e+00> : vector<1x1xf32>
-//       CHECK:   %[[extract:.*]] = vector.extract %[[cst]][0, 0] : f32 from vector<1x1xf32>
-//       CHECK:   memref.store %[[extract]], %[[m]][%[[idx]], %[[idx]], %[[idx]]]
+//       CHECK:   %[[cst:.*]] = arith.constant 5.000000e+00 : f32
+//       CHECK:   memref.store %[[cst]], %[[m]][%[[idx]], %[[idx]], %[[idx]]]
 func.func @transfer_write_arith_constant(%m: memref<?x?x?xf32>, %idx: index) {
   %cst = arith.constant dense<5.000000e+00> : vector<1x1xf32>
   vector.transfer_write %cst, %m[%idx, %idx, %idx] : vector<1x1xf32>, memref<?x?x?xf32>
diff --git a/mlir/test/Dialect/Vector/vector-gather-lowering.mlir b/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
index 20e9400ed698d..5be267c1be984 100644
--- a...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Feb 20, 2025

@llvm/pr-subscribers-mlir-vector

Author: Kunwar Grover (Groverkss)

Changes

This PR moves vector.insert canonicalizers for DenseElementsAttr (splat and non splat case) to folders. Folders are local, and it's always better to implement a folder than a canonicalizer.

This PR is mostly NFC-ish, because the functionality mostly remains same, but is now run as part of a folder, which is why some tests are changed, because GreedyPatternRewriter tries to fold by default.


Patch is 24.89 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/128040.diff

7 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+100-141)
  • (modified) mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir (+3-7)
  • (modified) mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir (+7-19)
  • (modified) mlir/test/Dialect/Vector/linearize.mlir (+2-2)
  • (modified) mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir (+2-3)
  • (modified) mlir/test/Dialect/Vector/vector-gather-lowering.mlir (+5-9)
  • (modified) mlir/test/Dialect/Vector/vector-mask-lowering-transforms.mlir (+6-11)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index d5f3634377e4c..f21ad23a03c6e 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2047,6 +2047,49 @@ static Attribute foldPoisonSrcExtractOp(Attribute srcAttr) {
   return {};
 }
 
+static Attribute foldDenseElementsAttrSrcExtractOp(ExtractOp extractOp,
+                                                   Attribute srcAttr) {
+  auto denseAttr = dyn_cast_if_present<DenseElementsAttr>(srcAttr);
+  if (!denseAttr) {
+    return {};
+  }
+
+  if (denseAttr.isSplat()) {
+    Attribute newAttr = denseAttr.getSplatValue<Attribute>();
+    if (auto vecDstType = llvm::dyn_cast<VectorType>(extractOp.getType()))
+      newAttr = DenseElementsAttr::get(vecDstType, newAttr);
+    return newAttr;
+  }
+
+  auto vecTy = llvm::cast<VectorType>(extractOp.getSourceVectorType());
+  if (vecTy.isScalable())
+    return {};
+
+  if (extractOp.hasDynamicPosition()) {
+    return {};
+  }
+
+  // Calculate the linearized position of the continuous chunk of elements to
+  // extract.
+  llvm::SmallVector<int64_t> completePositions(vecTy.getRank(), 0);
+  copy(extractOp.getStaticPosition(), completePositions.begin());
+  int64_t elemBeginPosition =
+      linearize(completePositions, computeStrides(vecTy.getShape()));
+  auto denseValuesBegin =
+      denseAttr.value_begin<TypedAttr>() + elemBeginPosition;
+
+  TypedAttr newAttr;
+  if (auto resVecTy = llvm::dyn_cast<VectorType>(extractOp.getType())) {
+    SmallVector<Attribute> elementValues(
+        denseValuesBegin, denseValuesBegin + resVecTy.getNumElements());
+    newAttr = DenseElementsAttr::get(resVecTy, elementValues);
+  } else {
+    newAttr = *denseValuesBegin;
+  }
+
+  return newAttr;
+}
+
 OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
   // Fold "vector.extract %v[] : vector<2x2xf32> from vector<2x2xf32>" to %v.
   // Note: Do not fold "vector.extract %v[] : f32 from vector<f32>" (type
@@ -2058,6 +2101,8 @@ OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
     return res;
   if (auto res = foldPoisonSrcExtractOp(adaptor.getVector()))
     return res;
+  if (auto res = foldDenseElementsAttrSrcExtractOp(*this, adaptor.getVector()))
+    return res;
   if (succeeded(foldExtractOpFromExtractChain(*this)))
     return getResult();
   if (auto res = ExtractFromInsertTransposeChainState(*this).fold())
@@ -2121,80 +2166,6 @@ class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
   }
 };
 
-// Pattern to rewrite a ExtractOp(splat ConstantOp) -> ConstantOp.
-class ExtractOpSplatConstantFolder final : public OpRewritePattern<ExtractOp> {
-public:
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(ExtractOp extractOp,
-                                PatternRewriter &rewriter) const override {
-    // Return if 'ExtractOp' operand is not defined by a splat vector
-    // ConstantOp.
-    Value sourceVector = extractOp.getVector();
-    Attribute vectorCst;
-    if (!matchPattern(sourceVector, m_Constant(&vectorCst)))
-      return failure();
-    auto splat = llvm::dyn_cast<SplatElementsAttr>(vectorCst);
-    if (!splat)
-      return failure();
-    TypedAttr newAttr = splat.getSplatValue<TypedAttr>();
-    if (auto vecDstType = llvm::dyn_cast<VectorType>(extractOp.getType()))
-      newAttr = DenseElementsAttr::get(vecDstType, newAttr);
-    rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractOp, newAttr);
-    return success();
-  }
-};
-
-// Pattern to rewrite a ExtractOp(non-splat ConstantOp)[...] -> ConstantOp.
-class ExtractOpNonSplatConstantFolder final
-    : public OpRewritePattern<ExtractOp> {
-public:
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(ExtractOp extractOp,
-                                PatternRewriter &rewriter) const override {
-    // TODO: Canonicalization for dynamic position not implemented yet.
-    if (extractOp.hasDynamicPosition())
-      return failure();
-
-    // Return if 'ExtractOp' operand is not defined by a compatible vector
-    // ConstantOp.
-    Value sourceVector = extractOp.getVector();
-    Attribute vectorCst;
-    if (!matchPattern(sourceVector, m_Constant(&vectorCst)))
-      return failure();
-
-    auto vecTy = llvm::cast<VectorType>(sourceVector.getType());
-    if (vecTy.isScalable())
-      return failure();
-
-    // The splat case is handled by `ExtractOpSplatConstantFolder`.
-    auto dense = llvm::dyn_cast<DenseElementsAttr>(vectorCst);
-    if (!dense || dense.isSplat())
-      return failure();
-
-    // Calculate the linearized position of the continuous chunk of elements to
-    // extract.
-    llvm::SmallVector<int64_t> completePositions(vecTy.getRank(), 0);
-    copy(extractOp.getStaticPosition(), completePositions.begin());
-    int64_t elemBeginPosition =
-        linearize(completePositions, computeStrides(vecTy.getShape()));
-    auto denseValuesBegin = dense.value_begin<TypedAttr>() + elemBeginPosition;
-
-    TypedAttr newAttr;
-    if (auto resVecTy = llvm::dyn_cast<VectorType>(extractOp.getType())) {
-      SmallVector<Attribute> elementValues(
-          denseValuesBegin, denseValuesBegin + resVecTy.getNumElements());
-      newAttr = DenseElementsAttr::get(resVecTy, elementValues);
-    } else {
-      newAttr = *denseValuesBegin;
-    }
-
-    rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractOp, newAttr);
-    return success();
-  }
-};
-
 // Pattern to rewrite a ExtractOp(CreateMask) -> CreateMask.
 class ExtractOpFromCreateMask final : public OpRewritePattern<ExtractOp> {
 public:
@@ -2332,8 +2303,7 @@ LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
 
 void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                             MLIRContext *context) {
-  results.add<ExtractOpSplatConstantFolder, ExtractOpNonSplatConstantFolder,
-              ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
+  results.add<ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
   results.add(foldExtractFromShapeCastToShapeCast);
   results.add(foldExtractFromFromElements);
 }
@@ -3043,94 +3013,78 @@ class InsertSplatToSplat final : public OpRewritePattern<InsertOp> {
   }
 };
 
-// Pattern to rewrite a InsertOp(ConstantOp into ConstantOp) -> ConstantOp.
-class InsertOpConstantFolder final : public OpRewritePattern<InsertOp> {
-public:
-  using OpRewritePattern::OpRewritePattern;
-
-  // Do not create constants with more than `vectorSizeFoldThreashold` elements,
-  // unless the source vector constant has a single use.
-  static constexpr int64_t vectorSizeFoldThreshold = 256;
-
-  LogicalResult matchAndRewrite(InsertOp op,
-                                PatternRewriter &rewriter) const override {
-    // TODO: Canonicalization for dynamic position not implemented yet.
-    if (op.hasDynamicPosition())
-      return failure();
+} // namespace
 
-    // Return if 'InsertOp' operand is not defined by a compatible vector
-    // ConstantOp.
-    TypedValue<VectorType> destVector = op.getDest();
-    Attribute vectorDestCst;
-    if (!matchPattern(destVector, m_Constant(&vectorDestCst)))
-      return failure();
-    auto denseDest = llvm::dyn_cast<DenseElementsAttr>(vectorDestCst);
-    if (!denseDest)
-      return failure();
+static Attribute
+foldDenseElementsAttrDestInsertOp(InsertOp insertOp, Attribute srcAttr,
+                                  Attribute dstAttr,
+                                  int64_t maxVectorSizeFoldThreshold) {
+  if (insertOp.hasDynamicPosition())
+    return {};
 
-    VectorType destTy = destVector.getType();
-    if (destTy.isScalable())
-      return failure();
+  auto denseDst = llvm::dyn_cast_if_present<DenseElementsAttr>(dstAttr);
+  if (!denseDst)
+    return {};
 
-    // Make sure we do not create too many large constants.
-    if (destTy.getNumElements() > vectorSizeFoldThreshold &&
-        !destVector.hasOneUse())
-      return failure();
+  if (!srcAttr) {
+    return {};
+  }
 
-    Value sourceValue = op.getSource();
-    Attribute sourceCst;
-    if (!matchPattern(sourceValue, m_Constant(&sourceCst)))
-      return failure();
+  VectorType destTy = insertOp.getDestVectorType();
+  if (destTy.isScalable())
+    return {};
 
-    // Calculate the linearized position of the continuous chunk of elements to
-    // insert.
-    llvm::SmallVector<int64_t> completePositions(destTy.getRank(), 0);
-    copy(op.getStaticPosition(), completePositions.begin());
-    int64_t insertBeginPosition =
-        linearize(completePositions, computeStrides(destTy.getShape()));
-
-    SmallVector<Attribute> insertedValues;
-    Type destEltType = destTy.getElementType();
-
-    // The `convertIntegerAttr` method specifically handles the case
-    // for `llvm.mlir.constant` which can hold an attribute with a
-    // different type than the return type.
-    if (auto denseSource = llvm::dyn_cast<DenseElementsAttr>(sourceCst)) {
-      for (auto value : denseSource.getValues<Attribute>())
-        insertedValues.push_back(convertIntegerAttr(value, destEltType));
-    } else {
-      insertedValues.push_back(convertIntegerAttr(sourceCst, destEltType));
-    }
+  // Make sure we do not create too many large constants.
+  if (destTy.getNumElements() > maxVectorSizeFoldThreshold &&
+      !insertOp->hasOneUse())
+    return {};
 
-    auto allValues = llvm::to_vector(denseDest.getValues<Attribute>());
-    copy(insertedValues, allValues.begin() + insertBeginPosition);
-    auto newAttr = DenseElementsAttr::get(destTy, allValues);
+  // Calculate the linearized position of the continuous chunk of elements to
+  // insert.
+  llvm::SmallVector<int64_t> completePositions(destTy.getRank(), 0);
+  copy(insertOp.getStaticPosition(), completePositions.begin());
+  int64_t insertBeginPosition =
+      linearize(completePositions, computeStrides(destTy.getShape()));
 
-    rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, newAttr);
-    return success();
-  }
+  SmallVector<Attribute> insertedValues;
+  Type destEltType = destTy.getElementType();
 
-private:
   /// Converts the expected type to an IntegerAttr if there's
   /// a mismatch.
-  Attribute convertIntegerAttr(Attribute attr, Type expectedType) const {
+  auto convertIntegerAttr = [](Attribute attr, Type expectedType) -> Attribute {
     if (auto intAttr = mlir::dyn_cast<IntegerAttr>(attr)) {
       if (intAttr.getType() != expectedType)
         return IntegerAttr::get(expectedType, intAttr.getInt());
     }
     return attr;
+  };
+
+  // The `convertIntegerAttr` method specifically handles the case
+  // for `llvm.mlir.constant` which can hold an attribute with a
+  // different type than the return type.
+  if (auto denseSource = llvm::dyn_cast<DenseElementsAttr>(srcAttr)) {
+    for (auto value : denseSource.getValues<Attribute>())
+      insertedValues.push_back(convertIntegerAttr(value, destEltType));
+  } else {
+    insertedValues.push_back(convertIntegerAttr(srcAttr, destEltType));
   }
-};
 
-} // namespace
+  auto allValues = llvm::to_vector(denseDst.getValues<Attribute>());
+  copy(insertedValues, allValues.begin() + insertBeginPosition);
+  auto newAttr = DenseElementsAttr::get(destTy, allValues);
+
+  return newAttr;
+}
 
 void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                            MLIRContext *context) {
-  results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat,
-              InsertOpConstantFolder>(context);
+  results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat>(context);
 }
 
 OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
+  // Do not create constants with more than `vectorSizeFoldThreashold` elements,
+  // unless the source vector constant has a single use.
+  constexpr int64_t vectorSizeFoldThreshold = 256;
   // Fold "vector.insert %v, %dest [] : vector<2x2xf32> from vector<2x2xf32>" to
   // %v. Note: Do not fold "vector.insert %v, %dest [] : f32 into vector<f32>"
   // (type mismatch).
@@ -3142,6 +3096,11 @@ OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
   if (auto res = foldPoisonIndexInsertExtractOp(
           getContext(), adaptor.getStaticPosition(), kPoisonIndex))
     return res;
+  if (auto res = foldDenseElementsAttrDestInsertOp(*this, adaptor.getSource(),
+                                                   adaptor.getDest(),
+                                                   vectorSizeFoldThreshold)) {
+    return res;
+  }
 
   return {};
 }
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 36b37a137ac1e..1ab28b9df2d19 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1517,13 +1517,9 @@ func.func @constant_mask_2d() -> vector<4x4xi1> {
 }
 
 // CHECK-LABEL: func @constant_mask_2d
-// CHECK: %[[VAL_0:.*]] = arith.constant dense<[true, true, false, false]> : vector<4xi1>
-// CHECK: %[[VAL_1:.*]] = arith.constant dense<false> : vector<4x4xi1>
-// CHECK: %[[VAL_2:.*]] = builtin.unrealized_conversion_cast %[[VAL_1]] : vector<4x4xi1> to !llvm.array<4 x vector<4xi1>>
-// CHECK: %[[VAL_3:.*]] = llvm.insertvalue %[[VAL_0]], %[[VAL_2]][0] : !llvm.array<4 x vector<4xi1>>
-// CHECK: %[[VAL_4:.*]] = llvm.insertvalue %[[VAL_0]], %[[VAL_3]][1] : !llvm.array<4 x vector<4xi1>>
-// CHECK: %[[VAL_5:.*]] = builtin.unrealized_conversion_cast %[[VAL_4]] : !llvm.array<4 x vector<4xi1>> to vector<4x4xi1>
-// CHECK: return %[[VAL_5]] : vector<4x4xi1>
+// CHECK: %[[VAL_0:.*]] = arith.constant 
+// CHECK-SAME{LITERAL}: dense<[[true, true, false, false], [true, true, false, false], [false, false, false, false], [false, false, false, false]]> : vector<4x4xi1>
+// CHECK: return %[[VAL_0]] : vector<4x4xi1>
 
 // -----
 
diff --git a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
index e66fbe968d9b0..cd83e1239fdda 100644
--- a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
@@ -32,14 +32,8 @@ func.func @vectorize_nd_tensor_extract_transfer_read_basic(
 
 // CHECK-DAG:  %[[C0:.+]] = arith.constant 0 : index
 // CHECK-DAG:  %[[CST:.+]] = arith.constant 0.000000e+00 : f32
-// CHECK-DAG:  %[[CST_0:.+]] = arith.constant dense<0> : vector<1xindex>
-// CHECK-DAG:  %[[CST_1:.+]] = arith.constant dense<[0, 1, 2]> : vector<3xindex>
 
-// CHECK-DAG: %[[IDX1:.+]] = vector.extract %[[CST_0]][0] : index from vector<1xindex>
-// CHECK-DAG: %[[IDX2:.+]] = vector.extract %[[CST_0]][0] : index from vector<1xindex>
-// CHECK-DAG: %[[IDX3:.+]] = vector.extract %[[CST_1]][0] : index from vector<3xindex>
-
-// CHECK:   %[[READ:.*]] = vector.transfer_read %[[ARG0]][%[[IDX1]], %[[IDX2]], %[[IDX3]]], %[[CST]] {in_bounds = [true, true, true]} : tensor<3x3x3xf32>, vector<1x1x3xf32>
+// CHECK:   %[[READ:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]], %[[C0]]], %[[CST]] {in_bounds = [true, true, true]} : tensor<3x3x3xf32>, vector<1x1x3xf32>
 // CHECK:   vector.transfer_write %[[READ]], %[[ARG1]][%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<1x1x3xf32>, tensor<1x1x3xf32>
 
  // -----
@@ -175,16 +169,12 @@ func.func @vectorize_nd_tensor_extract_with_maxsi_contiguous(%arg0: tensor<80x16
 // CHECK-LABEL:   func.func @vectorize_nd_tensor_extract_with_maxsi_contiguous(
 // CHECK-SAME:                                                                 %[[VAL_0:.*]]: tensor<80x16xf32>,
 // CHECK-SAME:                                                                 %[[VAL_1:.*]]: tensor<1x4xf32>) -> tensor<1x4xf32> {
-// CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 0 : index
-// CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 0.000000e+00 : f32
-
-// CHECK-DAG:       %[[CST_0:.+]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
-// CHECK-DAG:       %[[CST_1:.+]] = arith.constant dense<16> : vector<4x1xindex>
-// CHECK-DAG:       %[[IDX0:.+]] = vector.extract %[[CST_1]][0, 0] : index from vector<4x1xindex>
-// CHECK-DAG:       %[[IDX1:.+]] = vector.extract %[[CST_0]][0] : index from vector<4xindex>
+// CHECK-DAG:       %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:       %[[C16:.*]] = arith.constant 16 : index
+// CHECK-DAG:       %[[CST:.*]] = arith.constant 0.000000e+00 : f32
 
-// CHECK:           %[[VAL_8:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[IDX0]], %[[IDX1]]], %[[VAL_5]] {in_bounds = [true, true]} : tensor<80x16xf32>, vector<1x4xf32>
-// CHECK:           %[[VAL_9:.*]] = vector.transfer_write %[[VAL_8]], %[[VAL_1]]{{\[}}%[[VAL_4]], %[[VAL_4]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32>
+// CHECK:           %[[VAL_8:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[C16]], %[[C0]]], %[[CST]] {in_bounds = [true, true]} : tensor<80x16xf32>, vector<1x4xf32>
+// CHECK:           %[[VAL_9:.*]] = vector.transfer_write %[[VAL_8]], %[[VAL_1]]{{\[}}%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32>
 // CHECK:           return %[[VAL_9]] : tensor<1x4xf32>
 // CHECK:         }
 
@@ -675,9 +665,7 @@ func.func @scalar_read_with_broadcast_from_column_tensor(%init: tensor<1x1x4xi32
 // CHECK-DAG:       %[[PAD:.*]] = arith.constant 0 : i32
 // CHECK-DAG:       %[[C0:.*]] = arith.constant 0 : index
 // CHECK-DAG:       %[[SRC:.*]] = arith.constant dense<{{\[\[}}0], [1], [2], [3], [4], [5], [6], [7], [8], [9], [10], [11], [12], [13], [14]]> : tensor<15x1xi32>
-// CHECK-DAG:       %[[IDX_VEC:.*]] = arith.constant dense<0> : vector<1xindex>
-// CHECK:           %[[IDX_ELT:.*]] = vector.extract %[[IDX_VEC]][0] : index from vector<1xindex>
-// CHECK:           %[[READ:.*]] = vector.transfer_read %[[SRC]]{{\[}}%[[IDX_ELT]], %[[C0]]], %[[PAD]] : tensor<15x1xi32>, vector<i32>
+// CHECK:           %[[READ:.*]] = vector.transfer_read %[[SRC]]{{\[}}%[[C0]], %[[C0]]], %[[PAD]] : tensor<15x1xi32>, vector<i32>
 // CHECK:           %[[READ_BCAST:.*]] = vector.broadcast %[[READ]] : vector<i32> to vector<1x1x4xi32>
 // CHECK:           %[[RES:.*]] = vector.transfer_write %[[READ_BCAST]], %[[INIT]][%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<1x1x4xi32>, tensor<1x1x4xi32>
 
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 99b1bbab1eede..8e5ddbfffcdd9 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -310,12 +310,12 @@ func.func @test_vector_insert_scalable(%arg0: vector<2x8x[4]xf32>, %arg1: vector
 // -----
 
 // ALL-LABEL: test_vector_extract_scalar
-func.func @test_vector_extract_scalar() {
+func.func @test_vector_extract_scalar(%idx : index) {
   %cst = arith.constant dense<[1, 2, 3, 4]> : vector<4xi32>
   // ALL-NOT: vector.shuffle
   // ALL:     vector.extract
   // ALL-NOT: vector.shuffle
-  %0 = vector.extract %cst[0] : i32 from vector<4xi32>
+  %0 = vector.extract %cst[%idx] : i32 from vector<4xi32>
   return
 }
 
diff --git a/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir b/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir
index c5cb09b9aa9f9..b4ebb14b8829e 100644
--- a/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir
+++ b/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir
@@ -101,9 +101,8 @@ func.func @transfer_read_2d_extract(%m: memref<?x?x?x?xf32>, %idx: index, %idx2:
 
 // CHECK-LABEL: func @transfer_write_arith_constant(
 //  CHECK-SAME:     %[[m:.*]]: memref<?x?x?xf32>, %[[idx:.*]]: index
-//       CHECK:   %[[cst:.*]] = arith.constant dense<5.000000e+00> : vector<1x1xf32>
-//       CHECK:   %[[extract:.*]] = vector.extract %[[cst]][0, 0] : f32 from vector<1x1xf32>
-//       CHECK:   memref.store %[[extract]], %[[m]][%[[idx]], %[[idx]], %[[idx]]]
+//       CHECK:   %[[cst:.*]] = arith.constant 5.000000e+00 : f32
+//       CHECK:   memref.store %[[cst]], %[[m]][%[[idx]], %[[idx]], %[[idx]]]
 func.func @transfer_write_arith_constant(%m: memref<?x?x?xf32>, %idx: index) {
   %cst = arith.constant dense<5.000000e+00> : vector<1x1xf32>
   vector.transfer_write %cst, %m[%idx, %idx, %idx] : vector<1x1xf32>, memref<?x?x?xf32>
diff --git a/mlir/test/Dialect/Vector/vector-gather-lowering.mlir b/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
index 20e9400ed698d..5be267c1be984 100644
--- a...
[truncated]

@Groverkss Groverkss force-pushed the insert-op-canonicalizer-to-folder branch from 0d08560 to 14f025f Compare March 3, 2025 10:58
@Groverkss
Copy link
Member Author

Groverkss commented Mar 3, 2025

This PR is now ready to review. It's a InsertOp variant of #127995 so I'm hoping it should be fairly non-controversial.

Ping @banach-space @dcaballe @kuhar :)

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry about the delay, I was waiting for #129517.

LGTM, thanks!

So what's left to get rid of vector.{insert|extract}element?

@Groverkss
Copy link
Member Author

I'm guessing with the doc change landed, this is a trivial change now, so landing this :)

@Groverkss Groverkss merged commit d10dca6 into llvm:main Mar 6, 2025
11 checks passed
@Groverkss
Copy link
Member Author

Groverkss commented Mar 6, 2025

Sorry about the delay, I was waiting for #129517.

LGTM, thanks!

So what's left to get rid of vector.{insert|extract}element?

I'm going to benchmark the replacement in IREE and check what happens for performance. If nothing significant, I'll send the next pr that removes it.

jph-13 pushed a commit to jph-13/llvm-project that referenced this pull request Mar 21, 2025
…r to folders (llvm#128040)

This PR moves vector.insert canonicalizers for DenseElementsAttr (splat
and non splat case) to folders. Folders are local, and it's always
better to implement a folder than a canonicalizer.

This PR is mostly NFC-ish, because the functionality mostly remains
same, but is now run as part of a folder, which is why some tests are
changed, because GreedyPatternRewriter tries to fold by default.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants