Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][vector] Update helpers in VectorEmulateNarrowType.cpp (nfc) #131527

Conversation

banach-space
Copy link
Contributor

Refactors the following pairs of helper hooks:

  • dynamicallyInsertSubVector + staticallyInsertSubVector
  • dynamicallyExtractSubVector + staticallyExtractSubVector

These hooks are very similar, so I have unified the variable names and
various conditions to make the actual differences clearer.

Refactors the following pairs of helper hooks:
  * `dynamicallyInsertSubVector` + `staticallyInsertSubVector`
  * `dynamicallyExtractSubVector` + `staticallyExtractSubVector`

These hooks are very similar, so I have unified the variable names and
various conditions to make the actual differences clearer.
@llvmbot
Copy link
Member

llvmbot commented Mar 16, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-vector

Author: Andrzej Warzyński (banach-space)

Changes

Refactors the following pairs of helper hooks:

  • dynamicallyInsertSubVector + staticallyInsertSubVector
  • dynamicallyExtractSubVector + staticallyExtractSubVector

These hooks are very similar, so I have unified the variable names and
various conditions to make the actual differences clearer.


Full diff: https://github.com/llvm/llvm-project/pull/131527.diff

1 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp (+112-41)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index cf6efaa04ae44..456a83503ea8f 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -198,85 +198,156 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
   return *newMask;
 }
 
