Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir] ArithToLLVM: fix memref bitcast lowering #125148

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

Hardcode84
Copy link
Contributor

arith.bitcast is allowed on memrefs and such code can actually be generated by IREE ConvertBf16ArithToF32Pass. LLVM::detail::vectorOneToOneRewrite doesn't properly check its types and will generate bitcast between structs which is illegal.

With the opaque pointers this is a no-op operation for memref so we can just add type check in LLVM::detail::vectorOneToOneRewrite and add a separate pattern which removes op if converted types are the same.

arith.bitcast is allowed on memrefs and such code can actually be generated by IREE `ConvertBf16ArithToF32Pass`.
`LLVM::detail::vectorOneToOneRewrite` doesn't properly check its types and will generate bitcast between structs which is illegal.
With the opaque pointers this is a no-op operation for memref so we can just add type check in `LLVM::detail::vectorOneToOneRewrite` and add a separate pattern which removes op if converted types are the same.
@llvmbot
Copy link
Member

llvmbot commented Jan 31, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-llvm

Author: Ivan Butygin (Hardcode84)

Changes

arith.bitcast is allowed on memrefs and such code can actually be generated by IREE ConvertBf16ArithToF32Pass. LLVM::detail::vectorOneToOneRewrite doesn't properly check its types and will generate bitcast between structs which is illegal.

With the opaque pointers this is a no-op operation for memref so we can just add type check in LLVM::detail::vectorOneToOneRewrite and add a separate pattern which removes op if converted types are the same.


Full diff: https://github.com/llvm/llvm-project/pull/125148.diff

3 Files Affected:

  • (modified) mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp (+20)
  • (modified) mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp (+6-1)
  • (modified) mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir (+12)
diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index 754ed898142936..b726faa92a03a0 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -54,6 +54,23 @@ struct ConstrainedVectorConvertToLLVMPattern
   }
 };
 
+/// No-op bitcast.
+struct IdentityBitcastLowering final
+    : public OpConversionPattern<arith::BitcastOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(arith::BitcastOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const final {
+    Value src = adaptor.getIn();
+    if (src.getType() != getTypeConverter()->convertType(op.getType()))
+      return rewriter.notifyMatchFailure(op, "Types are different");
+
+    rewriter.replaceOp(op, src);
+    return success();
+  }
+};
+
 //===----------------------------------------------------------------------===//
 // Straightforward Op Lowerings
 //===----------------------------------------------------------------------===//
@@ -524,6 +541,9 @@ void mlir::arith::registerConvertArithToLLVMInterface(
 
 void mlir::arith::populateArithToLLVMConversionPatterns(
     const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
+
+  patterns.add<IdentityBitcastLowering>(converter, patterns.getContext());
+
   // clang-format off
   patterns.add<
     AddFOpLowering,
diff --git a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
index 626135c10a3e96..c9d3b57b0d596e 100644
--- a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
@@ -103,6 +103,11 @@ LogicalResult LLVM::detail::handleMultidimensionalVectors(
   return success();
 }
 
+static bool isVectorCompatibleType(Type type) {
+  return isa<LLVM::LLVMArrayType, VectorType, IntegerType, FloatType>(type) &&
+         LLVM::isCompatibleType(type);
+}
+
 LogicalResult LLVM::detail::vectorOneToOneRewrite(
     Operation *op, StringRef targetOp, ValueRange operands,
     ArrayRef<NamedAttribute> targetAttrs,
@@ -111,7 +116,7 @@ LogicalResult LLVM::detail::vectorOneToOneRewrite(
   assert(!operands.empty());
 
   // Cannot convert ops if their operands are not of LLVM type.
-  if (!llvm::all_of(operands.getTypes(), isCompatibleType))
+  if (!llvm::all_of(operands.getTypes(), isVectorCompatibleType))
     return failure();
 
   auto llvmNDVectorTy = operands[0].getType();
diff --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
index 1dabacfd8a47cc..9a6c4bca88f3bf 100644
--- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
+++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
@@ -727,3 +727,15 @@ func.func @ops_supporting_overflow(%arg0: i64, %arg1: i64) {
   %3 = arith.shli %arg0, %arg1 overflow<nsw, nuw> : i64
   return
 }
+
+// -----
+
+// CHECK-LABEL: func @memref_bitcast
+//  CHECK-SAME:   (%[[ARG:.*]]: memref<?xi16>)
+//       CHECK:   %[[V1:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : memref<?xi16> to !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+//       CHECK:   %[[V2:.*]] = builtin.unrealized_conversion_cast %[[V1]] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> to memref<?xbf16>
+//       CHECK:   return %[[V2]]
+func.func @memref_bitcast(%1: memref<?xi16>) -> memref<?xbf16> {
+  %2 = arith.bitcast %1 : memref<?xi16> to memref<?xbf16>
+  func.return %2 : memref<?xbf16>
+}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants