-
Notifications
You must be signed in to change notification settings - Fork 13.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][IR] Allow !llvm.ptr
as vector element type
#125690
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, Matthias! I don't have major concerns at this point.
@@ -43,6 +43,10 @@ template <typename ConcreteType> | |||
class ValueSemantics | |||
: public TypeTrait::TraitBase<ConcreteType, ValueSemantics> {}; | |||
|
|||
/// Type trait indicating that the type is a pointer-like type. | |||
template <typename ConcreteType> | |||
class PointerLike : public TypeTrait::TraitBase<ConcreteType, PointerLike> {}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if this should be an interface... we may have common methods to retrieve the address space, at least?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As long as there are no interface methods that we would like to add (now or in the future), I would leave it as a trait.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure there's demand upstream for an interface, but querying the memory space and perhaps whether the pointer is opaque could be useful for downstream. Maybe CIR and flang folks might have a use.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Given that changing this to an interface is not really intrusive, I vote in favor of adding this once a use case pops up.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure there's demand upstream for an interface, but querying the memory space and perhaps whether the pointer is opaque could be useful for downstream. Maybe CIR and flang folks might have a use.
Yes, that's what I was thinking but I'm also ok with adding this on a case by case basis.
5b7422a
to
6f6cd95
Compare
!llvm.ptr
as vector element type
6f6cd95
to
3320422
Compare
@llvm/pr-subscribers-mlir-core @llvm/pr-subscribers-mlir Author: Matthias Springer (matthias-springer) ChangesThis commit extends the MLIR vector type to support pointer-like types such as This commit also disallows RFC: https://discourse.llvm.org/t/rfc-allow-pointers-as-element-type-of-vector/85360 Patch is 45.16 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/125690.diff 25 Files Affected:
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
index 8b380751c2f9d..bca0feb45aab2 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
@@ -14,6 +14,7 @@
#ifndef MLIR_DIALECT_LLVMIR_LLVMTYPES_H_
#define MLIR_DIALECT_LLVMIR_LLVMTYPES_H_
+#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Types.h"
#include "mlir/Interfaces/DataLayoutInterfaces.h"
#include "mlir/Interfaces/MemorySlotInterfaces.h"
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td
index 2ecbf8f50c50c..3386003cb61fb 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td
@@ -11,6 +11,7 @@
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
include "mlir/IR/AttrTypeBase.td"
+include "mlir/IR/BuiltinTypes.td"
include "mlir/Interfaces/DataLayoutInterfaces.td"
include "mlir/Interfaces/MemorySlotInterfaces.td"
@@ -259,7 +260,8 @@ def LLVMStructType : LLVMType<"LLVMStruct", "struct", [
def LLVMPointerType : LLVMType<"LLVMPointer", "ptr", [
DeclareTypeInterfaceMethods<DataLayoutTypeInterface, [
"getIndexBitwidth", "areCompatible", "verifyEntries",
- "getPreferredAlignment"]>]> {
+ "getPreferredAlignment"]>,
+ PointerLike]> {
let summary = "LLVM pointer type";
let description = [{
The `!llvm.ptr` type is an LLVM pointer type. This type typically represents
diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td
index 857e68cec8c76..c73b6ab3a46b5 100644
--- a/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td
+++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td
@@ -12,6 +12,7 @@
include "mlir/Interfaces/DataLayoutInterfaces.td"
include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/BuiltinTypeInterfaces.td"
+include "mlir/IR/BuiltinTypes.td"
include "mlir/IR/OpBase.td"
//===----------------------------------------------------------------------===//
@@ -39,7 +40,8 @@ def Ptr_PtrType : Ptr_Type<"Ptr", "ptr", [
MemRefElementTypeInterface,
DeclareTypeInterfaceMethods<DataLayoutTypeInterface, [
"areCompatible", "getIndexBitwidth", "verifyEntries",
- "getPreferredAlignment"]>
+ "getPreferredAlignment"]>,
+ PointerLike
]> {
let summary = "pointer type";
let description = [{
diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrTypes.h b/mlir/include/mlir/Dialect/Ptr/IR/PtrTypes.h
index 4fe1b5a1aa423..564d02218eb8e 100644
--- a/mlir/include/mlir/Dialect/Ptr/IR/PtrTypes.h
+++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrTypes.h
@@ -14,6 +14,7 @@
#define MLIR_DIALECT_PTR_IR_PTRTYPES_H
#include "mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.h"
+#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Types.h"
#include "mlir/Interfaces/DataLayoutInterfaces.h"
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index df1e02732617d..7f20878d02f67 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -43,6 +43,10 @@ template <typename ConcreteType>
class ValueSemantics
: public TypeTrait::TraitBase<ConcreteType, ValueSemantics> {};
+/// Type trait indicating that the type is a pointer-like type.
+template <typename ConcreteType>
+class PointerLike : public TypeTrait::TraitBase<ConcreteType, PointerLike> {};
+
//===----------------------------------------------------------------------===//
// TensorType
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index af474b3e3ec47..db84063b42a0a 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -40,6 +40,11 @@ def ValueSemantics : NativeTypeTrait<"ValueSemantics"> {
let cppNamespace = "::mlir";
}
+/// Type trait indicating that the type is a pointer-like type.
+def PointerLike : NativeTypeTrait<"PointerLike"> {
+ let cppNamespace = "::mlir";
+}
+
//===----------------------------------------------------------------------===//
// ComplexType
//===----------------------------------------------------------------------===//
@@ -1249,7 +1254,7 @@ def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", "unranked_tensor", [
// VectorType
//===----------------------------------------------------------------------===//
-def Builtin_VectorTypeElementType : AnyTypeOf<[AnyInteger, Index, AnyFloat]> {
+def Builtin_VectorTypeElementType : AnyTypeOf<[AnyInteger, Index, AnyFloat, AnyPointerLike]> {
let cppFunctionName = "isValidVectorTypeElementType";
}
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index 601517717978e..1fd48d7ab7ee3 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -301,6 +301,8 @@ def Index : Type<CPred<"::llvm::isa<::mlir::IndexType>($_self)">, "index",
"::mlir::IndexType">,
BuildableType<"$_builder.getIndexType()">;
+def AnyPointerLike : Type<CPred<"$_self.hasTrait<::mlir::PointerLike>()">, "pointer-like", "::mlir::Type">;
+
// Any signless integer type or index type.
def AnySignlessIntegerOrIndex : Type<CPred<"$_self.isSignlessIntOrIndex()">,
"signless integer or index">;
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
index d1ccb487d2265..38171267e82ff 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
@@ -140,9 +140,13 @@ static bool isSupportedTypeForConversion(Type type) {
if (isa<LLVM::LLVMFixedVectorType, LLVM::LLVMScalableVectorType>(type))
return false;
- // Scalable types are not supported.
- if (auto vectorType = dyn_cast<VectorType>(type))
+ if (auto vectorType = dyn_cast<VectorType>(type)) {
+ // Vectors of pointers cannot be casted.
+ if (isa<LLVM::LLVMPointerType>(vectorType.getElementType()))
+ return false;
+ // Scalable types are not supported.
return !vectorType.isScalable();
+ }
return true;
}
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
index 8f39ede721c92..403756765268e 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
@@ -690,7 +690,7 @@ LLVMFixedVectorType::getChecked(function_ref<InFlightDiagnostic()> emitError,
}
bool LLVMFixedVectorType::isValidElementType(Type type) {
- return llvm::isa<LLVMPointerType, LLVMPPCFP128Type>(type);
+ return llvm::isa<LLVMPPCFP128Type>(type);
}
LogicalResult
@@ -890,7 +890,7 @@ bool mlir::LLVM::isCompatibleVectorType(Type type) {
if (auto intType = llvm::dyn_cast<IntegerType>(elementType))
return intType.isSignless();
return llvm::isa<BFloat16Type, Float16Type, Float32Type, Float64Type,
- Float80Type, Float128Type>(elementType);
+ Float80Type, Float128Type, LLVMPointerType>(elementType);
}
return false;
}
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
index c3f06dd4d5dd1..800e158a94ed2 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
@@ -2002,8 +2002,8 @@ func.func @gather(%arg0: memref<?xf32>, %arg1: vector<3xi32>, %arg2: vector<3xi1
}
// CHECK-LABEL: func @gather
-// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<3xi32>) -> !llvm.vec<3 x ptr>, f32
-// CHECK: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.vec<3 x ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
+// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<3xi32>) -> vector<3x!llvm.ptr>, f32
+// CHECK: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (vector<3x!llvm.ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
// CHECK: return %[[G]] : vector<3xf32>
// -----
@@ -2015,8 +2015,8 @@ func.func @gather_scalable(%arg0: memref<?xf32>, %arg1: vector<[3]xi32>, %arg2:
}
// CHECK-LABEL: func @gather_scalable
-// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<[3]xi32>) -> !llvm.vec<? x 3 x ptr>, f32
-// CHECK: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.vec<? x 3 x ptr>, vector<[3]xi1>, vector<[3]xf32>) -> vector<[3]xf32>
+// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<[3]xi32>) -> vector<[3]x!llvm.ptr>, f32
+// CHECK: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (vector<[3]x!llvm.ptr>, vector<[3]xi1>, vector<[3]xf32>) -> vector<[3]xf32>
// CHECK: return %[[G]] : vector<[3]xf32>
// -----
@@ -2028,8 +2028,8 @@ func.func @gather_global_memory(%arg0: memref<?xf32, 1>, %arg1: vector<3xi32>, %
}
// CHECK-LABEL: func @gather_global_memory
-// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr<1>, vector<3xi32>) -> !llvm.vec<3 x ptr<1>>, f32
-// CHECK: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.vec<3 x ptr<1>>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
+// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr<1>, vector<3xi32>) -> vector<3x!llvm.ptr<1>>, f32
+// CHECK: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (vector<3x!llvm.ptr<1>>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
// CHECK: return %[[G]] : vector<3xf32>
// -----
@@ -2041,8 +2041,8 @@ func.func @gather_global_memory_scalable(%arg0: memref<?xf32, 1>, %arg1: vector<
}
// CHECK-LABEL: func @gather_global_memory_scalable
-// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr<1>, vector<[3]xi32>) -> !llvm.vec<? x 3 x ptr<1>>, f32
-// CHECK: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.vec<? x 3 x ptr<1>>, vector<[3]xi1>, vector<[3]xf32>) -> vector<[3]xf32>
+// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr<1>, vector<[3]xi32>) -> vector<[3]x!llvm.ptr<1>>, f32
+// CHECK: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (vector<[3]x!llvm.ptr<1>>, vector<[3]xi1>, vector<[3]xf32>) -> vector<[3]xf32>
// CHECK: return %[[G]] : vector<[3]xf32>
// -----
@@ -2055,8 +2055,8 @@ func.func @gather_index(%arg0: memref<?xindex>, %arg1: vector<3xindex>, %arg2: v
}
// CHECK-LABEL: func @gather_index
-// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<3xi64>) -> !llvm.vec<3 x ptr>, i64
-// CHECK: %[[G:.*]] = llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 8 : i32} : (!llvm.vec<3 x ptr>, vector<3xi1>, vector<3xi64>) -> vector<3xi64>
+// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<3xi64>) -> vector<3x!llvm.ptr>, i64
+// CHECK: %[[G:.*]] = llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 8 : i32} : (vector<3x!llvm.ptr>, vector<3xi1>, vector<3xi64>) -> vector<3xi64>
// CHECK: %{{.*}} = builtin.unrealized_conversion_cast %[[G]] : vector<3xi64> to vector<3xindex>
// -----
@@ -2068,8 +2068,8 @@ func.func @gather_index_scalable(%arg0: memref<?xindex>, %arg1: vector<[3]xindex
}
// CHECK-LABEL: func @gather_index_scalable
-// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<[3]xi64>) -> !llvm.vec<? x 3 x ptr>, i64
-// CHECK: %[[G:.*]] = llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 8 : i32} : (!llvm.vec<? x 3 x ptr>, vector<[3]xi1>, vector<[3]xi64>) -> vector<[3]xi64>
+// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<[3]xi64>) -> vector<[3]x!llvm.ptr>, i64
+// CHECK: %[[G:.*]] = llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 8 : i32} : (vector<[3]x!llvm.ptr>, vector<[3]xi1>, vector<[3]xi64>) -> vector<[3]xi64>
// CHECK: %{{.*}} = builtin.unrealized_conversion_cast %[[G]] : vector<[3]xi64> to vector<[3]xindex>
// -----
@@ -2085,14 +2085,14 @@ func.func @gather_2d_from_1d(%arg0: memref<?xf32>, %arg1: vector<2x3xi32>, %arg2
// CHECK: %[[I0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<3xi32>>
// CHECK: %[[M0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<3xi1>>
// CHECK: %[[S0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<3xf32>>
-// CHECK: %[[P0:.*]] = llvm.getelementptr %[[B]][%[[I0]]] : (!llvm.ptr, vector<3xi32>) -> !llvm.vec<3 x ptr>, f32
-// CHECK: %[[G0:.*]] = llvm.intr.masked.gather %[[P0]], %[[M0]], %[[S0]] {alignment = 4 : i32} : (!llvm.vec<3 x ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
+// CHECK: %[[P0:.*]] = llvm.getelementptr %[[B]][%[[I0]]] : (!llvm.ptr, vector<3xi32>) -> vector<3x!llvm.ptr>, f32
+// CHECK: %[[G0:.*]] = llvm.intr.masked.gather %[[P0]], %[[M0]], %[[S0]] {alignment = 4 : i32} : (vector<3x!llvm.ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
// CHECK: %{{.*}} = llvm.insertvalue %[[G0]], %{{.*}}[0] : !llvm.array<2 x vector<3xf32>>
// CHECK: %[[I1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<3xi32>>
// CHECK: %[[M1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<3xi1>>
// CHECK: %[[S1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<3xf32>>
-// CHECK: %[[P1:.*]] = llvm.getelementptr %[[B]][%[[I1]]] : (!llvm.ptr, vector<3xi32>) -> !llvm.vec<3 x ptr>, f32
-// CHECK: %[[G1:.*]] = llvm.intr.masked.gather %[[P1]], %[[M1]], %[[S1]] {alignment = 4 : i32} : (!llvm.vec<3 x ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
+// CHECK: %[[P1:.*]] = llvm.getelementptr %[[B]][%[[I1]]] : (!llvm.ptr, vector<3xi32>) -> vector<3x!llvm.ptr>, f32
+// CHECK: %[[G1:.*]] = llvm.intr.masked.gather %[[P1]], %[[M1]], %[[S1]] {alignment = 4 : i32} : (vector<3x!llvm.ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
// CHECK: %{{.*}} = llvm.insertvalue %[[G1]], %{{.*}}[1] : !llvm.array<2 x vector<3xf32>>
// -----
@@ -2108,14 +2108,14 @@ func.func @gather_2d_from_1d_scalable(%arg0: memref<?xf32>, %arg1: vector<2x[3]x
// CHECK: %[[I0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<[3]xi32>>
// CHECK: %[[M0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<[3]xi1>>
// CHECK: %[[S0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<[3]xf32>>
-// CHECK: %[[P0:.*]] = llvm.getelementptr %[[B]][%[[I0]]] : (!llvm.ptr, vector<[3]xi32>) -> !llvm.vec<? x 3 x ptr>, f32
-// CHECK: %[[G0:.*]] = llvm.intr.masked.gather %[[P0]], %[[M0]], %[[S0]] {alignment = 4 : i32} : (!llvm.vec<? x 3 x ptr>, vector<[3]xi1>, vector<[3]xf32>) -> vector<[3]xf32>
+// CHECK: %[[P0:.*]] = llvm.getelementptr %[[B]][%[[I0]]] : (!llvm.ptr, vector<[3]xi32>) -> vector<[3]x!llvm.ptr>, f32
+// CHECK: %[[G0:.*]] = llvm.intr.masked.gather %[[P0]], %[[M0]], %[[S0]] {alignment = 4 : i32} : (vector<[3]x!llvm.ptr>, vector<[3]xi1>, vector<[3]xf32>) -> vector<[3]xf32>
// CHECK: %{{.*}} = llvm.insertvalue %[[G0]], %{{.*}}[0] : !llvm.array<2 x vector<[3]xf32>>
// CHECK: %[[I1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<[3]xi32>>
// CHECK: %[[M1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<[3]xi1>>
// CHECK: %[[S1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<[3]xf32>>
-// CHECK: %[[P1:.*]] = llvm.getelementptr %[[B]][%[[I1]]] : (!llvm.ptr, vector<[3]xi32>) -> !llvm.vec<? x 3 x ptr>, f32
-// CHECK: %[[G1:.*]] = llvm.intr.masked.gather %[[P1]], %[[M1]], %[[S1]] {alignment = 4 : i32} : (!llvm.vec<? x 3 x ptr>, vector<[3]xi1>, vector<[3]xf32>) -> vector<[3]xf32>
+// CHECK: %[[P1:.*]] = llvm.getelementptr %[[B]][%[[I1]]] : (!llvm.ptr, vector<[3]xi32>) -> vector<[3]x!llvm.ptr>, f32
+// CHECK: %[[G1:.*]] = llvm.intr.masked.gather %[[P1]], %[[M1]], %[[S1]] {alignment = 4 : i32} : (vector<[3]x!llvm.ptr>, vector<[3]xi1>, vector<[3]xf32>) -> vector<[3]xf32>
// CHECK: %{{.*}} = llvm.insertvalue %[[G1]], %{{.*}}[1] : !llvm.array<2 x vector<[3]xf32>>
// -----
@@ -2129,8 +2129,8 @@ func.func @gather_1d_from_2d(%arg0: memref<4x4xf32>, %arg1: vector<4xi32>, %arg2
// CHECK-LABEL: func @gather_1d_from_2d
// CHECK: %[[B:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, i64) -> !llvm.ptr, f32
-// CHECK: %[[P:.*]] = llvm.getelementptr %[[B]][%{{.*}}] : (!llvm.ptr, vector<4xi32>) -> !llvm.vec<4 x ptr>, f32
-// CHECK: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.vec<4 x ptr>, vector<4xi1>, vector<4xf32>) -> vector<4xf32>
+// CHECK: %[[P:.*]] = llvm.getelementptr %[[B]][%{{.*}}] : (!llvm.ptr, vector<4xi32>) -> vector<4x!llvm.ptr>, f32
+// CHECK: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (vector<4x!llvm.ptr>, vector<4xi1>, vector<4xf32>) -> vector<4xf32>
// CHECK: return %[[G]] : vector<4xf32>
// -----
@@ -2143,8 +2143,8 @@ func.func @gather_1d_from_2d_scalable(%arg0: memref<4x?xf32>, %arg1: vector<[4]x
// CHECK-LABEL: func @gather_1d_from_2d_scalable
// CHECK: %[[B:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, i64) -> !llvm.ptr, f32
-// CHECK: %[[P:.*]] = llvm.getelementptr %[[B]][%{{.*}}] : (!llvm.ptr, vector<[4]xi32>) -> !llvm.vec<? x 4 x ptr>, f32
-// CHECK: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.vec<? x 4 x ptr>, vector<[4]xi1>, vector<[4]xf32>) -> vector<[4]xf32>
+// CHECK: %[[P:.*]] = llvm.getelementptr %[[B]][%{{.*}}] : (!llvm.ptr, vector<[4]xi32>) -> vector<[4]x!llvm.ptr>, f32
+// CHECK: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (vector<[4]x!llvm.ptr>, vector<[4]xi1>, vector<[4]xf32>) -> vector<[4]xf32>
// CHECK: return %[[G]] : vector<[4]xf32>
// -----
@@ -2160,8 +2160,8 @@ func.func @scatter(%arg0: memref<?xf32>, %arg1: vector<3xi32>, %arg2: vector<3xi
}
// CHECK-LABEL: func @scatter
-// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<3xi32>) -> !llvm.vec<3 x ptr>, f32
-// CHECK: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 4 : i32} : vector<3xf32>, vector<3xi1> into !llvm.vec<3 x ptr>
+// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<3xi32>) -> vector<3x!llvm.ptr>, f32
+// CHECK: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 4 : i32} : vector<3xf32>, vector<3xi1> into vector<3x!llvm.ptr>
// -----
@@ -2172,8 +2172,8 @@ func.func @scatter_scalable(%arg0: memref<?xf32>, %arg1: vector<[3]xi32>, %arg2:
}
// CHECK-LABEL: func @scatter_scalable
-// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<[3]xi32>) -> !llvm.vec<? x 3 x ptr>, f32
-// CHECK: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 4 : i32} : vector<[3]xf32>, vector<[3]xi1> into !llvm.vec<? x 3 x ptr>
+// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<[3]xi32>) -> vector<[3]x!llvm.ptr>, f32
+// CHECK: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 4 : i32} : vector<[3]xf32>, vector<[3]xi1> into vector<[3]x!llvm.ptr>
// -----
@@ -2184,8 +2184,8 @@ func.func @scatter_index(%arg0: memref<?xindex>, %arg1: vector<3xindex>, %arg2:
}
// CHECK-LABEL: func @scatter_index
-// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<3xi64>) -> !llvm.vec<3 x ptr>, i64
-// CHECK: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 8 : i32} : vector<3xi64>, vector<3xi1> into !llvm.vec<3 x ptr>
+// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<3xi64>) -> vector<3x!llvm.ptr>, i64
+// CHECK: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 8 : i32} : vector<3xi64>, vector<3xi1> into vector<3x!llvm.ptr>
// -----
@@ -2196,8 +2196,8 @@ func.func @scatter_index_scalable(%arg0: memref<?xindex>, %arg1: vector<[3]xinde
}
// CHECK-LABEL: func @scatter_index_scalable
-// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<[3]xi64>) -> !llvm.vec<? x 3 x ptr>, i64
-// CHECK: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 8 : i32} : vector<[3]xi64>, vector<[3]xi1> into !llvm.vec<? x 3 x ptr>
+// CHECK: %[...
[truncated]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This nicely simplifies the handling of vector in LLVM dialect, thus LGTM!
I suggest to wait for a bit such that others can also weight in their opinion on this change.
@@ -43,6 +43,10 @@ template <typename ConcreteType> | |||
class ValueSemantics | |||
: public TypeTrait::TraitBase<ConcreteType, ValueSemantics> {}; | |||
|
|||
/// Type trait indicating that the type is a pointer-like type. | |||
template <typename ConcreteType> | |||
class PointerLike : public TypeTrait::TraitBase<ConcreteType, PointerLike> {}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Given that changing this to an interface is not really intrusive, I vote in favor of adding this once a use case pops up.
605d2f9
to
4913d6b
Compare
4913d6b
to
52fac3b
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be great if we could remove !llvm.vector
(in a separate PR). Do you know if there is anything missing to do so?
@@ -43,6 +43,10 @@ template <typename ConcreteType> | |||
class ValueSemantics | |||
: public TypeTrait::TraitBase<ConcreteType, ValueSemantics> {}; | |||
|
|||
/// Type trait indicating that the type is a pointer-like type. | |||
template <typename ConcreteType> | |||
class PointerLike : public TypeTrait::TraitBase<ConcreteType, PointerLike> {}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure there's demand upstream for an interface, but querying the memory space and perhaps whether the pointer is opaque could be useful for downstream. Maybe CIR and flang folks might have a use.
Yes, that's what I was thinking but I'm also ok with adding this on a case by case basis.
The only allowed element type for |
@@ -40,6 +40,11 @@ def ValueSemantics : NativeTypeTrait<"ValueSemantics"> { | |||
let cppNamespace = "::mlir"; | |||
} | |||
|
|||
/// Type trait indicating that the type is a pointer-like type. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This would deserve a bit more documentation: what's the implication of marking a type with "pointer-like"?
Why is memref not a a PointerLike type for example?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: we wouldn't have to answer this kind of difficult question if we were making Builtin_VectorTypeElementType
a TypeInferface instead, we could implement it on the in-tree "pointer-like" types we select.
(Please take this comment thread as a current blocker)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What if we rename it to BarePointerLike
(and/or mention this in a comment)? (And when we agreed on the exact mechanics of Builtin_VectorTypeElementType
we can delete BarePointerLike
again, unless we find it useful for some other reason.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think that's addressing my underlying concern: anything that isn't named "vectorElementTypeInterface" or something like this should be a separate PR that we can reason about on its own merits.
The fundamental problem here is that there is nothing special about pointers for vector element types from a builtin type point of view, I don't quite see a workaround to this right now TBH.
This commit extends the MLIR vector type to support pointer-like types such as
!llvm.ptr
and!ptr.ptr
, as indicated by the newly addedPointerLike
type trait. This makes the LLVM dialect closer to LLVM IR. LLVM IR already supports pointers as vector element type.This commit also disallows
!llvm.ptr
as an element type of!llvm.vec
. This type exists due to limitations of the MLIR vector type.RFC: https://discourse.llvm.org/t/rfc-allow-pointers-as-element-type-of-vector/85360