Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][vector] Fix emulation of "narrow" type vector.store #133231

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

banach-space
Copy link
Contributor

@banach-space banach-space commented Mar 27, 2025

Below are two examples of "narrow" vector.stores. The first example
does not require partial stores and hence no RMW stores. This is
currently emulated correctly.

func.func @example_1(%arg0: vector<4xi2>) {
    %0 = memref.alloc() : memref<13xi2>
    %c4 = arith.constant 4 : index
    vector.store %arg0, %0[%c4] : memref<13xi2>, vector<4xi2>
    return
}

The second example requires a partial (and hence RMW) store due to the
offset pointing outside the emulated type boundary (%c3).

func.func @example_2(%arg0: vector<4xi2>) {
    %0 = memref.alloc() : memref<13xi2>
    %c3 = arith.constant 3 : index
    vector.store %arg0, %0[%c3] : memref<13xi2>, vector<4xi2>
    return
}

This is currently incorrectly emulated as a single "full" store (note
that the offset is incorrect) instead of partial stores:

func.func @example_2(%arg0: vector<4xi2>) {
  %alloc = memref.alloc() : memref<4xi8>
  %0 = vector.bitcast %arg0 : vector<4xi2> to vector<1xi8>
  %c0 = arith.constant 0 : index
  vector.store %0, %alloc[%c0] : memref<4xi8>, vector<1xi8>
  return
}

The incorrect emulation stems from this simplified (i.e. incomplete)
calculation of the front padding:

    std::optional<int64_t> foldedNumFrontPadElems =
        isDivisibleInSize ? 0
                          : getConstantIntValue(linearizedInfo.intraDataOffset);

Since isDivisibleInSize is true (i8 / i2 = 4):
* front padding is set to 0 and, as a result,
* the input offset (%c3) is ignored, and
* we incorrectly assume that partial stores won't be needed.

Note that in both examples we are storing vector<4xi2> into
memref<13xi2> (note different trailing dims) and hence partial
stores might in fact be required. The condition above is updated to:

    std::optional<int64_t> foldedNumFrontPadElems =
        (isDivisibleInSize && trailingDimsMatch)
            ? 0
            : getConstantIntValue(linearizedInfo.intraDataOffset);

This change ensures that the input offset is properly taken into
account, which fixes the issue. It doesn't affect @example1.

Additional comments are added to clarify the current logic.

@llvmbot
Copy link
Member

llvmbot commented Mar 27, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-vector

Author: Andrzej Warzyński (banach-space)

Changes

Below are two examples of "narrow" vector.stores. The first example
does not require partial stores and hence no RMW stores. This is
currently emulated correctly.

func.func @<!-- -->example_1(%arg0: vector&lt;4xi2&gt;) {
    %0 = memref.alloc() : memref&lt;13xi2&gt;
    %c4 = arith.constant 4 : index
    vector.store %arg0, %0[%c4] : memref&lt;13xi2&gt;, vector&lt;4xi2&gt;
    return
}

The second example below does require a partial store (due to the
offset) and hence a RMW store.

func.func @<!-- -->example_2(%arg0: vector&lt;4xi2&gt;) {
    %0 = memref.alloc() : memref&lt;13xi2&gt;
    %c3 = arith.constant 3 : index
    vector.store %arg0, %0[%c3] : memref&lt;13xi2&gt;, vector&lt;4xi2&gt;
    return
}

This is currently incorrectly emulated as a single "full" store (note
that the offset is incorrect):

func.func @<!-- -->example_2(%arg0: vector&lt;4xi2&gt;) {
  %alloc = memref.alloc() : memref&lt;4xi8&gt;
  %0 = vector.bitcast %arg0 : vector&lt;4xi2&gt; to vector&lt;1xi8&gt;
  %c0 = arith.constant 0 : index
  vector.store %0, %alloc[%c0] : memref&lt;4xi8&gt;, vector&lt;1xi8&gt;
  return
}

This PR fixes this issue. Additional comments are added to clarify the
current logic.


Full diff: https://github.com/llvm/llvm-project/pull/133231.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp (+34-11)
  • (modified) mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir (+69)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 5debebd3218ed..7beb85f3a2a61 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -591,12 +591,12 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
     // vector<4xi8>
 
     auto origElements = valueToStore.getType().getNumElements();
-    // Note, per-element-alignment was already verified above.
-    bool isDivisibleInSize = origElements % emulatedPerContainerElem == 0;
 
     auto stridedMetadata =
         rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
 
+    // FIXME: ATM, we do not test cases where offsets, sizes, or strides are
+    // non-zero. As such, this is not needed.
     OpFoldResult linearizedIndices;
     memref::LinearizedMemRefInfo linearizedInfo;
     std::tie(linearizedInfo, linearizedIndices) =
@@ -608,8 +608,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
             getAsOpFoldResult(adaptor.getIndices()));
 
     std::optional<int64_t> foldedNumFrontPadElems =
-        isDivisibleInSize ? 0
-                          : getConstantIntValue(linearizedInfo.intraDataOffset);
+                          getConstantIntValue(linearizedInfo.intraDataOffset);
 
     if (!foldedNumFrontPadElems) {
       return rewriter.notifyMatchFailure(
@@ -619,15 +618,39 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
 
     auto memrefBase = cast<MemRefValue>(adaptor.getBase());
 
-    // Conditions when atomic RMWs are not needed:
+    // RMWs are not needed when:
+    //  * no _partial_ stores are required.
+    // A partial store is defined as a store in which only a part of the
+    // container element is overwritten, e.g.
+    //
+    //    Dest before (8 bits)
+    //        +----------+
+    //        | 11000000 |
+    //        +----------+
+    //
+    //    Dest after storing 0xF at offset 4 (in bits)
+    //        +----------+
+    //        | 11001111 |
+    //        +----------+
+    //
+    // At a higher level, this translats to:
     // 1. The source vector size (in bits) is a multiple of byte size.
-    // 2. The address of the store is aligned to the emulated width boundary.
+    // 2. The address of the store is aligned to the container type width
+    //    boundary.
+    //
+    // EXAMPLE 1:
+    //  Requires partial store:
+    //    vector.store %arg0, %0[%c3] : memref<13xi2>, vector<4xi2>
     //
-    // For example, to store a vector<4xi2> to <13xi2> at offset 4, does not
-    // need unaligned emulation because the store address is aligned and the
-    // source is a whole byte.
-    bool emulationRequiresPartialStores =
-        !isDivisibleInSize || *foldedNumFrontPadElems != 0;
+    // EXAMPLE 2:
+    //  Does not require a partial store:
+    //    vector.store %arg0, %0[%c4] : memref<13xi2>, vector<4xi2>
+    //
+    // TODO: Take linearizedInfo.linearizedOffset into account. This is
+    // currently not needed/used/exercised as all our tests set offset to 0.
+    bool emulationRequiresPartialStores = *foldedNumFrontPadElems != 0;
+    emulationRequiresPartialStores = true;
+
     if (!emulationRequiresPartialStores) {
       // Basic case: storing full bytes.
       auto numElements = origElements / emulatedPerContainerElem;
diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
index 6fc974200c6f3..40961349a4a62 100644
--- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
@@ -361,6 +361,75 @@ func.func @vector_maskedload_i2_constant_mask_unaligned(%passthru: vector<5xi2>)
 /// vector.store
 ///----------------------------------------------------------------------------------------
 
+// -----
+
+// Most basic example to demonstrate where partial stores are not needed.
+
+func.func @vector_store_i2_const_index_no_partial_store(%arg0: vector<4xi2>) {
+    %0 = memref.alloc() : memref<13xi2>
+    %c4 = arith.constant 4 : index
+    vector.store %arg0, %0[%c4] : memref<13xi2>, vector<4xi2>
+    return
+}
+// CHECK-LABEL:   func.func @vector_store_i2_const_index_no_partial_store(
+// CHECK-SAME:      %[[ARG_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: vector<4xi2>) {
+// CHECK-NOT:       memref.generic_atomic_rmw
+// CHECK:           %[[ALLOC:.*]] = memref.alloc() : memref<4xi8>
+// CHECK:           %[[C1:.*]] = arith.constant 1 : index
+// CHECK:           %[[IDX:.*]] = arith.addi %[[C1]], %[[C1]] : index
+// CHECK:           %[[UPCAST:.*]] = vector.bitcast %[[ARG_0]] : vector<4xi2> to vector<1xi8>
+// CHECK:           vector.store %[[UPCAST]], %[[ALLOC]]{{\[}}%[[IDX]]] : memref<4xi8>, vector<1xi8>
+
+// -----
+
+// Small modification of the example above to demonstrate where partial stores
+// are needed.
+
+func.func @vector_store_i2_const_index_two_partial_stores(%arg0: vector<4xi2>) {
+    %0 = memref.alloc() : memref<13xi2>
+    %c3 = arith.constant 3 : index
+    vector.store %arg0, %0[%c3] : memref<13xi2>, vector<4xi2>
+    return
+}
+
+// CHECK-LABEL:   func.func @vector_store_i2_const_index_two_partial_stores(
+// CHECK-SAME:      %[[ARG_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: vector<4xi2>) {
+// CHECK:           %[[VAL_1:.*]] = memref.alloc() : memref<4xi8>
+
+// First atomic RMW:
+// CHECK:           %[[IDX_1:.*]] = arith.constant 0 : index
+// CHECK:           %[[MASK_1:.*]] = arith.constant dense<[false, false, false, true]> : vector<4xi1>
+// CHECK:           %[[INIT:.*]] = arith.constant dense<0> : vector<4xi2>
+// CHECK:           %[[SLICE_1:.*]] = vector.extract_strided_slice %[[ARG_0]] {offsets = [0], sizes = [1], strides = [1]} : vector<4xi2> to vector<1xi2>
+// CHECK:           %[[V1:.*]] = vector.insert_strided_slice %[[SLICE_1]], %[[INIT]] {offsets = [3], strides = [1]} : vector<1xi2> into vector<4xi2>
+// CHECK:           memref.generic_atomic_rmw %[[VAL_1]]{{\[}}%[[IDX_1]]] : memref<4xi8> {
+// CHECK:           ^bb0(%[[VAL_8:.*]]: i8):
+// CHECK:             %[[VAL_9:.*]] = vector.from_elements %[[VAL_8]] : vector<1xi8>
+// CHECK:             %[[DOWNCAST_1:.*]] = vector.bitcast %[[VAL_9]] : vector<1xi8> to vector<4xi2>
+// CHECK:             %[[SELECT_1:.*]] = arith.select %[[MASK_1]], %[[V1]], %[[DOWNCAST_1]] : vector<4xi1>, vector<4xi2>
+// CHECK:             %[[UPCAST_1:.*]] = vector.bitcast %[[SELECT_1]] : vector<4xi2> to vector<1xi8>
+// CHECK:             %[[RES_1:.*]] = vector.extract %[[UPCAST_1]][0] : i8 from vector<1xi8>
+// CHECK:             memref.atomic_yield %[[RES_1]] : i8
+// CHECK:           }
+
+// Second atomic RMW:
+// CHECK:           %[[VAL_14:.*]] = arith.constant 1 : index
+// CHECK:           %[[IDX_2:.*]] = arith.addi %[[IDX_1]], %[[VAL_14]] : index
+// CHECK:           %[[VAL_16:.*]] = vector.extract_strided_slice %[[ARG_0]] {offsets = [1], sizes = [3], strides = [1]} : vector<4xi2> to vector<3xi2>
+// CHECK:           %[[V2:.*]] = vector.insert_strided_slice %[[VAL_16]], %[[INIT]] {offsets = [0], strides = [1]} : vector<3xi2> into vector<4xi2>
+// CHECK:           %[[MASK_2:.*]] = arith.constant dense<[true, true, true, false]> : vector<4xi1>
+// CHECK:            memref.generic_atomic_rmw %[[VAL_1]]{{\[}}%[[IDX_2]]] : memref<4xi8> {
+// CHECK:           ^bb0(%[[VAL_20:.*]]: i8):
+// CHECK:             %[[VAL_21:.*]] = vector.from_elements %[[VAL_20]] : vector<1xi8>
+// CHECK:             %[[DONWCAST_2:.*]] = vector.bitcast %[[VAL_21]] : vector<1xi8> to vector<4xi2>
+// CHECK:             %[[SELECT_2:.*]] = arith.select %[[MASK_2]], %[[V2]], %[[DONWCAST_2]] : vector<4xi1>, vector<4xi2>
+// CHECK:             %[[UPCAST_2:.*]] = vector.bitcast %[[SELECT_2]] : vector<4xi2> to vector<1xi8>
+// CHECK:             %[[RES_2:.*]] = vector.extract %[[UPCAST_2]][0] : i8 from vector<1xi8>
+// CHECK:             memref.atomic_yield %[[RES_2]] : i8
+// CHECK:           }
+
+// -----
+
 func.func @vector_store_i2_const_index_two_partial_stores(%arg0: vector<3xi2>) {
     %src = memref.alloc() : memref<3x3xi2>
     %c0 = arith.constant 0 : index

Copy link

github-actions bot commented Mar 27, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@banach-space banach-space force-pushed the andrzej/narrow_type/fix branch from eddf322 to c210c2d Compare March 27, 2025 12:27
@banach-space banach-space marked this pull request as draft March 27, 2025 12:44
@banach-space banach-space force-pushed the andrzej/narrow_type/fix branch from c210c2d to ef434aa Compare March 27, 2025 13:28
@banach-space banach-space marked this pull request as ready for review March 27, 2025 14:13
Below are two examples of "narrow" `vector.stores`. The first example
does not require partial stores and hence no RMW stores. This is
currently emulated correctly.
```
func.func @example_1(%arg0: vector<4xi2>) {
    %0 = memref.alloc() : memref<13xi2>
    %c4 = arith.constant 4 : index
    vector.store %arg0, %0[%c4] : memref<13xi2>, vector<4xi2>
    return
}
```

The second example requires a partial (and hence RMW) store due to the
non-zero offset.
```
func.func @example_2(%arg0: vector<4xi2>) {
    %0 = memref.alloc() : memref<13xi2>
    %c3 = arith.constant 3 : index
    vector.store %arg0, %0[%c3] : memref<13xi2>, vector<4xi2>
    return
}
```

This is currently incorrectly emulated as a single "full" store (note
that the offset is incorrect) instead of partial stores:
```
func.func @example_2(%arg0: vector<4xi2>) {
  %alloc = memref.alloc() : memref<4xi8>
  %0 = vector.bitcast %arg0 : vector<4xi2> to vector<1xi8>
  %c0 = arith.constant 0 : index
  vector.store %0, %alloc[%c0] : memref<4xi8>, vector<1xi8>
  return
}
```

The incorrect emulation stems from this simplified (i.e. incomplete)
calculation of the front padding:
```cpp
    std::optional<int64_t> foldedNumFrontPadElems =
        isDivisibleInSize ? 0
                          : getConstantIntValue(linearizedInfo.intraDataOffset);
```

Since `isDivisibleInSize` is `true` (i8 / i2 = 4):
  * front padding is set to `0` and, as a result,
  * the input offset (`%c3`) is ignored, and
  * we incorrectly assume that partial stores won't be needed.

Note, however, that in `@example_2` we are storing `vector<4xi2>` into
`memref<13xi2>` (note _different_ trailing dims) and hence partial
stores might in fact be required. The condition above is updated to:
```cpp
    std::optional<int64_t> foldedNumFrontPadElems =
        (isDivisibleInSize && trailingDimsMatch)
            ? 0
            : getConstantIntValue(linearizedInfo.intraDataOffset);
```

This change ensures that the input offset is properly taken into
account, which fixes the issue.

Additional comments are added to clarify the current logic.
@banach-space banach-space force-pushed the andrzej/narrow_type/fix branch from ef434aa to 52ebd48 Compare March 31, 2025 13:20
@@ -593,10 +593,20 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
auto origElements = valueToStore.getType().getNumElements();
// Note, per-element-alignment was already verified above.
bool isDivisibleInSize = origElements % emulatedPerContainerElem == 0;
// Do the trailing dim for source and destination match? If yes, then the
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Trying to understand this. If both source and destination's last dims are the same, why would we want to skip this check? except the last dim is aligned?

Copy link
Contributor Author

@banach-space banach-space Apr 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why would we want to skip this check?

Are you referring to the calculation of foldedNumFrontPadElems?

If so, the reason is tied to a key assumption we're making here:

  • The vector being stored is 1D (see this code).

Under that assumption, here’s one possible case:

vector.store %arg0, %0[%idx] : memref<4xi2>, vector<4xi2>

In this case, %idx is 0. Assuming the "emulated" and "container" types are aligned, we definitely don’t need padding here - hence foldedNumFrontPadElems = 0.

However, in a case like:

vector.store %arg0, %0[%idx] : memref<8xi2>, vector<4xi2>

%idx could be 0 or 1, and if it's 1, padding is required even when "emulated" and "container" types are aligned.

That’s the nuance I was trying to capture.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants