Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][vector] Update helpers in VectorEmulateNarrowType.cpp (nfc) #131527

Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 112 additions & 41 deletions mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -198,85 +198,156 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
return *newMask;
}

/// Extracts 1-D subvector from a 1-D vector. It is a wrapper function for
/// emitting `vector.extract_strided_slice`.
/// Extracts 1-D subvector from a 1-D vector.
///
/// Given the input rank-1 source vector, extracts `numElemsToExtract` elements
/// from `src`, starting at `offset`. The result is also a rank-1 vector:
///
/// vector<numElemsToExtract x !elemType>
///
/// (`!elType` is the element type of the source vector). As `offset` is a known
/// _static_ value, this helper hook emits `vector.extract_strided_slice`.
///
/// EXAMPLE:
/// %res = vector.extract_strided_slice %src
/// { offsets = [offset], sizes = [numElemsToExtract], strides = [1] }
static Value staticallyExtractSubvector(OpBuilder &rewriter, Location loc,
Value source, int64_t frontOffset,
int64_t subvecSize) {
auto vectorType = cast<VectorType>(source.getType());
assert(vectorType.getRank() == 1 && "expected 1-D source types");
assert(frontOffset + subvecSize <= vectorType.getNumElements() &&
Value src, int64_t offset,
int64_t numElemsToExtract) {
auto vectorType = cast<VectorType>(src.getType());
assert(vectorType.getRank() == 1 && "expected source to be rank-1-D vector ");
assert(offset + numElemsToExtract <= vectorType.getNumElements() &&
"subvector out of bounds");

// do not need extraction if the subvector size is the same as the source
if (vectorType.getNumElements() == subvecSize)
return source;
// When extracting all available elements, just use the source vector as the
// result.
if (vectorType.getNumElements() == numElemsToExtract)
return src;

auto offsets = rewriter.getI64ArrayAttr({frontOffset});
auto sizes = rewriter.getI64ArrayAttr({subvecSize});
auto offsets = rewriter.getI64ArrayAttr({offset});
auto sizes = rewriter.getI64ArrayAttr({numElemsToExtract});
auto strides = rewriter.getI64ArrayAttr({1});

auto resultVectorType =
VectorType::get({subvecSize}, vectorType.getElementType());
VectorType::get({numElemsToExtract}, vectorType.getElementType());
return rewriter
.create<vector::ExtractStridedSliceOp>(loc, resultVectorType, source,
.create<vector::ExtractStridedSliceOp>(loc, resultVectorType, src,
offsets, sizes, strides)
->getResult(0);
}

/// Inserts 1-D subvector into a 1-D vector by overwriting the elements starting
/// at `offset`. it is a wrapper function for emitting
/// Inserts 1-D subvector into a 1-D vector.
///
/// Inserts the input rank-1 source vector into the destination vector starting
/// at `offset`. As `offset` is a known _static_ value, this helper hook emits
/// `vector.insert_strided_slice`.
///
/// EXAMPLE:
/// %res = vector.insert_strided_slice %src, %dest
/// {offsets = [%offset], strides [1]}
static Value staticallyInsertSubvector(OpBuilder &rewriter, Location loc,
Value src, Value dest, int64_t offset) {
[[maybe_unused]] auto srcType = cast<VectorType>(src.getType());
[[maybe_unused]] auto destType = cast<VectorType>(dest.getType());
assert(srcType.getRank() == 1 && destType.getRank() == 1 &&
"expected source and dest to be vector type");
auto srcVecTy = cast<VectorType>(src.getType());
auto destVecTy = cast<VectorType>(dest.getType());
assert(srcVecTy.getRank() == 1 && destVecTy.getRank() == 1 &&
"expected source and dest to be rank-1 vector types");

// If overwritting the destination vector, just return the source.
if (srcVecTy.getNumElements() == destVecTy.getNumElements() && offset == 0)
return src;

auto offsets = rewriter.getI64ArrayAttr({offset});
auto strides = rewriter.getI64ArrayAttr({1});
return rewriter.create<vector::InsertStridedSliceOp>(loc, dest.getType(), src,
return rewriter.create<vector::InsertStridedSliceOp>(loc, destVecTy, src,
dest, offsets, strides);
}

/// Extracts a 1-D subvector from a 1-D `source` vector, with index at `offset`
/// and size `numElementsToExtract`, and inserts into the `dest` vector. This
/// function emits multiple `vector.extract` and `vector.insert` ops, so only
/// use it when `offset` cannot be folded into a constant value.
/// Extracts 1-D subvector from a 1-D vector.
///
/// Given the input rank-1 source vector, extracts `numElemsToExtact` elements
/// from `src`, starting at `offset`. The result is also a rank-1 vector:
///
/// vector<numElemsToExtact x !elType>
///
/// (`!elType` is the element type of the source vector). As `offset` is assumed
/// to be a _dynamic_ SSA value, this helper method generates a sequence of
/// `vector.extract` + `vector.insert` pairs.
///
/// EXAMPLE:
/// %v1 = vector.extract %src[%offset] : i2 from vector<8xi2>
/// %r1 = vector.insert %v1, %dest[0] : i2 into vector<3xi2>
/// %c1 = arith.constant 1 : index
/// %idx2 = arith.addi %offset, %c1 : index
/// %v2 = vector.extract %src[%idx2] : i2 from vector<8xi2>
/// %r2 = vector.insert %v2, %r1 [1] : i2 into vector<3xi2>
/// (...)
static Value dynamicallyExtractSubVector(OpBuilder &rewriter, Location loc,
Value source, Value dest,
Value src, Value dest,
OpFoldResult offset,
int64_t numElementsToExtract) {
assert(isa<VectorValue>(source) && "expected `source` to be a vector type");
for (int i = 0; i < numElementsToExtract; ++i) {
int64_t numElemsToExtract) {
auto srcVecTy = cast<VectorType>(src.getType());
assert(srcVecTy.getRank() == 1 && "expected source to be rank-1-D vector ");
// NOTE: We are unable to take the offset into account in the following
// assert, hence its still possible that the subvector is out-of-bounds even
// if the condition is true.
assert(numElemsToExtract <= srcVecTy.getNumElements() &&
"subvector out of bounds");

// When extracting all available elements, just use the source vector as the
// result.
if (srcVecTy.getNumElements() == numElemsToExtract)
return src;

for (int i = 0; i < numElemsToExtract; ++i) {
Value extractLoc =
(i == 0) ? offset.dyn_cast<Value>()
: rewriter.create<arith::AddIOp>(
loc, rewriter.getIndexType(), offset.dyn_cast<Value>(),
rewriter.create<arith::ConstantIndexOp>(loc, i));
auto extractOp =
rewriter.create<vector::ExtractOp>(loc, source, extractLoc);
auto extractOp = rewriter.create<vector::ExtractOp>(loc, src, extractLoc);
dest = rewriter.create<vector::InsertOp>(loc, extractOp, dest, i);
}
return dest;
}

/// Inserts a 1-D subvector into a 1-D `dest` vector at index `destOffsetVar`.
/// Inserts 1-D subvector into a 1-D vector.
///
/// Inserts the input rank-1 source vector into the destination vector starting
/// at `offset`. As `offset` is assumed to be a _dynamic_ SSA value, this hook
/// uses a sequence of `vector.extract` + `vector.insert` pairs.
///
/// EXAMPLE:
/// %v1 = vector.extract %src[0] : i2 from vector<8xi2>
/// %r1 = vector.insert %v1, %dest[%offset] : i2 into vector<3xi2>
/// %c1 = arith.constant 1 : index
/// %idx2 = arith.addi %offset, %c1 : index
/// %v2 = vector.extract %src[1] : i2 from vector<8xi2>
/// %r2 = vector.insert %v2, %r1 [%idx2] : i2 into vector<3xi2>
/// (...)
static Value dynamicallyInsertSubVector(RewriterBase &rewriter, Location loc,
Value source, Value dest,
OpFoldResult destOffsetVar,
size_t length) {
assert(isa<VectorValue>(source) && "expected `source` to be a vector type");
assert(length > 0 && "length must be greater than 0");
Value destOffsetVal =
getValueOrCreateConstantIndexOp(rewriter, loc, destOffsetVar);
for (size_t i = 0; i < length; ++i) {
Value src, Value dest,
OpFoldResult offset,
int64_t numElemsToInsert) {
auto srcVecTy = cast<VectorType>(src.getType());
auto destVecTy = cast<VectorType>(dest.getType());
assert(srcVecTy.getRank() == 1 && destVecTy.getRank() == 1 &&
"expected source and dest to be rank-1 vector types");
assert(numElemsToInsert > 0 &&
"the number of elements to insert must be greater than 0");
// NOTE: We are unable to take the offset into account in the following
// assert, hence its still possible that the subvector is out-of-bounds even
// if the condition is true.
assert(numElemsToInsert <= destVecTy.getNumElements() &&
"subvector out of bounds");

Value destOffsetVal = getValueOrCreateConstantIndexOp(rewriter, loc, offset);
for (int64_t i = 0; i < numElemsToInsert; ++i) {
auto insertLoc = i == 0
? destOffsetVal
: rewriter.create<arith::AddIOp>(
loc, rewriter.getIndexType(), destOffsetVal,
rewriter.create<arith::ConstantIndexOp>(loc, i));
auto extractOp = rewriter.create<vector::ExtractOp>(loc, source, i);
auto extractOp = rewriter.create<vector::ExtractOp>(loc, src, i);
dest = rewriter.create<vector::InsertOp>(loc, extractOp, dest, insertLoc);
}
return dest;
Expand Down