Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][IR] Allow !llvm.ptr as vector element type #125690

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

matthias-springer
Copy link
Member

@matthias-springer matthias-springer commented Feb 4, 2025

This commit extends the MLIR vector type to support pointer-like types such as !llvm.ptr and !ptr.ptr, as indicated by the newly added PointerLike type trait. This makes the LLVM dialect closer to LLVM IR. LLVM IR already supports pointers as vector element type.

This commit also disallows !llvm.ptr as an element type of !llvm.vec. This type exists due to limitations of the MLIR vector type.

RFC: https://discourse.llvm.org/t/rfc-allow-pointers-as-element-type-of-vector/85360

Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, Matthias! I don't have major concerns at this point.

@@ -43,6 +43,10 @@ template <typename ConcreteType>
class ValueSemantics
: public TypeTrait::TraitBase<ConcreteType, ValueSemantics> {};

/// Type trait indicating that the type is a pointer-like type.
template <typename ConcreteType>
class PointerLike : public TypeTrait::TraitBase<ConcreteType, PointerLike> {};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if this should be an interface... we may have common methods to retrieve the address space, at least?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As long as there are no interface methods that we would like to add (now or in the future), I would leave it as a trait.

Copy link
Contributor

@fabianmcg fabianmcg Mar 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure there's demand upstream for an interface, but querying the memory space and perhaps whether the pointer is opaque could be useful for downstream. Maybe CIR and flang folks might have a use.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that changing this to an interface is not really intrusive, I vote in favor of adding this once a use case pops up.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure there's demand upstream for an interface, but querying the memory space and perhaps whether the pointer is opaque could be useful for downstream. Maybe CIR and flang folks might have a use.

Yes, that's what I was thinking but I'm also ok with adding this on a case by case basis.

@matthias-springer matthias-springer force-pushed the users/matthias-springer/vector_type_ptr branch from 5b7422a to 6f6cd95 Compare March 20, 2025 12:59
@matthias-springer matthias-springer changed the title [mlir][IR] Experiment: Allow ptr as vector element type [mlir][IR] Allow ptr as vector element type Mar 20, 2025
@matthias-springer matthias-springer changed the title [mlir][IR] Allow ptr as vector element type [mlir][IR] Allow !llvm.ptr as vector element type Mar 20, 2025
@matthias-springer matthias-springer changed the title [mlir][IR] Allow !llvm.ptr as vector element type [mlir][IR] Allow !llvm.ptr as vector element type Mar 20, 2025
@matthias-springer matthias-springer force-pushed the users/matthias-springer/vector_type_ptr branch from 6f6cd95 to 3320422 Compare March 20, 2025 13:33
@matthias-springer matthias-springer marked this pull request as ready for review March 20, 2025 13:33
@llvmbot
Copy link
Member

llvmbot commented Mar 20, 2025

@llvm/pr-subscribers-mlir-core
@llvm/pr-subscribers-mlir-llvm
@llvm/pr-subscribers-mlir-ods

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

Changes

This commit extends the MLIR vector type to support pointer-like types such as !llvm.ptr and !ptr.ptr, as indicated by the newly added PointerLike type trait. This makes the LLVM dialect closer to LLVM IR. LLVM IR already supports pointers as vector element type.

This commit also disallows !llvm.ptr as an element type of !llvm.vec. This type exists due to limitations of the MLIR vector type.

RFC: https://discourse.llvm.org/t/rfc-allow-pointers-as-element-type-of-vector/85360


Patch is 45.16 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/125690.diff

25 Files Affected:

  • (modified) mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h (+1)
  • (modified) mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td (+3-1)
  • (modified) mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td (+3-1)
  • (modified) mlir/include/mlir/Dialect/Ptr/IR/PtrTypes.h (+1)
  • (modified) mlir/include/mlir/IR/BuiltinTypes.h (+4)
  • (modified) mlir/include/mlir/IR/BuiltinTypes.td (+6-1)
  • (modified) mlir/include/mlir/IR/CommonTypeConstraints.td (+2)
  • (modified) mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp (+6-2)
  • (modified) mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp (+2-2)
  • (modified) mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir (+36-36)
  • (modified) mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir (+4-4)
  • (modified) mlir/test/Dialect/LLVMIR/invalid.mlir (+5-5)
  • (modified) mlir/test/Dialect/LLVMIR/mem2reg.mlir (+1-1)
  • (modified) mlir/test/Dialect/LLVMIR/opaque-ptr.mlir (+3-3)
  • (modified) mlir/test/Dialect/LLVMIR/roundtrip.mlir (+7-7)
  • (modified) mlir/test/Dialect/LLVMIR/types.mlir (+2-2)
  • (modified) mlir/test/IR/test-verifiers-type.mlir (+15)
  • (modified) mlir/test/Target/LLVMIR/Import/constant.ll (+4-4)
  • (modified) mlir/test/Target/LLVMIR/Import/incorrect-scalable-vector-check.ll (+1-1)
  • (modified) mlir/test/Target/LLVMIR/Import/intrinsic.ll (+5-5)
  • (modified) mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir (+10-10)
  • (modified) mlir/test/Target/LLVMIR/llvmir-invalid.mlir (+8-8)
  • (modified) mlir/test/Target/LLVMIR/llvmir-types.mlir (+1-1)
  • (modified) mlir/test/Target/LLVMIR/llvmir.mlir (+3-3)
  • (modified) mlir/test/Target/LLVMIR/opaque-ptr.mlir (+2-2)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
index 8b380751c2f9d..bca0feb45aab2 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
@@ -14,6 +14,7 @@
 #ifndef MLIR_DIALECT_LLVMIR_LLVMTYPES_H_
 #define MLIR_DIALECT_LLVMIR_LLVMTYPES_H_
 
+#include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Types.h"
 #include "mlir/Interfaces/DataLayoutInterfaces.h"
 #include "mlir/Interfaces/MemorySlotInterfaces.h"
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td
index 2ecbf8f50c50c..3386003cb61fb 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td
@@ -11,6 +11,7 @@
 
 include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
 include "mlir/IR/AttrTypeBase.td"
+include "mlir/IR/BuiltinTypes.td"
 include "mlir/Interfaces/DataLayoutInterfaces.td"
 include "mlir/Interfaces/MemorySlotInterfaces.td"
 
@@ -259,7 +260,8 @@ def LLVMStructType : LLVMType<"LLVMStruct", "struct", [
 def LLVMPointerType : LLVMType<"LLVMPointer", "ptr", [
     DeclareTypeInterfaceMethods<DataLayoutTypeInterface, [
       "getIndexBitwidth", "areCompatible", "verifyEntries",
-      "getPreferredAlignment"]>]> {
+      "getPreferredAlignment"]>,
+    PointerLike]> {
   let summary = "LLVM pointer type";
   let description = [{
     The `!llvm.ptr` type is an LLVM pointer type. This type typically represents
diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td
index 857e68cec8c76..c73b6ab3a46b5 100644
--- a/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td
+++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td
@@ -12,6 +12,7 @@
 include "mlir/Interfaces/DataLayoutInterfaces.td"
 include "mlir/IR/AttrTypeBase.td"
 include "mlir/IR/BuiltinTypeInterfaces.td"
+include "mlir/IR/BuiltinTypes.td"
 include "mlir/IR/OpBase.td"
 
 //===----------------------------------------------------------------------===//
@@ -39,7 +40,8 @@ def Ptr_PtrType : Ptr_Type<"Ptr", "ptr", [
     MemRefElementTypeInterface,
     DeclareTypeInterfaceMethods<DataLayoutTypeInterface, [
       "areCompatible", "getIndexBitwidth", "verifyEntries",
-      "getPreferredAlignment"]>
+      "getPreferredAlignment"]>,
+    PointerLike
   ]> {
   let summary = "pointer type";
   let description = [{
diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrTypes.h b/mlir/include/mlir/Dialect/Ptr/IR/PtrTypes.h
index 4fe1b5a1aa423..564d02218eb8e 100644
--- a/mlir/include/mlir/Dialect/Ptr/IR/PtrTypes.h
+++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrTypes.h
@@ -14,6 +14,7 @@
 #define MLIR_DIALECT_PTR_IR_PTRTYPES_H
 
 #include "mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.h"
+#include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Types.h"
 #include "mlir/Interfaces/DataLayoutInterfaces.h"
 
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index df1e02732617d..7f20878d02f67 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -43,6 +43,10 @@ template <typename ConcreteType>
 class ValueSemantics
     : public TypeTrait::TraitBase<ConcreteType, ValueSemantics> {};
 
+/// Type trait indicating that the type is a pointer-like type.
+template <typename ConcreteType>
+class PointerLike : public TypeTrait::TraitBase<ConcreteType, PointerLike> {};
+
 //===----------------------------------------------------------------------===//
 // TensorType
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index af474b3e3ec47..db84063b42a0a 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -40,6 +40,11 @@ def ValueSemantics : NativeTypeTrait<"ValueSemantics"> {
   let cppNamespace = "::mlir";
 }
 
+/// Type trait indicating that the type is a pointer-like type.
+def PointerLike : NativeTypeTrait<"PointerLike"> {
+  let cppNamespace = "::mlir";
+}
+
 //===----------------------------------------------------------------------===//
 // ComplexType
 //===----------------------------------------------------------------------===//
@@ -1249,7 +1254,7 @@ def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", "unranked_tensor", [
 // VectorType
 //===----------------------------------------------------------------------===//
 
-def Builtin_VectorTypeElementType : AnyTypeOf<[AnyInteger, Index, AnyFloat]> {
+def Builtin_VectorTypeElementType : AnyTypeOf<[AnyInteger, Index, AnyFloat, AnyPointerLike]> {
   let cppFunctionName = "isValidVectorTypeElementType";
 }
 
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index 601517717978e..1fd48d7ab7ee3 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -301,6 +301,8 @@ def Index : Type<CPred<"::llvm::isa<::mlir::IndexType>($_self)">, "index",
                  "::mlir::IndexType">,
             BuildableType<"$_builder.getIndexType()">;
 
+def AnyPointerLike : Type<CPred<"$_self.hasTrait<::mlir::PointerLike>()">, "pointer-like", "::mlir::Type">;
+
 // Any signless integer type or index type.
 def AnySignlessIntegerOrIndex : Type<CPred<"$_self.isSignlessIntOrIndex()">,
                                      "signless integer or index">;
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
index d1ccb487d2265..38171267e82ff 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
@@ -140,9 +140,13 @@ static bool isSupportedTypeForConversion(Type type) {
   if (isa<LLVM::LLVMFixedVectorType, LLVM::LLVMScalableVectorType>(type))
     return false;
 
-  // Scalable types are not supported.
-  if (auto vectorType = dyn_cast<VectorType>(type))
+  if (auto vectorType = dyn_cast<VectorType>(type)) {
+    // Vectors of pointers cannot be casted.
+    if (isa<LLVM::LLVMPointerType>(vectorType.getElementType()))
+      return false;
+    // Scalable types are not supported.
     return !vectorType.isScalable();
+  }
   return true;
 }
 
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
index 8f39ede721c92..403756765268e 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
@@ -690,7 +690,7 @@ LLVMFixedVectorType::getChecked(function_ref<InFlightDiagnostic()> emitError,
 }
 
 bool LLVMFixedVectorType::isValidElementType(Type type) {
-  return llvm::isa<LLVMPointerType, LLVMPPCFP128Type>(type);
+  return llvm::isa<LLVMPPCFP128Type>(type);
 }
 
 LogicalResult
@@ -890,7 +890,7 @@ bool mlir::LLVM::isCompatibleVectorType(Type type) {
     if (auto intType = llvm::dyn_cast<IntegerType>(elementType))
       return intType.isSignless();
     return llvm::isa<BFloat16Type, Float16Type, Float32Type, Float64Type,
-                     Float80Type, Float128Type>(elementType);
+                     Float80Type, Float128Type, LLVMPointerType>(elementType);
   }
   return false;
 }
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
index c3f06dd4d5dd1..800e158a94ed2 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
@@ -2002,8 +2002,8 @@ func.func @gather(%arg0: memref<?xf32>, %arg1: vector<3xi32>, %arg2: vector<3xi1
 }
 
 // CHECK-LABEL: func @gather
-// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<3xi32>) -> !llvm.vec<3 x ptr>, f32
-// CHECK: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.vec<3 x ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
+// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<3xi32>) -> vector<3x!llvm.ptr>, f32
+// CHECK: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (vector<3x!llvm.ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
 // CHECK: return %[[G]] : vector<3xf32>
 
 // -----
@@ -2015,8 +2015,8 @@ func.func @gather_scalable(%arg0: memref<?xf32>, %arg1: vector<[3]xi32>, %arg2:
 }
 
 // CHECK-LABEL: func @gather_scalable
-// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<[3]xi32>) -> !llvm.vec<? x 3 x ptr>, f32
-// CHECK: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.vec<? x 3 x ptr>, vector<[3]xi1>, vector<[3]xf32>) -> vector<[3]xf32>
+// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<[3]xi32>) -> vector<[3]x!llvm.ptr>, f32
+// CHECK: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (vector<[3]x!llvm.ptr>, vector<[3]xi1>, vector<[3]xf32>) -> vector<[3]xf32>
 // CHECK: return %[[G]] : vector<[3]xf32>
 
 // -----
@@ -2028,8 +2028,8 @@ func.func @gather_global_memory(%arg0: memref<?xf32, 1>, %arg1: vector<3xi32>, %
 }
 
 // CHECK-LABEL: func @gather_global_memory
-// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr<1>, vector<3xi32>) -> !llvm.vec<3 x ptr<1>>, f32
-// CHECK: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.vec<3 x ptr<1>>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
+// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr<1>, vector<3xi32>) -> vector<3x!llvm.ptr<1>>, f32
+// CHECK: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (vector<3x!llvm.ptr<1>>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
 // CHECK: return %[[G]] : vector<3xf32>
 
 // -----
@@ -2041,8 +2041,8 @@ func.func @gather_global_memory_scalable(%arg0: memref<?xf32, 1>, %arg1: vector<
 }
 
 // CHECK-LABEL: func @gather_global_memory_scalable
-// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr<1>, vector<[3]xi32>) -> !llvm.vec<? x 3 x ptr<1>>, f32
-// CHECK: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.vec<? x 3 x ptr<1>>, vector<[3]xi1>, vector<[3]xf32>) -> vector<[3]xf32>
+// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr<1>, vector<[3]xi32>) -> vector<[3]x!llvm.ptr<1>>, f32
+// CHECK: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (vector<[3]x!llvm.ptr<1>>, vector<[3]xi1>, vector<[3]xf32>) -> vector<[3]xf32>
 // CHECK: return %[[G]] : vector<[3]xf32>
 
 // -----
@@ -2055,8 +2055,8 @@ func.func @gather_index(%arg0: memref<?xindex>, %arg1: vector<3xindex>, %arg2: v
 }
 
 // CHECK-LABEL: func @gather_index
-// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<3xi64>) -> !llvm.vec<3 x ptr>, i64
-// CHECK: %[[G:.*]] = llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 8 : i32} : (!llvm.vec<3 x ptr>, vector<3xi1>, vector<3xi64>) -> vector<3xi64>
+// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<3xi64>) -> vector<3x!llvm.ptr>, i64
+// CHECK: %[[G:.*]] = llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 8 : i32} : (vector<3x!llvm.ptr>, vector<3xi1>, vector<3xi64>) -> vector<3xi64>
 // CHECK: %{{.*}} = builtin.unrealized_conversion_cast %[[G]] : vector<3xi64> to vector<3xindex>
 
 // -----
@@ -2068,8 +2068,8 @@ func.func @gather_index_scalable(%arg0: memref<?xindex>, %arg1: vector<[3]xindex
 }
 
 // CHECK-LABEL: func @gather_index_scalable
-// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<[3]xi64>) -> !llvm.vec<? x 3 x ptr>, i64
-// CHECK: %[[G:.*]] = llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 8 : i32} : (!llvm.vec<? x 3 x ptr>, vector<[3]xi1>, vector<[3]xi64>) -> vector<[3]xi64>
+// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<[3]xi64>) -> vector<[3]x!llvm.ptr>, i64
+// CHECK: %[[G:.*]] = llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 8 : i32} : (vector<[3]x!llvm.ptr>, vector<[3]xi1>, vector<[3]xi64>) -> vector<[3]xi64>
 // CHECK: %{{.*}} = builtin.unrealized_conversion_cast %[[G]] : vector<[3]xi64> to vector<[3]xindex>
 
 // -----
@@ -2085,14 +2085,14 @@ func.func @gather_2d_from_1d(%arg0: memref<?xf32>, %arg1: vector<2x3xi32>, %arg2
 // CHECK: %[[I0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<3xi32>>
 // CHECK: %[[M0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<3xi1>>
 // CHECK: %[[S0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<3xf32>>
-// CHECK: %[[P0:.*]] = llvm.getelementptr %[[B]][%[[I0]]] : (!llvm.ptr, vector<3xi32>) -> !llvm.vec<3 x ptr>, f32
-// CHECK: %[[G0:.*]] = llvm.intr.masked.gather %[[P0]], %[[M0]], %[[S0]] {alignment = 4 : i32} : (!llvm.vec<3 x ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
+// CHECK: %[[P0:.*]] = llvm.getelementptr %[[B]][%[[I0]]] : (!llvm.ptr, vector<3xi32>) -> vector<3x!llvm.ptr>, f32
+// CHECK: %[[G0:.*]] = llvm.intr.masked.gather %[[P0]], %[[M0]], %[[S0]] {alignment = 4 : i32} : (vector<3x!llvm.ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
 // CHECK: %{{.*}} = llvm.insertvalue %[[G0]], %{{.*}}[0] : !llvm.array<2 x vector<3xf32>>
 // CHECK: %[[I1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<3xi32>>
 // CHECK: %[[M1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<3xi1>>
 // CHECK: %[[S1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<3xf32>>
-// CHECK: %[[P1:.*]] = llvm.getelementptr %[[B]][%[[I1]]] : (!llvm.ptr, vector<3xi32>) -> !llvm.vec<3 x ptr>, f32
-// CHECK: %[[G1:.*]] = llvm.intr.masked.gather %[[P1]], %[[M1]], %[[S1]] {alignment = 4 : i32} : (!llvm.vec<3 x ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
+// CHECK: %[[P1:.*]] = llvm.getelementptr %[[B]][%[[I1]]] : (!llvm.ptr, vector<3xi32>) -> vector<3x!llvm.ptr>, f32
+// CHECK: %[[G1:.*]] = llvm.intr.masked.gather %[[P1]], %[[M1]], %[[S1]] {alignment = 4 : i32} : (vector<3x!llvm.ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
 // CHECK: %{{.*}} = llvm.insertvalue %[[G1]], %{{.*}}[1] : !llvm.array<2 x vector<3xf32>>
 
 // -----
@@ -2108,14 +2108,14 @@ func.func @gather_2d_from_1d_scalable(%arg0: memref<?xf32>, %arg1: vector<2x[3]x
 // CHECK: %[[I0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<[3]xi32>>
 // CHECK: %[[M0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<[3]xi1>>
 // CHECK: %[[S0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<[3]xf32>>
-// CHECK: %[[P0:.*]] = llvm.getelementptr %[[B]][%[[I0]]] : (!llvm.ptr, vector<[3]xi32>) -> !llvm.vec<? x 3 x ptr>, f32
-// CHECK: %[[G0:.*]] = llvm.intr.masked.gather %[[P0]], %[[M0]], %[[S0]] {alignment = 4 : i32} : (!llvm.vec<? x 3 x ptr>, vector<[3]xi1>, vector<[3]xf32>) -> vector<[3]xf32>
+// CHECK: %[[P0:.*]] = llvm.getelementptr %[[B]][%[[I0]]] : (!llvm.ptr, vector<[3]xi32>) -> vector<[3]x!llvm.ptr>, f32
+// CHECK: %[[G0:.*]] = llvm.intr.masked.gather %[[P0]], %[[M0]], %[[S0]] {alignment = 4 : i32} : (vector<[3]x!llvm.ptr>, vector<[3]xi1>, vector<[3]xf32>) -> vector<[3]xf32>
 // CHECK: %{{.*}} = llvm.insertvalue %[[G0]], %{{.*}}[0] : !llvm.array<2 x vector<[3]xf32>>
 // CHECK: %[[I1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<[3]xi32>>
 // CHECK: %[[M1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<[3]xi1>>
 // CHECK: %[[S1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<[3]xf32>>
-// CHECK: %[[P1:.*]] = llvm.getelementptr %[[B]][%[[I1]]] : (!llvm.ptr, vector<[3]xi32>) -> !llvm.vec<? x 3 x ptr>, f32
-// CHECK: %[[G1:.*]] = llvm.intr.masked.gather %[[P1]], %[[M1]], %[[S1]] {alignment = 4 : i32} : (!llvm.vec<? x 3 x ptr>, vector<[3]xi1>, vector<[3]xf32>) -> vector<[3]xf32>
+// CHECK: %[[P1:.*]] = llvm.getelementptr %[[B]][%[[I1]]] : (!llvm.ptr, vector<[3]xi32>) -> vector<[3]x!llvm.ptr>, f32
+// CHECK: %[[G1:.*]] = llvm.intr.masked.gather %[[P1]], %[[M1]], %[[S1]] {alignment = 4 : i32} : (vector<[3]x!llvm.ptr>, vector<[3]xi1>, vector<[3]xf32>) -> vector<[3]xf32>
 // CHECK: %{{.*}} = llvm.insertvalue %[[G1]], %{{.*}}[1] : !llvm.array<2 x vector<[3]xf32>>
 
 // -----
@@ -2129,8 +2129,8 @@ func.func @gather_1d_from_2d(%arg0: memref<4x4xf32>, %arg1: vector<4xi32>, %arg2
 
 // CHECK-LABEL: func @gather_1d_from_2d
 // CHECK: %[[B:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, i64) -> !llvm.ptr, f32
-// CHECK: %[[P:.*]] = llvm.getelementptr %[[B]][%{{.*}}] : (!llvm.ptr, vector<4xi32>) -> !llvm.vec<4 x ptr>, f32
-// CHECK: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.vec<4 x ptr>, vector<4xi1>, vector<4xf32>) -> vector<4xf32>
+// CHECK: %[[P:.*]] = llvm.getelementptr %[[B]][%{{.*}}] : (!llvm.ptr, vector<4xi32>) -> vector<4x!llvm.ptr>, f32
+// CHECK: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (vector<4x!llvm.ptr>, vector<4xi1>, vector<4xf32>) -> vector<4xf32>
 // CHECK: return %[[G]] : vector<4xf32>
 
 // -----
@@ -2143,8 +2143,8 @@ func.func @gather_1d_from_2d_scalable(%arg0: memref<4x?xf32>, %arg1: vector<[4]x
 
 // CHECK-LABEL: func @gather_1d_from_2d_scalable
 // CHECK: %[[B:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, i64) -> !llvm.ptr, f32
-// CHECK: %[[P:.*]] = llvm.getelementptr %[[B]][%{{.*}}] : (!llvm.ptr, vector<[4]xi32>) -> !llvm.vec<? x 4 x ptr>, f32
-// CHECK: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.vec<? x 4 x ptr>, vector<[4]xi1>, vector<[4]xf32>) -> vector<[4]xf32>
+// CHECK: %[[P:.*]] = llvm.getelementptr %[[B]][%{{.*}}] : (!llvm.ptr, vector<[4]xi32>) -> vector<[4]x!llvm.ptr>, f32
+// CHECK: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (vector<[4]x!llvm.ptr>, vector<[4]xi1>, vector<[4]xf32>) -> vector<[4]xf32>
 // CHECK: return %[[G]] : vector<[4]xf32>
 
 // -----
@@ -2160,8 +2160,8 @@ func.func @scatter(%arg0: memref<?xf32>, %arg1: vector<3xi32>, %arg2: vector<3xi
 }
 
 // CHECK-LABEL: func @scatter
-// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<3xi32>) -> !llvm.vec<3 x ptr>, f32
-// CHECK: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 4 : i32} : vector<3xf32>, vector<3xi1> into !llvm.vec<3 x ptr>
+// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<3xi32>) -> vector<3x!llvm.ptr>, f32
+// CHECK: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 4 : i32} : vector<3xf32>, vector<3xi1> into vector<3x!llvm.ptr>
 
 // -----
 
@@ -2172,8 +2172,8 @@ func.func @scatter_scalable(%arg0: memref<?xf32>, %arg1: vector<[3]xi32>, %arg2:
 }
 
 // CHECK-LABEL: func @scatter_scalable
-// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<[3]xi32>) -> !llvm.vec<? x 3 x ptr>, f32
-// CHECK: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 4 : i32} : vector<[3]xf32>, vector<[3]xi1> into !llvm.vec<? x 3 x ptr>
+// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<[3]xi32>) -> vector<[3]x!llvm.ptr>, f32
+// CHECK: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 4 : i32} : vector<[3]xf32>, vector<[3]xi1> into vector<[3]x!llvm.ptr>
 
 // -----
 
@@ -2184,8 +2184,8 @@ func.func @scatter_index(%arg0: memref<?xindex>, %arg1: vector<3xindex>, %arg2:
 }
 
 // CHECK-LABEL: func @scatter_index
-// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<3xi64>) -> !llvm.vec<3 x ptr>, i64
-// CHECK: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 8 : i32} : vector<3xi64>, vector<3xi1> into !llvm.vec<3 x ptr>
+// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<3xi64>) -> vector<3x!llvm.ptr>, i64
+// CHECK: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 8 : i32} : vector<3xi64>, vector<3xi1> into vector<3x!llvm.ptr>
 
 // -----
 
@@ -2196,8 +2196,8 @@ func.func @scatter_index_scalable(%arg0: memref<?xindex>, %arg1: vector<[3]xinde
 }
 
 // CHECK-LABEL: func @scatter_index_scalable
-// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<[3]xi64>) -> !llvm.vec<? x 3 x ptr>, i64
-// CHECK: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 8 : i32} : vector<[3]xi64>, vector<[3]xi1> into !llvm.vec<? x 3 x ptr>
+// CHECK: %[...
[truncated]

@kuhar kuhar self-requested a review March 20, 2025 15:25
Copy link
Contributor

@Dinistro Dinistro left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This nicely simplifies the handling of vector in LLVM dialect, thus LGTM!
I suggest to wait for a bit such that others can also weight in their opinion on this change.

@@ -43,6 +43,10 @@ template <typename ConcreteType>
class ValueSemantics
: public TypeTrait::TraitBase<ConcreteType, ValueSemantics> {};

/// Type trait indicating that the type is a pointer-like type.
template <typename ConcreteType>
class PointerLike : public TypeTrait::TraitBase<ConcreteType, PointerLike> {};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that changing this to an interface is not really intrusive, I vote in favor of adding this once a use case pops up.

@matthias-springer matthias-springer force-pushed the users/matthias-springer/vector_type_ptr branch 2 times, most recently from 605d2f9 to 4913d6b Compare March 27, 2025 10:35
@matthias-springer matthias-springer force-pushed the users/matthias-springer/vector_type_ptr branch from 4913d6b to 52fac3b Compare March 27, 2025 10:46
Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be great if we could remove !llvm.vector (in a separate PR). Do you know if there is anything missing to do so?

@@ -43,6 +43,10 @@ template <typename ConcreteType>
class ValueSemantics
: public TypeTrait::TraitBase<ConcreteType, ValueSemantics> {};

/// Type trait indicating that the type is a pointer-like type.
template <typename ConcreteType>
class PointerLike : public TypeTrait::TraitBase<ConcreteType, PointerLike> {};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure there's demand upstream for an interface, but querying the memory space and perhaps whether the pointer is opaque could be useful for downstream. Maybe CIR and flang folks might have a use.

Yes, that's what I was thinking but I'm also ok with adding this on a case by case basis.

@matthias-springer
Copy link
Member Author

It would be great if we could remove !llvm.vector (in a separate PR). Do you know if there is anything missing to do so?

The only allowed element type for !llvm.vec is now LLVMPPCFP128Type. It corresponds to APFloat::PPCDoubleDouble. I should be able to clean this up, now that FloatType is a type interface.

@@ -40,6 +40,11 @@ def ValueSemantics : NativeTypeTrait<"ValueSemantics"> {
let cppNamespace = "::mlir";
}

/// Type trait indicating that the type is a pointer-like type.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would deserve a bit more documentation: what's the implication of marking a type with "pointer-like"?
Why is memref not a a PointerLike type for example?

Copy link
Collaborator

@joker-eph joker-eph Mar 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: we wouldn't have to answer this kind of difficult question if we were making Builtin_VectorTypeElementType a TypeInferface instead, we could implement it on the in-tree "pointer-like" types we select.

(Please take this comment thread as a current blocker)

Copy link
Member Author

@matthias-springer matthias-springer Mar 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if we rename it to BarePointerLike (and/or mention this in a comment)? (And when we agreed on the exact mechanics of Builtin_VectorTypeElementType we can delete BarePointerLike again, unless we find it useful for some other reason.)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think that's addressing my underlying concern: anything that isn't named "vectorElementTypeInterface" or something like this should be a separate PR that we can reason about on its own merits.

The fundamental problem here is that there is nothing special about pointers for vector element types from a builtin type point of view, I don't quite see a workaround to this right now TBH.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

9 participants