-/// Extracts 1-D subvector from a 1-D vector. It is a wrapper function for
-/// emitting `vector.extract_strided_slice`.
+/// Extracts 1-D subvector from a 1-D vector.
+///
+/// Given the input rank-1 source vector, extracts `numElemsToExtract` elements
+/// from `src`, starting at `offset`. The result is also a rank-1 vector:
+///
+///   vector<numElemsToExtract x !elType>
+///
+/// (`!elType` is the element type of the source vector). As `offset` is a known
+/// _static_ value, this helper hook emits `vector.extract_strided_slice`.
+///
+/// EXAMPLE:
+///     %res = vector.extract_strided_slice %src
+///       { offsets = [offset], sizes = [numElemsToExtract], strides = [1] }
 static Value staticallyExtractSubvector(OpBuilder &rewriter, Location loc,
-                                        Value source, int64_t frontOffset,
-                                        int64_t subvecSize) {
-  auto vectorType = cast<VectorType>(source.getType());
-  assert(vectorType.getRank() == 1 && "expected 1-D source types");
-  assert(frontOffset + subvecSize <= vectorType.getNumElements() &&
+                                        Value src, int64_t offset,
+                                        int64_t numElemsToExtract) {
+  auto vectorType = cast<VectorType>(src.getType());
+  assert(vectorType.getRank() == 1 && "expected source to be rank-1-D vector ");
+  assert(offset + numElemsToExtract <= vectorType.getNumElements() &&
          "subvector out of bounds");
 
-  // do not need extraction if the subvector size is the same as the source
-  if (vectorType.getNumElements() == subvecSize)
-    return source;
+  // When extracting all available elements, just use the source vector as the
+  // result.
+  if (vectorType.getNumElements() == numElemsToExtract)
+    return src;
 
-  auto offsets = rewriter.getI64ArrayAttr({frontOffset});
-  auto sizes = rewriter.getI64ArrayAttr({subvecSize});
+  auto offsets = rewriter.getI64ArrayAttr({offset});
+  auto sizes = rewriter.getI64ArrayAttr({numElemsToExtract});
   auto strides = rewriter.getI64ArrayAttr({1});
 
   auto resultVectorType =
-      VectorType::get({subvecSize}, vectorType.getElementType());
+      VectorType::get({numElemsToExtract}, vectorType.getElementType());
   return rewriter
-      .create<vector::ExtractStridedSliceOp>(loc, resultVectorType, source,
+      .create<vector::ExtractStridedSliceOp>(loc, resultVectorType, src,
                                              offsets, sizes, strides)
       ->getResult(0);
 }
 
-/// Inserts 1-D subvector into a 1-D vector by overwriting the elements starting
-/// at `offset`. it is a wrapper function for emitting
+/// Inserts 1-D subvector into a 1-D vector.
+///
+/// Inserts the input rank-1 source vector into the destination vector starting
+/// at `offset`. As `offset` is a known _static_ value, this helper hook emits
 /// `vector.insert_strided_slice`.
+///
+/// EXAMPLE:
+///   %res = vector.insert_strided_slice %src, %dest
+///     {offsets = [%offset], strides [1]}
 static Value staticallyInsertSubvector(OpBuilder &rewriter, Location loc,
                                        Value src, Value dest, int64_t offset) {
-  [[maybe_unused]] auto srcType = cast<VectorType>(src.getType());
-  [[maybe_unused]] auto destType = cast<VectorType>(dest.getType());
-  assert(srcType.getRank() == 1 && destType.getRank() == 1 &&
-         "expected source and dest to be vector type");
+  auto srcVecTy = cast<VectorType>(src.getType());
+  auto destVecTy = cast<VectorType>(dest.getType());
+  assert(srcVecTy.getRank() == 1 && destVecTy.getRank() == 1 &&
+         "expected source and dest to be rank-1 vector types");
+
+  // If overwritting the destination vector, just return the source.
+  if (srcVecTy.getNumElements() == destVecTy.getNumElements() && offset == 0)
+    return src;
+
   auto offsets = rewriter.getI64ArrayAttr({offset});
   auto strides = rewriter.getI64ArrayAttr({1});
-  return rewriter.create<vector::InsertStridedSliceOp>(loc, dest.getType(), src,
+  return rewriter.create<vector::InsertStridedSliceOp>(loc, destVecTy, src,
                                                        dest, offsets, strides);
 }
 
-/// Extracts a 1-D subvector from a 1-D `source` vector, with index at `offset`
-/// and size `numElementsToExtract`, and inserts into the `dest` vector. This
-/// function emits multiple `vector.extract` and `vector.insert` ops, so only
-/// use it when `offset` cannot be folded into a constant value.
+/// Extracts 1-D subvector from a 1-D vector.
+///
+/// Given the input rank-1 source vector, extracts `numElemsToExtact` elements
+/// from `src`, starting at `offset`. The result is also a rank-1 vector:
+///
+///   vector<numElemsToExtact x !elType>
+///
+/// (`!elType` is the element type of the source vector). As `offset` is assumed
+/// to be a _dynamic_ SSA value, this helper method generates a sequence of
+/// `vector.extract` + `vector.insert` pairs.
+///
+/// EXAMPLE:
+///     %v1 = vector.extract %src[%offset] : i2 from vector<8xi2>
+///     %r1 = vector.insert %v1, %dest[0] : i2 into vector<3xi2>
+///     %c1 = arith.constant 1 : index
+///     %idx2 = arith.addi %offset, %c1 : index
+///     %v2 = vector.extract %src[%idx2] : i2 from vector<8xi2>
+///     %r2 = vector.insert %v2, %r1 [1] : i2 into vector<3xi2>
+///     (...)
 static Value dynamicallyExtractSubVector(OpBuilder &rewriter, Location loc,
-                                         Value source, Value dest,
+                                         Value src, Value dest,
                                          OpFoldResult offset,
-                                         int64_t numElementsToExtract) {
-  assert(isa<VectorValue>(source) && "expected `source` to be a vector type");
-  for (int i = 0; i < numElementsToExtract; ++i) {
+                                         int64_t numElemsToExtract) {
+  auto srcVecTy = cast<VectorType>(src.getType());
+  assert(srcVecTy.getRank() == 1 && "expected source to be rank-1-D vector ");
+  // NOTE: We are unable to take the offset into account in the following
+  // assert, hence its still possible that the subvector is out-of-bounds even
+  // if the condition is true.
+  assert(numElemsToExtract <= srcVecTy.getNumElements() &&
+         "subvector out of bounds");
+
+  // When extracting all available elements, just use the source vector as the
+  // result.
+  if (srcVecTy.getNumElements() == numElemsToExtract)
+    return src;
+
+  for (int i = 0; i < numElemsToExtract; ++i) {
     Value extractLoc =
         (i == 0) ? offset.dyn_cast<Value>()
                  : rewriter.create<arith::AddIOp>(
                        loc, rewriter.getIndexType(), offset.dyn_cast<Value>(),
                        rewriter.create<arith::ConstantIndexOp>(loc, i));
-    auto extractOp =
-        rewriter.create<vector::ExtractOp>(loc, source, extractLoc);
+    auto extractOp = rewriter.create<vector::ExtractOp>(loc, src, extractLoc);
     dest = rewriter.create<vector::InsertOp>(loc, extractOp, dest, i);
   }
   return dest;
 }
 
-/// Inserts a 1-D subvector into a 1-D `dest` vector at index `destOffsetVar`.
+/// Inserts 1-D subvector into a 1-D vector.
+///
+/// Inserts the input rank-1 source vector into the destination vector starting
+/// at `offset`. As `offset` is assumed to be a _dynamic_ SSA value, this hook
+/// uses a sequence of `vector.extract` + `vector.insert` pairs.
+///
+/// EXAMPLE:
+///     %v1 = vector.extract %src[0] : i2 from vector<8xi2>
+///     %r1 = vector.insert %v1, %dest[%offset] : i2 into vector<3xi2>
+///     %c1 = arith.constant 1 : index
+///     %idx2 = arith.addi %offset, %c1 : index
+///     %v2 = vector.extract %src[1] : i2 from vector<8xi2>
+///     %r2 = vector.insert %v2, %r1 [%idx2] : i2 into vector<3xi2>
+///     (...)
 static Value dynamicallyInsertSubVector(RewriterBase &rewriter, Location loc,
-                                        Value source, Value dest,
-                                        OpFoldResult destOffsetVar,
-                                        size_t length) {
-  assert(isa<VectorValue>(source) && "expected `source` to be a vector type");
-  assert(length > 0 && "length must be greater than 0");
-  Value destOffsetVal =
-      getValueOrCreateConstantIndexOp(rewriter, loc, destOffsetVar);
-  for (size_t i = 0; i < length; ++i) {
+                                        Value src, Value dest,
+                                        OpFoldResult offset,
+                                        int64_t numElemsToInsert) {
+  auto srcVecTy = cast<VectorType>(src.getType());
+  auto destVecTy = cast<VectorType>(dest.getType());
+  assert(srcVecTy.getRank() == 1 && destVecTy.getRank() == 1 &&
+         "expected source and dest to be rank-1 vector types");
+  assert(numElemsToInsert > 0 &&
+         "the number of elements to insert must be greater than 0");
+  // NOTE: We are unable to take the offset into account in the following
+  // assert, hence its still possible that the subvector is out-of-bounds even
+  // if the condition is true.
+  assert(numElemsToInsert <= destVecTy.getNumElements() &&
+         "subvector out of bounds");
+
+  Value destOffsetVal = getValueOrCreateConstantIndexOp(rewriter, loc, offset);
+  for (int64_t i = 0; i < numElemsToInsert; ++i) {
     auto insertLoc = i == 0
                          ? destOffsetVal
                          : rewriter.create<arith::AddIOp>(
                                loc, rewriter.getIndexType(), destOffsetVal,
                                rewriter.create<arith::ConstantIndexOp>(loc, i));
-    auto extractOp = rewriter.create<vector::ExtractOp>(loc, source, i);
+    auto extractOp = rewriter.create<vector::ExtractOp>(loc, src, i);
     dest = rewriter.create<vector::InsertOp>(loc, extractOp, dest, insertLoc);
   }
   return dest;

@banach-space banach-space requested a review from lialan March 17, 2025 11:41
/// Given the input rank-1 source vector, extracts `numElemsToExtract` elements
/// from `src`, starting at `offset`. The result is also a rank-1 vector:
///
/// vector<numElemsToExtract x !elType>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

prefer elemType?

@banach-space banach-space merged commit 9768077 into llvm:main Mar 25, 2025
11 checks passed
@banach-space banach-space deleted the andrzej/narrow_type/unify_extract_insert_hooks branch March 25, 2025 13:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants