-
Notifications
You must be signed in to change notification settings - Fork 13.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][vector] Fix emulation of "narrow" type vector.store
#133231
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-vector Author: Andrzej Warzyński (banach-space) ChangesBelow are two examples of "narrow" func.func @<!-- -->example_1(%arg0: vector<4xi2>) {
%0 = memref.alloc() : memref<13xi2>
%c4 = arith.constant 4 : index
vector.store %arg0, %0[%c4] : memref<13xi2>, vector<4xi2>
return
} The second example below does require a partial store (due to the func.func @<!-- -->example_2(%arg0: vector<4xi2>) {
%0 = memref.alloc() : memref<13xi2>
%c3 = arith.constant 3 : index
vector.store %arg0, %0[%c3] : memref<13xi2>, vector<4xi2>
return
} This is currently incorrectly emulated as a single "full" store (note func.func @<!-- -->example_2(%arg0: vector<4xi2>) {
%alloc = memref.alloc() : memref<4xi8>
%0 = vector.bitcast %arg0 : vector<4xi2> to vector<1xi8>
%c0 = arith.constant 0 : index
vector.store %0, %alloc[%c0] : memref<4xi8>, vector<1xi8>
return
} This PR fixes this issue. Additional comments are added to clarify the Full diff: https://github.com/llvm/llvm-project/pull/133231.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 5debebd3218ed..7beb85f3a2a61 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -591,12 +591,12 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
// vector<4xi8>
auto origElements = valueToStore.getType().getNumElements();
- // Note, per-element-alignment was already verified above.
- bool isDivisibleInSize = origElements % emulatedPerContainerElem == 0;
auto stridedMetadata =
rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
+ // FIXME: ATM, we do not test cases where offsets, sizes, or strides are
+ // non-zero. As such, this is not needed.
OpFoldResult linearizedIndices;
memref::LinearizedMemRefInfo linearizedInfo;
std::tie(linearizedInfo, linearizedIndices) =
@@ -608,8 +608,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
getAsOpFoldResult(adaptor.getIndices()));
std::optional<int64_t> foldedNumFrontPadElems =
- isDivisibleInSize ? 0
- : getConstantIntValue(linearizedInfo.intraDataOffset);
+ getConstantIntValue(linearizedInfo.intraDataOffset);
if (!foldedNumFrontPadElems) {
return rewriter.notifyMatchFailure(
@@ -619,15 +618,39 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
auto memrefBase = cast<MemRefValue>(adaptor.getBase());
- // Conditions when atomic RMWs are not needed:
+ // RMWs are not needed when:
+ // * no _partial_ stores are required.
+ // A partial store is defined as a store in which only a part of the
+ // container element is overwritten, e.g.
+ //
+ // Dest before (8 bits)
+ // +----------+
+ // | 11000000 |
+ // +----------+
+ //
+ // Dest after storing 0xF at offset 4 (in bits)
+ // +----------+
+ // | 11001111 |
+ // +----------+
+ //
+ // At a higher level, this translats to:
// 1. The source vector size (in bits) is a multiple of byte size.
- // 2. The address of the store is aligned to the emulated width boundary.
+ // 2. The address of the store is aligned to the container type width
+ // boundary.
+ //
+ // EXAMPLE 1:
+ // Requires partial store:
+ // vector.store %arg0, %0[%c3] : memref<13xi2>, vector<4xi2>
//
- // For example, to store a vector<4xi2> to <13xi2> at offset 4, does not
- // need unaligned emulation because the store address is aligned and the
- // source is a whole byte.
- bool emulationRequiresPartialStores =
- !isDivisibleInSize || *foldedNumFrontPadElems != 0;
+ // EXAMPLE 2:
+ // Does not require a partial store:
+ // vector.store %arg0, %0[%c4] : memref<13xi2>, vector<4xi2>
+ //
+ // TODO: Take linearizedInfo.linearizedOffset into account. This is
+ // currently not needed/used/exercised as all our tests set offset to 0.
+ bool emulationRequiresPartialStores = *foldedNumFrontPadElems != 0;
+ emulationRequiresPartialStores = true;
+
if (!emulationRequiresPartialStores) {
// Basic case: storing full bytes.
auto numElements = origElements / emulatedPerContainerElem;
diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
index 6fc974200c6f3..40961349a4a62 100644
--- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
@@ -361,6 +361,75 @@ func.func @vector_maskedload_i2_constant_mask_unaligned(%passthru: vector<5xi2>)
/// vector.store
///----------------------------------------------------------------------------------------
+// -----
+
+// Most basic example to demonstrate where partial stores are not needed.
+
+func.func @vector_store_i2_const_index_no_partial_store(%arg0: vector<4xi2>) {
+ %0 = memref.alloc() : memref<13xi2>
+ %c4 = arith.constant 4 : index
+ vector.store %arg0, %0[%c4] : memref<13xi2>, vector<4xi2>
+ return
+}
+// CHECK-LABEL: func.func @vector_store_i2_const_index_no_partial_store(
+// CHECK-SAME: %[[ARG_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: vector<4xi2>) {
+// CHECK-NOT: memref.generic_atomic_rmw
+// CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<4xi8>
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[IDX:.*]] = arith.addi %[[C1]], %[[C1]] : index
+// CHECK: %[[UPCAST:.*]] = vector.bitcast %[[ARG_0]] : vector<4xi2> to vector<1xi8>
+// CHECK: vector.store %[[UPCAST]], %[[ALLOC]]{{\[}}%[[IDX]]] : memref<4xi8>, vector<1xi8>
+
+// -----
+
+// Small modification of the example above to demonstrate where partial stores
+// are needed.
+
+func.func @vector_store_i2_const_index_two_partial_stores(%arg0: vector<4xi2>) {
+ %0 = memref.alloc() : memref<13xi2>
+ %c3 = arith.constant 3 : index
+ vector.store %arg0, %0[%c3] : memref<13xi2>, vector<4xi2>
+ return
+}
+
+// CHECK-LABEL: func.func @vector_store_i2_const_index_two_partial_stores(
+// CHECK-SAME: %[[ARG_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: vector<4xi2>) {
+// CHECK: %[[VAL_1:.*]] = memref.alloc() : memref<4xi8>
+
+// First atomic RMW:
+// CHECK: %[[IDX_1:.*]] = arith.constant 0 : index
+// CHECK: %[[MASK_1:.*]] = arith.constant dense<[false, false, false, true]> : vector<4xi1>
+// CHECK: %[[INIT:.*]] = arith.constant dense<0> : vector<4xi2>
+// CHECK: %[[SLICE_1:.*]] = vector.extract_strided_slice %[[ARG_0]] {offsets = [0], sizes = [1], strides = [1]} : vector<4xi2> to vector<1xi2>
+// CHECK: %[[V1:.*]] = vector.insert_strided_slice %[[SLICE_1]], %[[INIT]] {offsets = [3], strides = [1]} : vector<1xi2> into vector<4xi2>
+// CHECK: memref.generic_atomic_rmw %[[VAL_1]]{{\[}}%[[IDX_1]]] : memref<4xi8> {
+// CHECK: ^bb0(%[[VAL_8:.*]]: i8):
+// CHECK: %[[VAL_9:.*]] = vector.from_elements %[[VAL_8]] : vector<1xi8>
+// CHECK: %[[DOWNCAST_1:.*]] = vector.bitcast %[[VAL_9]] : vector<1xi8> to vector<4xi2>
+// CHECK: %[[SELECT_1:.*]] = arith.select %[[MASK_1]], %[[V1]], %[[DOWNCAST_1]] : vector<4xi1>, vector<4xi2>
+// CHECK: %[[UPCAST_1:.*]] = vector.bitcast %[[SELECT_1]] : vector<4xi2> to vector<1xi8>
+// CHECK: %[[RES_1:.*]] = vector.extract %[[UPCAST_1]][0] : i8 from vector<1xi8>
+// CHECK: memref.atomic_yield %[[RES_1]] : i8
+// CHECK: }
+
+// Second atomic RMW:
+// CHECK: %[[VAL_14:.*]] = arith.constant 1 : index
+// CHECK: %[[IDX_2:.*]] = arith.addi %[[IDX_1]], %[[VAL_14]] : index
+// CHECK: %[[VAL_16:.*]] = vector.extract_strided_slice %[[ARG_0]] {offsets = [1], sizes = [3], strides = [1]} : vector<4xi2> to vector<3xi2>
+// CHECK: %[[V2:.*]] = vector.insert_strided_slice %[[VAL_16]], %[[INIT]] {offsets = [0], strides = [1]} : vector<3xi2> into vector<4xi2>
+// CHECK: %[[MASK_2:.*]] = arith.constant dense<[true, true, true, false]> : vector<4xi1>
+// CHECK: memref.generic_atomic_rmw %[[VAL_1]]{{\[}}%[[IDX_2]]] : memref<4xi8> {
+// CHECK: ^bb0(%[[VAL_20:.*]]: i8):
+// CHECK: %[[VAL_21:.*]] = vector.from_elements %[[VAL_20]] : vector<1xi8>
+// CHECK: %[[DONWCAST_2:.*]] = vector.bitcast %[[VAL_21]] : vector<1xi8> to vector<4xi2>
+// CHECK: %[[SELECT_2:.*]] = arith.select %[[MASK_2]], %[[V2]], %[[DONWCAST_2]] : vector<4xi1>, vector<4xi2>
+// CHECK: %[[UPCAST_2:.*]] = vector.bitcast %[[SELECT_2]] : vector<4xi2> to vector<1xi8>
+// CHECK: %[[RES_2:.*]] = vector.extract %[[UPCAST_2]][0] : i8 from vector<1xi8>
+// CHECK: memref.atomic_yield %[[RES_2]] : i8
+// CHECK: }
+
+// -----
+
func.func @vector_store_i2_const_index_two_partial_stores(%arg0: vector<3xi2>) {
%src = memref.alloc() : memref<3x3xi2>
%c0 = arith.constant 0 : index
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
eddf322
to
c210c2d
Compare
c210c2d
to
ef434aa
Compare
Below are two examples of "narrow" `vector.stores`. The first example does not require partial stores and hence no RMW stores. This is currently emulated correctly. ``` func.func @example_1(%arg0: vector<4xi2>) { %0 = memref.alloc() : memref<13xi2> %c4 = arith.constant 4 : index vector.store %arg0, %0[%c4] : memref<13xi2>, vector<4xi2> return } ``` The second example requires a partial (and hence RMW) store due to the non-zero offset. ``` func.func @example_2(%arg0: vector<4xi2>) { %0 = memref.alloc() : memref<13xi2> %c3 = arith.constant 3 : index vector.store %arg0, %0[%c3] : memref<13xi2>, vector<4xi2> return } ``` This is currently incorrectly emulated as a single "full" store (note that the offset is incorrect) instead of partial stores: ``` func.func @example_2(%arg0: vector<4xi2>) { %alloc = memref.alloc() : memref<4xi8> %0 = vector.bitcast %arg0 : vector<4xi2> to vector<1xi8> %c0 = arith.constant 0 : index vector.store %0, %alloc[%c0] : memref<4xi8>, vector<1xi8> return } ``` The incorrect emulation stems from this simplified (i.e. incomplete) calculation of the front padding: ```cpp std::optional<int64_t> foldedNumFrontPadElems = isDivisibleInSize ? 0 : getConstantIntValue(linearizedInfo.intraDataOffset); ``` Since `isDivisibleInSize` is `true` (i8 / i2 = 4): * front padding is set to `0` and, as a result, * the input offset (`%c3`) is ignored, and * we incorrectly assume that partial stores won't be needed. Note, however, that in `@example_2` we are storing `vector<4xi2>` into `memref<13xi2>` (note _different_ trailing dims) and hence partial stores might in fact be required. The condition above is updated to: ```cpp std::optional<int64_t> foldedNumFrontPadElems = (isDivisibleInSize && trailingDimsMatch) ? 0 : getConstantIntValue(linearizedInfo.intraDataOffset); ``` This change ensures that the input offset is properly taken into account, which fixes the issue. Additional comments are added to clarify the current logic.
ef434aa
to
52ebd48
Compare
@@ -593,10 +593,20 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> { | |||
auto origElements = valueToStore.getType().getNumElements(); | |||
// Note, per-element-alignment was already verified above. | |||
bool isDivisibleInSize = origElements % emulatedPerContainerElem == 0; | |||
// Do the trailing dim for source and destination match? If yes, then the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Trying to understand this. If both source and destination's last dims are the same, why would we want to skip this check? except the last dim is aligned?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why would we want to skip this check?
Are you referring to the calculation of foldedNumFrontPadElems?
If so, the reason is tied to a key assumption we're making here:
- The vector being stored is 1D (see this code).
Under that assumption, here’s one possible case:
vector.store %arg0, %0[%idx] : memref<4xi2>, vector<4xi2>
In this case, %idx
is 0. Assuming the "emulated" and "container" types are aligned, we definitely don’t need padding here - hence foldedNumFrontPadElems = 0
.
However, in a case like:
vector.store %arg0, %0[%idx] : memref<8xi2>, vector<4xi2>
%idx
could be 0 or 1, and if it's 1, padding is required even when "emulated" and "container" types are aligned.
That’s the nuance I was trying to capture.
Below are two examples of "narrow"
vector.stores
. The first exampledoes not require partial stores and hence no RMW stores. This is
currently emulated correctly.
The second example requires a partial (and hence RMW) store due to the
offset pointing outside the emulated type boundary (
%c3
).This is currently incorrectly emulated as a single "full" store (note
that the offset is incorrect) instead of partial stores:
The incorrect emulation stems from this simplified (i.e. incomplete)
calculation of the front padding:
Since
isDivisibleInSize
istrue
(i8 / i2 = 4):* front padding is set to
0
and, as a result,* the input offset (
%c3
) is ignored, and* we incorrectly assume that partial stores won't be needed.
Note that in both examples we are storing
vector<4xi2>
intomemref<13xi2>
(note different trailing dims) and hence partialstores might in fact be required. The condition above is updated to:
This change ensures that the input offset is properly taken into
account, which fixes the issue. It doesn't affect
@example1
.Additional comments are added to clarify the current logic.