-
Notifications
You must be signed in to change notification settings - Fork 13.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][Vector] Move vector.insert canonicalizers for DenseElementsAttr to folders #128040
[mlir][Vector] Move vector.insert canonicalizers for DenseElementsAttr to folders #128040
Conversation
Depends on #127995 , please review top commit only |
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-linalg Author: Kunwar Grover (Groverkss) ChangesThis PR moves vector.insert canonicalizers for DenseElementsAttr (splat and non splat case) to folders. Folders are local, and it's always better to implement a folder than a canonicalizer. This PR is mostly NFC-ish, because the functionality mostly remains same, but is now run as part of a folder, which is why some tests are changed, because GreedyPatternRewriter tries to fold by default. Patch is 24.89 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/128040.diff 7 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index d5f3634377e4c..f21ad23a03c6e 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2047,6 +2047,49 @@ static Attribute foldPoisonSrcExtractOp(Attribute srcAttr) {
return {};
}
+static Attribute foldDenseElementsAttrSrcExtractOp(ExtractOp extractOp,
+ Attribute srcAttr) {
+ auto denseAttr = dyn_cast_if_present<DenseElementsAttr>(srcAttr);
+ if (!denseAttr) {
+ return {};
+ }
+
+ if (denseAttr.isSplat()) {
+ Attribute newAttr = denseAttr.getSplatValue<Attribute>();
+ if (auto vecDstType = llvm::dyn_cast<VectorType>(extractOp.getType()))
+ newAttr = DenseElementsAttr::get(vecDstType, newAttr);
+ return newAttr;
+ }
+
+ auto vecTy = llvm::cast<VectorType>(extractOp.getSourceVectorType());
+ if (vecTy.isScalable())
+ return {};
+
+ if (extractOp.hasDynamicPosition()) {
+ return {};
+ }
+
+ // Calculate the linearized position of the continuous chunk of elements to
+ // extract.
+ llvm::SmallVector<int64_t> completePositions(vecTy.getRank(), 0);
+ copy(extractOp.getStaticPosition(), completePositions.begin());
+ int64_t elemBeginPosition =
+ linearize(completePositions, computeStrides(vecTy.getShape()));
+ auto denseValuesBegin =
+ denseAttr.value_begin<TypedAttr>() + elemBeginPosition;
+
+ TypedAttr newAttr;
+ if (auto resVecTy = llvm::dyn_cast<VectorType>(extractOp.getType())) {
+ SmallVector<Attribute> elementValues(
+ denseValuesBegin, denseValuesBegin + resVecTy.getNumElements());
+ newAttr = DenseElementsAttr::get(resVecTy, elementValues);
+ } else {
+ newAttr = *denseValuesBegin;
+ }
+
+ return newAttr;
+}
+
OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
// Fold "vector.extract %v[] : vector<2x2xf32> from vector<2x2xf32>" to %v.
// Note: Do not fold "vector.extract %v[] : f32 from vector<f32>" (type
@@ -2058,6 +2101,8 @@ OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
return res;
if (auto res = foldPoisonSrcExtractOp(adaptor.getVector()))
return res;
+ if (auto res = foldDenseElementsAttrSrcExtractOp(*this, adaptor.getVector()))
+ return res;
if (succeeded(foldExtractOpFromExtractChain(*this)))
return getResult();
if (auto res = ExtractFromInsertTransposeChainState(*this).fold())
@@ -2121,80 +2166,6 @@ class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
}
};
-// Pattern to rewrite a ExtractOp(splat ConstantOp) -> ConstantOp.
-class ExtractOpSplatConstantFolder final : public OpRewritePattern<ExtractOp> {
-public:
- using OpRewritePattern::OpRewritePattern;
-
- LogicalResult matchAndRewrite(ExtractOp extractOp,
- PatternRewriter &rewriter) const override {
- // Return if 'ExtractOp' operand is not defined by a splat vector
- // ConstantOp.
- Value sourceVector = extractOp.getVector();
- Attribute vectorCst;
- if (!matchPattern(sourceVector, m_Constant(&vectorCst)))
- return failure();
- auto splat = llvm::dyn_cast<SplatElementsAttr>(vectorCst);
- if (!splat)
- return failure();
- TypedAttr newAttr = splat.getSplatValue<TypedAttr>();
- if (auto vecDstType = llvm::dyn_cast<VectorType>(extractOp.getType()))
- newAttr = DenseElementsAttr::get(vecDstType, newAttr);
- rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractOp, newAttr);
- return success();
- }
-};
-
-// Pattern to rewrite a ExtractOp(non-splat ConstantOp)[...] -> ConstantOp.
-class ExtractOpNonSplatConstantFolder final
- : public OpRewritePattern<ExtractOp> {
-public:
- using OpRewritePattern::OpRewritePattern;
-
- LogicalResult matchAndRewrite(ExtractOp extractOp,
- PatternRewriter &rewriter) const override {
- // TODO: Canonicalization for dynamic position not implemented yet.
- if (extractOp.hasDynamicPosition())
- return failure();
-
- // Return if 'ExtractOp' operand is not defined by a compatible vector
- // ConstantOp.
- Value sourceVector = extractOp.getVector();
- Attribute vectorCst;
- if (!matchPattern(sourceVector, m_Constant(&vectorCst)))
- return failure();
-
- auto vecTy = llvm::cast<VectorType>(sourceVector.getType());
- if (vecTy.isScalable())
- return failure();
-
- // The splat case is handled by `ExtractOpSplatConstantFolder`.
- auto dense = llvm::dyn_cast<DenseElementsAttr>(vectorCst);
- if (!dense || dense.isSplat())
- return failure();
-
- // Calculate the linearized position of the continuous chunk of elements to
- // extract.
- llvm::SmallVector<int64_t> completePositions(vecTy.getRank(), 0);
- copy(extractOp.getStaticPosition(), completePositions.begin());
- int64_t elemBeginPosition =
- linearize(completePositions, computeStrides(vecTy.getShape()));
- auto denseValuesBegin = dense.value_begin<TypedAttr>() + elemBeginPosition;
-
- TypedAttr newAttr;
- if (auto resVecTy = llvm::dyn_cast<VectorType>(extractOp.getType())) {
- SmallVector<Attribute> elementValues(
- denseValuesBegin, denseValuesBegin + resVecTy.getNumElements());
- newAttr = DenseElementsAttr::get(resVecTy, elementValues);
- } else {
- newAttr = *denseValuesBegin;
- }
-
- rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractOp, newAttr);
- return success();
- }
-};
-
// Pattern to rewrite a ExtractOp(CreateMask) -> CreateMask.
class ExtractOpFromCreateMask final : public OpRewritePattern<ExtractOp> {
public:
@@ -2332,8 +2303,7 @@ LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<ExtractOpSplatConstantFolder, ExtractOpNonSplatConstantFolder,
- ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
+ results.add<ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
results.add(foldExtractFromShapeCastToShapeCast);
results.add(foldExtractFromFromElements);
}
@@ -3043,94 +3013,78 @@ class InsertSplatToSplat final : public OpRewritePattern<InsertOp> {
}
};
-// Pattern to rewrite a InsertOp(ConstantOp into ConstantOp) -> ConstantOp.
-class InsertOpConstantFolder final : public OpRewritePattern<InsertOp> {
-public:
- using OpRewritePattern::OpRewritePattern;
-
- // Do not create constants with more than `vectorSizeFoldThreashold` elements,
- // unless the source vector constant has a single use.
- static constexpr int64_t vectorSizeFoldThreshold = 256;
-
- LogicalResult matchAndRewrite(InsertOp op,
- PatternRewriter &rewriter) const override {
- // TODO: Canonicalization for dynamic position not implemented yet.
- if (op.hasDynamicPosition())
- return failure();
+} // namespace
- // Return if 'InsertOp' operand is not defined by a compatible vector
- // ConstantOp.
- TypedValue<VectorType> destVector = op.getDest();
- Attribute vectorDestCst;
- if (!matchPattern(destVector, m_Constant(&vectorDestCst)))
- return failure();
- auto denseDest = llvm::dyn_cast<DenseElementsAttr>(vectorDestCst);
- if (!denseDest)
- return failure();
+static Attribute
+foldDenseElementsAttrDestInsertOp(InsertOp insertOp, Attribute srcAttr,
+ Attribute dstAttr,
+ int64_t maxVectorSizeFoldThreshold) {
+ if (insertOp.hasDynamicPosition())
+ return {};
- VectorType destTy = destVector.getType();
- if (destTy.isScalable())
- return failure();
+ auto denseDst = llvm::dyn_cast_if_present<DenseElementsAttr>(dstAttr);
+ if (!denseDst)
+ return {};
- // Make sure we do not create too many large constants.
- if (destTy.getNumElements() > vectorSizeFoldThreshold &&
- !destVector.hasOneUse())
- return failure();
+ if (!srcAttr) {
+ return {};
+ }
- Value sourceValue = op.getSource();
- Attribute sourceCst;
- if (!matchPattern(sourceValue, m_Constant(&sourceCst)))
- return failure();
+ VectorType destTy = insertOp.getDestVectorType();
+ if (destTy.isScalable())
+ return {};
- // Calculate the linearized position of the continuous chunk of elements to
- // insert.
- llvm::SmallVector<int64_t> completePositions(destTy.getRank(), 0);
- copy(op.getStaticPosition(), completePositions.begin());
- int64_t insertBeginPosition =
- linearize(completePositions, computeStrides(destTy.getShape()));
-
- SmallVector<Attribute> insertedValues;
- Type destEltType = destTy.getElementType();
-
- // The `convertIntegerAttr` method specifically handles the case
- // for `llvm.mlir.constant` which can hold an attribute with a
- // different type than the return type.
- if (auto denseSource = llvm::dyn_cast<DenseElementsAttr>(sourceCst)) {
- for (auto value : denseSource.getValues<Attribute>())
- insertedValues.push_back(convertIntegerAttr(value, destEltType));
- } else {
- insertedValues.push_back(convertIntegerAttr(sourceCst, destEltType));
- }
+ // Make sure we do not create too many large constants.
+ if (destTy.getNumElements() > maxVectorSizeFoldThreshold &&
+ !insertOp->hasOneUse())
+ return {};
- auto allValues = llvm::to_vector(denseDest.getValues<Attribute>());
- copy(insertedValues, allValues.begin() + insertBeginPosition);
- auto newAttr = DenseElementsAttr::get(destTy, allValues);
+ // Calculate the linearized position of the continuous chunk of elements to
+ // insert.
+ llvm::SmallVector<int64_t> completePositions(destTy.getRank(), 0);
+ copy(insertOp.getStaticPosition(), completePositions.begin());
+ int64_t insertBeginPosition =
+ linearize(completePositions, computeStrides(destTy.getShape()));
- rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, newAttr);
- return success();
- }
+ SmallVector<Attribute> insertedValues;
+ Type destEltType = destTy.getElementType();
-private:
/// Converts the expected type to an IntegerAttr if there's
/// a mismatch.
- Attribute convertIntegerAttr(Attribute attr, Type expectedType) const {
+ auto convertIntegerAttr = [](Attribute attr, Type expectedType) -> Attribute {
if (auto intAttr = mlir::dyn_cast<IntegerAttr>(attr)) {
if (intAttr.getType() != expectedType)
return IntegerAttr::get(expectedType, intAttr.getInt());
}
return attr;
+ };
+
+ // The `convertIntegerAttr` method specifically handles the case
+ // for `llvm.mlir.constant` which can hold an attribute with a
+ // different type than the return type.
+ if (auto denseSource = llvm::dyn_cast<DenseElementsAttr>(srcAttr)) {
+ for (auto value : denseSource.getValues<Attribute>())
+ insertedValues.push_back(convertIntegerAttr(value, destEltType));
+ } else {
+ insertedValues.push_back(convertIntegerAttr(srcAttr, destEltType));
}
-};
-} // namespace
+ auto allValues = llvm::to_vector(denseDst.getValues<Attribute>());
+ copy(insertedValues, allValues.begin() + insertBeginPosition);
+ auto newAttr = DenseElementsAttr::get(destTy, allValues);
+
+ return newAttr;
+}
void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat,
- InsertOpConstantFolder>(context);
+ results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat>(context);
}
OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
+ // Do not create constants with more than `vectorSizeFoldThreashold` elements,
+ // unless the source vector constant has a single use.
+ constexpr int64_t vectorSizeFoldThreshold = 256;
// Fold "vector.insert %v, %dest [] : vector<2x2xf32> from vector<2x2xf32>" to
// %v. Note: Do not fold "vector.insert %v, %dest [] : f32 into vector<f32>"
// (type mismatch).
@@ -3142,6 +3096,11 @@ OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
if (auto res = foldPoisonIndexInsertExtractOp(
getContext(), adaptor.getStaticPosition(), kPoisonIndex))
return res;
+ if (auto res = foldDenseElementsAttrDestInsertOp(*this, adaptor.getSource(),
+ adaptor.getDest(),
+ vectorSizeFoldThreshold)) {
+ return res;
+ }
return {};
}
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 36b37a137ac1e..1ab28b9df2d19 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1517,13 +1517,9 @@ func.func @constant_mask_2d() -> vector<4x4xi1> {
}
// CHECK-LABEL: func @constant_mask_2d
-// CHECK: %[[VAL_0:.*]] = arith.constant dense<[true, true, false, false]> : vector<4xi1>
-// CHECK: %[[VAL_1:.*]] = arith.constant dense<false> : vector<4x4xi1>
-// CHECK: %[[VAL_2:.*]] = builtin.unrealized_conversion_cast %[[VAL_1]] : vector<4x4xi1> to !llvm.array<4 x vector<4xi1>>
-// CHECK: %[[VAL_3:.*]] = llvm.insertvalue %[[VAL_0]], %[[VAL_2]][0] : !llvm.array<4 x vector<4xi1>>
-// CHECK: %[[VAL_4:.*]] = llvm.insertvalue %[[VAL_0]], %[[VAL_3]][1] : !llvm.array<4 x vector<4xi1>>
-// CHECK: %[[VAL_5:.*]] = builtin.unrealized_conversion_cast %[[VAL_4]] : !llvm.array<4 x vector<4xi1>> to vector<4x4xi1>
-// CHECK: return %[[VAL_5]] : vector<4x4xi1>
+// CHECK: %[[VAL_0:.*]] = arith.constant
+// CHECK-SAME{LITERAL}: dense<[[true, true, false, false], [true, true, false, false], [false, false, false, false], [false, false, false, false]]> : vector<4x4xi1>
+// CHECK: return %[[VAL_0]] : vector<4x4xi1>
// -----
diff --git a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
index e66fbe968d9b0..cd83e1239fdda 100644
--- a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
@@ -32,14 +32,8 @@ func.func @vectorize_nd_tensor_extract_transfer_read_basic(
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
-// CHECK-DAG: %[[CST_0:.+]] = arith.constant dense<0> : vector<1xindex>
-// CHECK-DAG: %[[CST_1:.+]] = arith.constant dense<[0, 1, 2]> : vector<3xindex>
-// CHECK-DAG: %[[IDX1:.+]] = vector.extract %[[CST_0]][0] : index from vector<1xindex>
-// CHECK-DAG: %[[IDX2:.+]] = vector.extract %[[CST_0]][0] : index from vector<1xindex>
-// CHECK-DAG: %[[IDX3:.+]] = vector.extract %[[CST_1]][0] : index from vector<3xindex>
-
-// CHECK: %[[READ:.*]] = vector.transfer_read %[[ARG0]][%[[IDX1]], %[[IDX2]], %[[IDX3]]], %[[CST]] {in_bounds = [true, true, true]} : tensor<3x3x3xf32>, vector<1x1x3xf32>
+// CHECK: %[[READ:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]], %[[C0]]], %[[CST]] {in_bounds = [true, true, true]} : tensor<3x3x3xf32>, vector<1x1x3xf32>
// CHECK: vector.transfer_write %[[READ]], %[[ARG1]][%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<1x1x3xf32>, tensor<1x1x3xf32>
// -----
@@ -175,16 +169,12 @@ func.func @vectorize_nd_tensor_extract_with_maxsi_contiguous(%arg0: tensor<80x16
// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_with_maxsi_contiguous(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<80x16xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x4xf32>) -> tensor<1x4xf32> {
-// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 0.000000e+00 : f32
-
-// CHECK-DAG: %[[CST_0:.+]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
-// CHECK-DAG: %[[CST_1:.+]] = arith.constant dense<16> : vector<4x1xindex>
-// CHECK-DAG: %[[IDX0:.+]] = vector.extract %[[CST_1]][0, 0] : index from vector<4x1xindex>
-// CHECK-DAG: %[[IDX1:.+]] = vector.extract %[[CST_0]][0] : index from vector<4xindex>
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index
+// CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK: %[[VAL_8:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[IDX0]], %[[IDX1]]], %[[VAL_5]] {in_bounds = [true, true]} : tensor<80x16xf32>, vector<1x4xf32>
-// CHECK: %[[VAL_9:.*]] = vector.transfer_write %[[VAL_8]], %[[VAL_1]]{{\[}}%[[VAL_4]], %[[VAL_4]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32>
+// CHECK: %[[VAL_8:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[C16]], %[[C0]]], %[[CST]] {in_bounds = [true, true]} : tensor<80x16xf32>, vector<1x4xf32>
+// CHECK: %[[VAL_9:.*]] = vector.transfer_write %[[VAL_8]], %[[VAL_1]]{{\[}}%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32>
// CHECK: return %[[VAL_9]] : tensor<1x4xf32>
// CHECK: }
@@ -675,9 +665,7 @@ func.func @scalar_read_with_broadcast_from_column_tensor(%init: tensor<1x1x4xi32
// CHECK-DAG: %[[PAD:.*]] = arith.constant 0 : i32
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[SRC:.*]] = arith.constant dense<{{\[\[}}0], [1], [2], [3], [4], [5], [6], [7], [8], [9], [10], [11], [12], [13], [14]]> : tensor<15x1xi32>
-// CHECK-DAG: %[[IDX_VEC:.*]] = arith.constant dense<0> : vector<1xindex>
-// CHECK: %[[IDX_ELT:.*]] = vector.extract %[[IDX_VEC]][0] : index from vector<1xindex>
-// CHECK: %[[READ:.*]] = vector.transfer_read %[[SRC]]{{\[}}%[[IDX_ELT]], %[[C0]]], %[[PAD]] : tensor<15x1xi32>, vector<i32>
+// CHECK: %[[READ:.*]] = vector.transfer_read %[[SRC]]{{\[}}%[[C0]], %[[C0]]], %[[PAD]] : tensor<15x1xi32>, vector<i32>
// CHECK: %[[READ_BCAST:.*]] = vector.broadcast %[[READ]] : vector<i32> to vector<1x1x4xi32>
// CHECK: %[[RES:.*]] = vector.transfer_write %[[READ_BCAST]], %[[INIT]][%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<1x1x4xi32>, tensor<1x1x4xi32>
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 99b1bbab1eede..8e5ddbfffcdd9 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -310,12 +310,12 @@ func.func @test_vector_insert_scalable(%arg0: vector<2x8x[4]xf32>, %arg1: vector
// -----
// ALL-LABEL: test_vector_extract_scalar
-func.func @test_vector_extract_scalar() {
+func.func @test_vector_extract_scalar(%idx : index) {
%cst = arith.constant dense<[1, 2, 3, 4]> : vector<4xi32>
// ALL-NOT: vector.shuffle
// ALL: vector.extract
// ALL-NOT: vector.shuffle
- %0 = vector.extract %cst[0] : i32 from vector<4xi32>
+ %0 = vector.extract %cst[%idx] : i32 from vector<4xi32>
return
}
diff --git a/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir b/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir
index c5cb09b9aa9f9..b4ebb14b8829e 100644
--- a/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir
+++ b/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir
@@ -101,9 +101,8 @@ func.func @transfer_read_2d_extract(%m: memref<?x?x?x?xf32>, %idx: index, %idx2:
// CHECK-LABEL: func @transfer_write_arith_constant(
// CHECK-SAME: %[[m:.*]]: memref<?x?x?xf32>, %[[idx:.*]]: index
-// CHECK: %[[cst:.*]] = arith.constant dense<5.000000e+00> : vector<1x1xf32>
-// CHECK: %[[extract:.*]] = vector.extract %[[cst]][0, 0] : f32 from vector<1x1xf32>
-// CHECK: memref.store %[[extract]], %[[m]][%[[idx]], %[[idx]], %[[idx]]]
+// CHECK: %[[cst:.*]] = arith.constant 5.000000e+00 : f32
+// CHECK: memref.store %[[cst]], %[[m]][%[[idx]], %[[idx]], %[[idx]]]
func.func @transfer_write_arith_constant(%m: memref<?x?x?xf32>, %idx: index) {
%cst = arith.constant dense<5.000000e+00> : vector<1x1xf32>
vector.transfer_write %cst, %m[%idx, %idx, %idx] : vector<1x1xf32>, memref<?x?x?xf32>
diff --git a/mlir/test/Dialect/Vector/vector-gather-lowering.mlir b/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
index 20e9400ed698d..5be267c1be984 100644
--- a...
[truncated]
|
@llvm/pr-subscribers-mlir-vector Author: Kunwar Grover (Groverkss) ChangesThis PR moves vector.insert canonicalizers for DenseElementsAttr (splat and non splat case) to folders. Folders are local, and it's always better to implement a folder than a canonicalizer. This PR is mostly NFC-ish, because the functionality mostly remains same, but is now run as part of a folder, which is why some tests are changed, because GreedyPatternRewriter tries to fold by default. Patch is 24.89 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/128040.diff 7 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index d5f3634377e4c..f21ad23a03c6e 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2047,6 +2047,49 @@ static Attribute foldPoisonSrcExtractOp(Attribute srcAttr) {
return {};
}
+static Attribute foldDenseElementsAttrSrcExtractOp(ExtractOp extractOp,
+ Attribute srcAttr) {
+ auto denseAttr = dyn_cast_if_present<DenseElementsAttr>(srcAttr);
+ if (!denseAttr) {
+ return {};
+ }
+
+ if (denseAttr.isSplat()) {
+ Attribute newAttr = denseAttr.getSplatValue<Attribute>();
+ if (auto vecDstType = llvm::dyn_cast<VectorType>(extractOp.getType()))
+ newAttr = DenseElementsAttr::get(vecDstType, newAttr);
+ return newAttr;
+ }
+
+ auto vecTy = llvm::cast<VectorType>(extractOp.getSourceVectorType());
+ if (vecTy.isScalable())
+ return {};
+
+ if (extractOp.hasDynamicPosition()) {
+ return {};
+ }
+
+ // Calculate the linearized position of the continuous chunk of elements to
+ // extract.
+ llvm::SmallVector<int64_t> completePositions(vecTy.getRank(), 0);
+ copy(extractOp.getStaticPosition(), completePositions.begin());
+ int64_t elemBeginPosition =
+ linearize(completePositions, computeStrides(vecTy.getShape()));
+ auto denseValuesBegin =
+ denseAttr.value_begin<TypedAttr>() + elemBeginPosition;
+
+ TypedAttr newAttr;
+ if (auto resVecTy = llvm::dyn_cast<VectorType>(extractOp.getType())) {
+ SmallVector<Attribute> elementValues(
+ denseValuesBegin, denseValuesBegin + resVecTy.getNumElements());
+ newAttr = DenseElementsAttr::get(resVecTy, elementValues);
+ } else {
+ newAttr = *denseValuesBegin;
+ }
+
+ return newAttr;
+}
+
OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
// Fold "vector.extract %v[] : vector<2x2xf32> from vector<2x2xf32>" to %v.
// Note: Do not fold "vector.extract %v[] : f32 from vector<f32>" (type
@@ -2058,6 +2101,8 @@ OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
return res;
if (auto res = foldPoisonSrcExtractOp(adaptor.getVector()))
return res;
+ if (auto res = foldDenseElementsAttrSrcExtractOp(*this, adaptor.getVector()))
+ return res;
if (succeeded(foldExtractOpFromExtractChain(*this)))
return getResult();
if (auto res = ExtractFromInsertTransposeChainState(*this).fold())
@@ -2121,80 +2166,6 @@ class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
}
};
-// Pattern to rewrite a ExtractOp(splat ConstantOp) -> ConstantOp.
-class ExtractOpSplatConstantFolder final : public OpRewritePattern<ExtractOp> {
-public:
- using OpRewritePattern::OpRewritePattern;
-
- LogicalResult matchAndRewrite(ExtractOp extractOp,
- PatternRewriter &rewriter) const override {
- // Return if 'ExtractOp' operand is not defined by a splat vector
- // ConstantOp.
- Value sourceVector = extractOp.getVector();
- Attribute vectorCst;
- if (!matchPattern(sourceVector, m_Constant(&vectorCst)))
- return failure();
- auto splat = llvm::dyn_cast<SplatElementsAttr>(vectorCst);
- if (!splat)
- return failure();
- TypedAttr newAttr = splat.getSplatValue<TypedAttr>();
- if (auto vecDstType = llvm::dyn_cast<VectorType>(extractOp.getType()))
- newAttr = DenseElementsAttr::get(vecDstType, newAttr);
- rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractOp, newAttr);
- return success();
- }
-};
-
-// Pattern to rewrite a ExtractOp(non-splat ConstantOp)[...] -> ConstantOp.
-class ExtractOpNonSplatConstantFolder final
- : public OpRewritePattern<ExtractOp> {
-public:
- using OpRewritePattern::OpRewritePattern;
-
- LogicalResult matchAndRewrite(ExtractOp extractOp,
- PatternRewriter &rewriter) const override {
- // TODO: Canonicalization for dynamic position not implemented yet.
- if (extractOp.hasDynamicPosition())
- return failure();
-
- // Return if 'ExtractOp' operand is not defined by a compatible vector
- // ConstantOp.
- Value sourceVector = extractOp.getVector();
- Attribute vectorCst;
- if (!matchPattern(sourceVector, m_Constant(&vectorCst)))
- return failure();
-
- auto vecTy = llvm::cast<VectorType>(sourceVector.getType());
- if (vecTy.isScalable())
- return failure();
-
- // The splat case is handled by `ExtractOpSplatConstantFolder`.
- auto dense = llvm::dyn_cast<DenseElementsAttr>(vectorCst);
- if (!dense || dense.isSplat())
- return failure();
-
- // Calculate the linearized position of the continuous chunk of elements to
- // extract.
- llvm::SmallVector<int64_t> completePositions(vecTy.getRank(), 0);
- copy(extractOp.getStaticPosition(), completePositions.begin());
- int64_t elemBeginPosition =
- linearize(completePositions, computeStrides(vecTy.getShape()));
- auto denseValuesBegin = dense.value_begin<TypedAttr>() + elemBeginPosition;
-
- TypedAttr newAttr;
- if (auto resVecTy = llvm::dyn_cast<VectorType>(extractOp.getType())) {
- SmallVector<Attribute> elementValues(
- denseValuesBegin, denseValuesBegin + resVecTy.getNumElements());
- newAttr = DenseElementsAttr::get(resVecTy, elementValues);
- } else {
- newAttr = *denseValuesBegin;
- }
-
- rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractOp, newAttr);
- return success();
- }
-};
-
// Pattern to rewrite a ExtractOp(CreateMask) -> CreateMask.
class ExtractOpFromCreateMask final : public OpRewritePattern<ExtractOp> {
public:
@@ -2332,8 +2303,7 @@ LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<ExtractOpSplatConstantFolder, ExtractOpNonSplatConstantFolder,
- ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
+ results.add<ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
results.add(foldExtractFromShapeCastToShapeCast);
results.add(foldExtractFromFromElements);
}
@@ -3043,94 +3013,78 @@ class InsertSplatToSplat final : public OpRewritePattern<InsertOp> {
}
};
-// Pattern to rewrite a InsertOp(ConstantOp into ConstantOp) -> ConstantOp.
-class InsertOpConstantFolder final : public OpRewritePattern<InsertOp> {
-public:
- using OpRewritePattern::OpRewritePattern;
-
- // Do not create constants with more than `vectorSizeFoldThreashold` elements,
- // unless the source vector constant has a single use.
- static constexpr int64_t vectorSizeFoldThreshold = 256;
-
- LogicalResult matchAndRewrite(InsertOp op,
- PatternRewriter &rewriter) const override {
- // TODO: Canonicalization for dynamic position not implemented yet.
- if (op.hasDynamicPosition())
- return failure();
+} // namespace
- // Return if 'InsertOp' operand is not defined by a compatible vector
- // ConstantOp.
- TypedValue<VectorType> destVector = op.getDest();
- Attribute vectorDestCst;
- if (!matchPattern(destVector, m_Constant(&vectorDestCst)))
- return failure();
- auto denseDest = llvm::dyn_cast<DenseElementsAttr>(vectorDestCst);
- if (!denseDest)
- return failure();
+static Attribute
+foldDenseElementsAttrDestInsertOp(InsertOp insertOp, Attribute srcAttr,
+ Attribute dstAttr,
+ int64_t maxVectorSizeFoldThreshold) {
+ if (insertOp.hasDynamicPosition())
+ return {};
- VectorType destTy = destVector.getType();
- if (destTy.isScalable())
- return failure();
+ auto denseDst = llvm::dyn_cast_if_present<DenseElementsAttr>(dstAttr);
+ if (!denseDst)
+ return {};
- // Make sure we do not create too many large constants.
- if (destTy.getNumElements() > vectorSizeFoldThreshold &&
- !destVector.hasOneUse())
- return failure();
+ if (!srcAttr) {
+ return {};
+ }
- Value sourceValue = op.getSource();
- Attribute sourceCst;
- if (!matchPattern(sourceValue, m_Constant(&sourceCst)))
- return failure();
+ VectorType destTy = insertOp.getDestVectorType();
+ if (destTy.isScalable())
+ return {};
- // Calculate the linearized position of the continuous chunk of elements to
- // insert.
- llvm::SmallVector<int64_t> completePositions(destTy.getRank(), 0);
- copy(op.getStaticPosition(), completePositions.begin());
- int64_t insertBeginPosition =
- linearize(completePositions, computeStrides(destTy.getShape()));
-
- SmallVector<Attribute> insertedValues;
- Type destEltType = destTy.getElementType();
-
- // The `convertIntegerAttr` method specifically handles the case
- // for `llvm.mlir.constant` which can hold an attribute with a
- // different type than the return type.
- if (auto denseSource = llvm::dyn_cast<DenseElementsAttr>(sourceCst)) {
- for (auto value : denseSource.getValues<Attribute>())
- insertedValues.push_back(convertIntegerAttr(value, destEltType));
- } else {
- insertedValues.push_back(convertIntegerAttr(sourceCst, destEltType));
- }
+ // Make sure we do not create too many large constants.
+ if (destTy.getNumElements() > maxVectorSizeFoldThreshold &&
+ !insertOp->hasOneUse())
+ return {};
- auto allValues = llvm::to_vector(denseDest.getValues<Attribute>());
- copy(insertedValues, allValues.begin() + insertBeginPosition);
- auto newAttr = DenseElementsAttr::get(destTy, allValues);
+ // Calculate the linearized position of the continuous chunk of elements to
+ // insert.
+ llvm::SmallVector<int64_t> completePositions(destTy.getRank(), 0);
+ copy(insertOp.getStaticPosition(), completePositions.begin());
+ int64_t insertBeginPosition =
+ linearize(completePositions, computeStrides(destTy.getShape()));
- rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, newAttr);
- return success();
- }
+ SmallVector<Attribute> insertedValues;
+ Type destEltType = destTy.getElementType();
-private:
/// Converts the expected type to an IntegerAttr if there's
/// a mismatch.
- Attribute convertIntegerAttr(Attribute attr, Type expectedType) const {
+ auto convertIntegerAttr = [](Attribute attr, Type expectedType) -> Attribute {
if (auto intAttr = mlir::dyn_cast<IntegerAttr>(attr)) {
if (intAttr.getType() != expectedType)
return IntegerAttr::get(expectedType, intAttr.getInt());
}
return attr;
+ };
+
+ // The `convertIntegerAttr` method specifically handles the case
+ // for `llvm.mlir.constant` which can hold an attribute with a
+ // different type than the return type.
+ if (auto denseSource = llvm::dyn_cast<DenseElementsAttr>(srcAttr)) {
+ for (auto value : denseSource.getValues<Attribute>())
+ insertedValues.push_back(convertIntegerAttr(value, destEltType));
+ } else {
+ insertedValues.push_back(convertIntegerAttr(srcAttr, destEltType));
}
-};
-} // namespace
+ auto allValues = llvm::to_vector(denseDst.getValues<Attribute>());
+ copy(insertedValues, allValues.begin() + insertBeginPosition);
+ auto newAttr = DenseElementsAttr::get(destTy, allValues);
+
+ return newAttr;
+}
void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat,
- InsertOpConstantFolder>(context);
+ results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat>(context);
}
OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
+ // Do not create constants with more than `vectorSizeFoldThreashold` elements,
+ // unless the source vector constant has a single use.
+ constexpr int64_t vectorSizeFoldThreshold = 256;
// Fold "vector.insert %v, %dest [] : vector<2x2xf32> from vector<2x2xf32>" to
// %v. Note: Do not fold "vector.insert %v, %dest [] : f32 into vector<f32>"
// (type mismatch).
@@ -3142,6 +3096,11 @@ OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
if (auto res = foldPoisonIndexInsertExtractOp(
getContext(), adaptor.getStaticPosition(), kPoisonIndex))
return res;
+ if (auto res = foldDenseElementsAttrDestInsertOp(*this, adaptor.getSource(),
+ adaptor.getDest(),
+ vectorSizeFoldThreshold)) {
+ return res;
+ }
return {};
}
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 36b37a137ac1e..1ab28b9df2d19 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1517,13 +1517,9 @@ func.func @constant_mask_2d() -> vector<4x4xi1> {
}
// CHECK-LABEL: func @constant_mask_2d
-// CHECK: %[[VAL_0:.*]] = arith.constant dense<[true, true, false, false]> : vector<4xi1>
-// CHECK: %[[VAL_1:.*]] = arith.constant dense<false> : vector<4x4xi1>
-// CHECK: %[[VAL_2:.*]] = builtin.unrealized_conversion_cast %[[VAL_1]] : vector<4x4xi1> to !llvm.array<4 x vector<4xi1>>
-// CHECK: %[[VAL_3:.*]] = llvm.insertvalue %[[VAL_0]], %[[VAL_2]][0] : !llvm.array<4 x vector<4xi1>>
-// CHECK: %[[VAL_4:.*]] = llvm.insertvalue %[[VAL_0]], %[[VAL_3]][1] : !llvm.array<4 x vector<4xi1>>
-// CHECK: %[[VAL_5:.*]] = builtin.unrealized_conversion_cast %[[VAL_4]] : !llvm.array<4 x vector<4xi1>> to vector<4x4xi1>
-// CHECK: return %[[VAL_5]] : vector<4x4xi1>
+// CHECK: %[[VAL_0:.*]] = arith.constant
+// CHECK-SAME{LITERAL}: dense<[[true, true, false, false], [true, true, false, false], [false, false, false, false], [false, false, false, false]]> : vector<4x4xi1>
+// CHECK: return %[[VAL_0]] : vector<4x4xi1>
// -----
diff --git a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
index e66fbe968d9b0..cd83e1239fdda 100644
--- a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
@@ -32,14 +32,8 @@ func.func @vectorize_nd_tensor_extract_transfer_read_basic(
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
-// CHECK-DAG: %[[CST_0:.+]] = arith.constant dense<0> : vector<1xindex>
-// CHECK-DAG: %[[CST_1:.+]] = arith.constant dense<[0, 1, 2]> : vector<3xindex>
-// CHECK-DAG: %[[IDX1:.+]] = vector.extract %[[CST_0]][0] : index from vector<1xindex>
-// CHECK-DAG: %[[IDX2:.+]] = vector.extract %[[CST_0]][0] : index from vector<1xindex>
-// CHECK-DAG: %[[IDX3:.+]] = vector.extract %[[CST_1]][0] : index from vector<3xindex>
-
-// CHECK: %[[READ:.*]] = vector.transfer_read %[[ARG0]][%[[IDX1]], %[[IDX2]], %[[IDX3]]], %[[CST]] {in_bounds = [true, true, true]} : tensor<3x3x3xf32>, vector<1x1x3xf32>
+// CHECK: %[[READ:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]], %[[C0]]], %[[CST]] {in_bounds = [true, true, true]} : tensor<3x3x3xf32>, vector<1x1x3xf32>
// CHECK: vector.transfer_write %[[READ]], %[[ARG1]][%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<1x1x3xf32>, tensor<1x1x3xf32>
// -----
@@ -175,16 +169,12 @@ func.func @vectorize_nd_tensor_extract_with_maxsi_contiguous(%arg0: tensor<80x16
// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_with_maxsi_contiguous(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<80x16xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x4xf32>) -> tensor<1x4xf32> {
-// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 0.000000e+00 : f32
-
-// CHECK-DAG: %[[CST_0:.+]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
-// CHECK-DAG: %[[CST_1:.+]] = arith.constant dense<16> : vector<4x1xindex>
-// CHECK-DAG: %[[IDX0:.+]] = vector.extract %[[CST_1]][0, 0] : index from vector<4x1xindex>
-// CHECK-DAG: %[[IDX1:.+]] = vector.extract %[[CST_0]][0] : index from vector<4xindex>
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index
+// CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK: %[[VAL_8:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[IDX0]], %[[IDX1]]], %[[VAL_5]] {in_bounds = [true, true]} : tensor<80x16xf32>, vector<1x4xf32>
-// CHECK: %[[VAL_9:.*]] = vector.transfer_write %[[VAL_8]], %[[VAL_1]]{{\[}}%[[VAL_4]], %[[VAL_4]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32>
+// CHECK: %[[VAL_8:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[C16]], %[[C0]]], %[[CST]] {in_bounds = [true, true]} : tensor<80x16xf32>, vector<1x4xf32>
+// CHECK: %[[VAL_9:.*]] = vector.transfer_write %[[VAL_8]], %[[VAL_1]]{{\[}}%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32>
// CHECK: return %[[VAL_9]] : tensor<1x4xf32>
// CHECK: }
@@ -675,9 +665,7 @@ func.func @scalar_read_with_broadcast_from_column_tensor(%init: tensor<1x1x4xi32
// CHECK-DAG: %[[PAD:.*]] = arith.constant 0 : i32
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[SRC:.*]] = arith.constant dense<{{\[\[}}0], [1], [2], [3], [4], [5], [6], [7], [8], [9], [10], [11], [12], [13], [14]]> : tensor<15x1xi32>
-// CHECK-DAG: %[[IDX_VEC:.*]] = arith.constant dense<0> : vector<1xindex>
-// CHECK: %[[IDX_ELT:.*]] = vector.extract %[[IDX_VEC]][0] : index from vector<1xindex>
-// CHECK: %[[READ:.*]] = vector.transfer_read %[[SRC]]{{\[}}%[[IDX_ELT]], %[[C0]]], %[[PAD]] : tensor<15x1xi32>, vector<i32>
+// CHECK: %[[READ:.*]] = vector.transfer_read %[[SRC]]{{\[}}%[[C0]], %[[C0]]], %[[PAD]] : tensor<15x1xi32>, vector<i32>
// CHECK: %[[READ_BCAST:.*]] = vector.broadcast %[[READ]] : vector<i32> to vector<1x1x4xi32>
// CHECK: %[[RES:.*]] = vector.transfer_write %[[READ_BCAST]], %[[INIT]][%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<1x1x4xi32>, tensor<1x1x4xi32>
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 99b1bbab1eede..8e5ddbfffcdd9 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -310,12 +310,12 @@ func.func @test_vector_insert_scalable(%arg0: vector<2x8x[4]xf32>, %arg1: vector
// -----
// ALL-LABEL: test_vector_extract_scalar
-func.func @test_vector_extract_scalar() {
+func.func @test_vector_extract_scalar(%idx : index) {
%cst = arith.constant dense<[1, 2, 3, 4]> : vector<4xi32>
// ALL-NOT: vector.shuffle
// ALL: vector.extract
// ALL-NOT: vector.shuffle
- %0 = vector.extract %cst[0] : i32 from vector<4xi32>
+ %0 = vector.extract %cst[%idx] : i32 from vector<4xi32>
return
}
diff --git a/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir b/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir
index c5cb09b9aa9f9..b4ebb14b8829e 100644
--- a/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir
+++ b/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir
@@ -101,9 +101,8 @@ func.func @transfer_read_2d_extract(%m: memref<?x?x?x?xf32>, %idx: index, %idx2:
// CHECK-LABEL: func @transfer_write_arith_constant(
// CHECK-SAME: %[[m:.*]]: memref<?x?x?xf32>, %[[idx:.*]]: index
-// CHECK: %[[cst:.*]] = arith.constant dense<5.000000e+00> : vector<1x1xf32>
-// CHECK: %[[extract:.*]] = vector.extract %[[cst]][0, 0] : f32 from vector<1x1xf32>
-// CHECK: memref.store %[[extract]], %[[m]][%[[idx]], %[[idx]], %[[idx]]]
+// CHECK: %[[cst:.*]] = arith.constant 5.000000e+00 : f32
+// CHECK: memref.store %[[cst]], %[[m]][%[[idx]], %[[idx]], %[[idx]]]
func.func @transfer_write_arith_constant(%m: memref<?x?x?xf32>, %idx: index) {
%cst = arith.constant dense<5.000000e+00> : vector<1x1xf32>
vector.transfer_write %cst, %m[%idx, %idx, %idx] : vector<1x1xf32>, memref<?x?x?xf32>
diff --git a/mlir/test/Dialect/Vector/vector-gather-lowering.mlir b/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
index 20e9400ed698d..5be267c1be984 100644
--- a...
[truncated]
|
0d08560
to
14f025f
Compare
This PR is now ready to review. It's a InsertOp variant of #127995 so I'm hoping it should be fairly non-controversial. Ping @banach-space @dcaballe @kuhar :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry about the delay, I was waiting for #129517.
LGTM, thanks!
So what's left to get rid of vector.{insert|extract}element
?
I'm guessing with the doc change landed, this is a trivial change now, so landing this :) |
I'm going to benchmark the replacement in IREE and check what happens for performance. If nothing significant, I'll send the next pr that removes it. |
…r to folders (llvm#128040) This PR moves vector.insert canonicalizers for DenseElementsAttr (splat and non splat case) to folders. Folders are local, and it's always better to implement a folder than a canonicalizer. This PR is mostly NFC-ish, because the functionality mostly remains same, but is now run as part of a folder, which is why some tests are changed, because GreedyPatternRewriter tries to fold by default.
This PR moves vector.insert canonicalizers for DenseElementsAttr (splat and non splat case) to folders. Folders are local, and it's always better to implement a folder than a canonicalizer.
This PR is mostly NFC-ish, because the functionality mostly remains same, but is now run as part of a folder, which is why some tests are changed, because GreedyPatternRewriter tries to fold by default.