Skip to content

Commit

Permalink
[mlir] ArithToLLVM: fix memref bitcast lowering
Browse files Browse the repository at this point in the history
arith.bitcast is allowed on memrefs and such code can actually be generated by IREE `ConvertBf16ArithToF32Pass`.
`LLVM::detail::vectorOneToOneRewrite` doesn't properly check its types and will generate bitcast between structs which is illegal.
With the opaque pointers this is a no-op operation for memref so we can just add type check in `LLVM::detail::vectorOneToOneRewrite` and add a separate pattern which removes op if converted types are the same.
  • Loading branch information
Hardcode84 committed Jan 31, 2025
1 parent bce2cc1 commit 2fc606a
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 1 deletion.
20 changes: 20 additions & 0 deletions mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,23 @@ struct ConstrainedVectorConvertToLLVMPattern
}
};

/// No-op bitcast.
struct IdentityBitcastLowering final
: public OpConversionPattern<arith::BitcastOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(arith::BitcastOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
Value src = adaptor.getIn();
if (src.getType() != getTypeConverter()->convertType(op.getType()))
return rewriter.notifyMatchFailure(op, "Types are different");

rewriter.replaceOp(op, src);
return success();
}
};

//===----------------------------------------------------------------------===//
// Straightforward Op Lowerings
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -524,6 +541,9 @@ void mlir::arith::registerConvertArithToLLVMInterface(

void mlir::arith::populateArithToLLVMConversionPatterns(
const LLVMTypeConverter &converter, RewritePatternSet &patterns) {

patterns.add<IdentityBitcastLowering>(converter, patterns.getContext());

// clang-format off
patterns.add<
AddFOpLowering,
Expand Down
7 changes: 6 additions & 1 deletion mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,11 @@ LogicalResult LLVM::detail::handleMultidimensionalVectors(
return success();
}

static bool isVectorCompatibleType(Type type) {
return isa<LLVM::LLVMArrayType, VectorType, IntegerType, FloatType>(type) &&
LLVM::isCompatibleType(type);
}

LogicalResult LLVM::detail::vectorOneToOneRewrite(
Operation *op, StringRef targetOp, ValueRange operands,
ArrayRef<NamedAttribute> targetAttrs,
Expand All @@ -111,7 +116,7 @@ LogicalResult LLVM::detail::vectorOneToOneRewrite(
assert(!operands.empty());

// Cannot convert ops if their operands are not of LLVM type.
if (!llvm::all_of(operands.getTypes(), isCompatibleType))
if (!llvm::all_of(operands.getTypes(), isVectorCompatibleType))
return failure();

auto llvmNDVectorTy = operands[0].getType();
Expand Down
12 changes: 12 additions & 0 deletions mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -727,3 +727,15 @@ func.func @ops_supporting_overflow(%arg0: i64, %arg1: i64) {
%3 = arith.shli %arg0, %arg1 overflow<nsw, nuw> : i64
return
}

// -----

// CHECK-LABEL: func @memref_bitcast
// CHECK-SAME: (%[[ARG:.*]]: memref<?xi16>)
// CHECK: %[[V1:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : memref<?xi16> to !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: %[[V2:.*]] = builtin.unrealized_conversion_cast %[[V1]] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> to memref<?xbf16>
// CHECK: return %[[V2]]
func.func @memref_bitcast(%1: memref<?xi16>) -> memref<?xbf16> {
%2 = arith.bitcast %1 : memref<?xi16> to memref<?xbf16>
func.return %2 : memref<?xbf16>
}

0 comments on commit 2fc606a

Please sign in to comment.