diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp index 754ed89814293..5c1afe8034c73 100644 --- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp +++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp @@ -54,6 +54,25 @@ struct ConstrainedVectorConvertToLLVMPattern } }; +/// No-op bitcast. Propagate type input arg if converted source and dest types +/// are the same. +struct IdentityBitcastLowering final + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::BitcastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const final { + Value src = adaptor.getIn(); + Type resultType = getTypeConverter()->convertType(op.getType()); + if (src.getType() != resultType) + return rewriter.notifyMatchFailure(op, "Types are different"); + + rewriter.replaceOp(op, src); + return success(); + } +}; + //===----------------------------------------------------------------------===// // Straightforward Op Lowerings //===----------------------------------------------------------------------===// @@ -524,6 +543,9 @@ void mlir::arith::registerConvertArithToLLVMInterface( void mlir::arith::populateArithToLLVMConversionPatterns( const LLVMTypeConverter &converter, RewritePatternSet &patterns) { + + patterns.add(converter, patterns.getContext()); + // clang-format off patterns.add< AddFOpLowering, diff --git a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp index bf3f31729c3da..fe4781138fa29 100644 --- a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp +++ b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp @@ -103,6 +103,14 @@ LogicalResult LLVM::detail::handleMultidimensionalVectors( return success(); } +static bool isVectorCompatibleType(Type type) { + // Limit `vectorOneToOneRewrite` to scalar and vector types (and to + // `LLVM::LLVMArrayType` which have a special handling). + return isa(type) && + LLVM::isCompatibleType(type); +} + LogicalResult LLVM::detail::vectorOneToOneRewrite( Operation *op, StringRef targetOp, ValueRange operands, ArrayRef targetAttrs, @@ -111,7 +119,7 @@ LogicalResult LLVM::detail::vectorOneToOneRewrite( assert(!operands.empty()); // Cannot convert ops if their operands are not of LLVM type. - if (!llvm::all_of(operands.getTypes(), isCompatibleType)) + if (!llvm::all_of(operands.getTypes(), isVectorCompatibleType)) return failure(); auto llvmNDVectorTy = operands[0].getType(); diff --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir index 1dabacfd8a47c..9a6c4bca88f3b 100644 --- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir +++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir @@ -727,3 +727,15 @@ func.func @ops_supporting_overflow(%arg0: i64, %arg1: i64) { %3 = arith.shli %arg0, %arg1 overflow : i64 return } + +// ----- + +// CHECK-LABEL: func @memref_bitcast +// CHECK-SAME: (%[[ARG:.*]]: memref) +// CHECK: %[[V1:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : memref to !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> +// CHECK: %[[V2:.*]] = builtin.unrealized_conversion_cast %[[V1]] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> to memref +// CHECK: return %[[V2]] +func.func @memref_bitcast(%1: memref) -> memref { + %2 = arith.bitcast %1 : memref to memref + func.return %2 : memref +}