Skip to content

Commit 5b7422a

Browse files
[mlir][IR] Experiment: Allow ptr as vector element type
1 parent 3bd11b5 commit 5b7422a

File tree

9 files changed

+37
-4
lines changed

9 files changed

+37
-4
lines changed

mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h

+1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#ifndef MLIR_DIALECT_LLVMIR_LLVMTYPES_H_
1515
#define MLIR_DIALECT_LLVMIR_LLVMTYPES_H_
1616

17+
#include "mlir/IR/BuiltinTypes.h"
1718
#include "mlir/IR/Types.h"
1819
#include "mlir/Interfaces/DataLayoutInterfaces.h"
1920
#include "mlir/Interfaces/MemorySlotInterfaces.h"

mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td

+3-1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
1313
include "mlir/IR/AttrTypeBase.td"
14+
include "mlir/IR/BuiltinTypes.td"
1415
include "mlir/Interfaces/DataLayoutInterfaces.td"
1516
include "mlir/Interfaces/MemorySlotInterfaces.td"
1617

@@ -257,7 +258,8 @@ def LLVMStructType : LLVMType<"LLVMStruct", "struct", [
257258

258259
def LLVMPointerType : LLVMType<"LLVMPointer", "ptr", [
259260
DeclareTypeInterfaceMethods<DataLayoutTypeInterface, [
260-
"getIndexBitwidth", "areCompatible", "verifyEntries"]>]> {
261+
"getIndexBitwidth", "areCompatible", "verifyEntries"]>,
262+
PointerLike]> {
261263
let summary = "LLVM pointer type";
262264
let description = [{
263265
The `!llvm.ptr` type is an LLVM pointer type. This type typically represents

mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td

+3-1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
include "mlir/Interfaces/DataLayoutInterfaces.td"
1313
include "mlir/IR/AttrTypeBase.td"
1414
include "mlir/IR/BuiltinTypeInterfaces.td"
15+
include "mlir/IR/BuiltinTypes.td"
1516
include "mlir/IR/OpBase.td"
1617

1718
//===----------------------------------------------------------------------===//
@@ -38,7 +39,8 @@ class Ptr_Type<string name, string typeMnemonic, list<Trait> traits = []>
3839
def Ptr_PtrType : Ptr_Type<"Ptr", "ptr", [
3940
MemRefElementTypeInterface,
4041
DeclareTypeInterfaceMethods<DataLayoutTypeInterface, [
41-
"areCompatible", "getIndexBitwidth", "verifyEntries"]>
42+
"areCompatible", "getIndexBitwidth", "verifyEntries"]>,
43+
PointerLike
4244
]> {
4345
let summary = "pointer type";
4446
let description = [{

mlir/include/mlir/Dialect/Ptr/IR/PtrTypes.h

+1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#ifndef MLIR_DIALECT_PTR_IR_PTRTYPES_H
1414
#define MLIR_DIALECT_PTR_IR_PTRTYPES_H
1515

16+
#include "mlir/IR/BuiltinTypes.h"
1617
#include "mlir/IR/Types.h"
1718
#include "mlir/Interfaces/DataLayoutInterfaces.h"
1819

mlir/include/mlir/IR/BuiltinTypes.h

+4
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,10 @@ template <typename ConcreteType>
4343
class ValueSemantics
4444
: public TypeTrait::TraitBase<ConcreteType, ValueSemantics> {};
4545

46+
/// Type trait indicating that the type is a pointer-like type.
47+
template <typename ConcreteType>
48+
class PointerLike : public TypeTrait::TraitBase<ConcreteType, PointerLike> {};
49+
4650
//===----------------------------------------------------------------------===//
4751
// TensorType
4852
//===----------------------------------------------------------------------===//

mlir/include/mlir/IR/BuiltinTypes.td

+6-1
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,11 @@ def ValueSemantics : NativeTypeTrait<"ValueSemantics"> {
4040
let cppNamespace = "::mlir";
4141
}
4242

43+
/// Type trait indicating that the type is a pointer-like type.
44+
def PointerLike : NativeTypeTrait<"PointerLike"> {
45+
let cppNamespace = "::mlir";
46+
}
47+
4348
//===----------------------------------------------------------------------===//
4449
// ComplexType
4550
//===----------------------------------------------------------------------===//
@@ -1238,7 +1243,7 @@ def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", "unranked_tensor", [
12381243
// VectorType
12391244
//===----------------------------------------------------------------------===//
12401245

1241-
def Builtin_VectorTypeElementType : AnyTypeOf<[AnyInteger, Index, AnyFloat]> {
1246+
def Builtin_VectorTypeElementType : AnyTypeOf<[AnyInteger, Index, AnyFloat, AnyPointerLike]> {
12421247
let cppFunctionName = "isValidVectorTypeElementType";
12431248
}
12441249

mlir/include/mlir/IR/CommonTypeConstraints.td

+2
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,8 @@ def Index : Type<CPred<"::llvm::isa<::mlir::IndexType>($_self)">, "index",
301301
"::mlir::IndexType">,
302302
BuildableType<"$_builder.getIndexType()">;
303303

304+
def AnyPointerLike : Type<CPred<"$_self.hasTrait<::mlir::PointerLike>()">, "pointer-like", "::mlir::Type">;
305+
304306
// Any signless integer type or index type.
305307
def AnySignlessIntegerOrIndex : Type<CPred<"$_self.isSignlessIntOrIndex()">,
306308
"signless integer or index">;

mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -876,7 +876,8 @@ bool mlir::LLVM::isCompatibleFloatingPointType(Type type) {
876876
}
877877

878878
bool mlir::LLVM::isCompatibleVectorType(Type type) {
879-
if (llvm::isa<LLVMFixedVectorType, LLVMScalableVectorType>(type))
879+
if (llvm::isa<LLVMFixedVectorType, LLVMScalableVectorType, LLVMPointerType>(
880+
type))
880881
return true;
881882

882883
if (auto vecType = llvm::dyn_cast<VectorType>(type)) {

mlir/test/IR/test-verifiers-type.mlir

+15
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,18 @@
77

88
// expected-error @below{{failed to verify 'param': 16-bit signless integer or 32-bit signless integer}}
99
"test.type_producer"() : () -> !test.type_verification<f16>
10+
11+
// -----
12+
13+
// CHECK: "test.type_producer"() : () -> vector<!ptr.ptr<5 : i64>>
14+
"test.type_producer"() : () -> vector<!ptr.ptr<5>>
15+
16+
// -----
17+
18+
// CHECK: "test.type_producer"() : () -> vector<!llvm.ptr<1>>
19+
"test.type_producer"() : () -> vector<!llvm.ptr<1>>
20+
21+
// -----
22+
23+
// expected-error @below{{failed to verify 'elementType': integer or index or floating-point or pointer-like}}
24+
"test.type_producer"() : () -> vector<memref<2xf32>>

0 commit comments

Comments
 (0)