Skip to content

Commit 52fac3b

Browse files
[mlir][IR] Experiment: Allow ptr as vector element type
1 parent 491d3df commit 52fac3b

25 files changed

+132
-92
lines changed

mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h

+1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#ifndef MLIR_DIALECT_LLVMIR_LLVMTYPES_H_
1515
#define MLIR_DIALECT_LLVMIR_LLVMTYPES_H_
1616

17+
#include "mlir/IR/BuiltinTypes.h"
1718
#include "mlir/IR/Types.h"
1819
#include "mlir/Interfaces/DataLayoutInterfaces.h"
1920
#include "mlir/Interfaces/MemorySlotInterfaces.h"

mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td

+3-1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
1313
include "mlir/IR/AttrTypeBase.td"
14+
include "mlir/IR/BuiltinTypes.td"
1415
include "mlir/Interfaces/DataLayoutInterfaces.td"
1516
include "mlir/Interfaces/MemorySlotInterfaces.td"
1617

@@ -259,7 +260,8 @@ def LLVMStructType : LLVMType<"LLVMStruct", "struct", [
259260
def LLVMPointerType : LLVMType<"LLVMPointer", "ptr", [
260261
DeclareTypeInterfaceMethods<DataLayoutTypeInterface, [
261262
"getIndexBitwidth", "areCompatible", "verifyEntries",
262-
"getPreferredAlignment"]>]> {
263+
"getPreferredAlignment"]>,
264+
PointerLike]> {
263265
let summary = "LLVM pointer type";
264266
let description = [{
265267
The `!llvm.ptr` type is an LLVM pointer type. This type typically represents

mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td

+3-1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
include "mlir/Interfaces/DataLayoutInterfaces.td"
1313
include "mlir/IR/AttrTypeBase.td"
1414
include "mlir/IR/BuiltinTypeInterfaces.td"
15+
include "mlir/IR/BuiltinTypes.td"
1516
include "mlir/IR/OpBase.td"
1617

1718
//===----------------------------------------------------------------------===//
@@ -39,7 +40,8 @@ def Ptr_PtrType : Ptr_Type<"Ptr", "ptr", [
3940
MemRefElementTypeInterface,
4041
DeclareTypeInterfaceMethods<DataLayoutTypeInterface, [
4142
"areCompatible", "getIndexBitwidth", "verifyEntries",
42-
"getPreferredAlignment"]>
43+
"getPreferredAlignment"]>,
44+
PointerLike
4345
]> {
4446
let summary = "pointer type";
4547
let description = [{

mlir/include/mlir/Dialect/Ptr/IR/PtrTypes.h

+1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#define MLIR_DIALECT_PTR_IR_PTRTYPES_H
1515

1616
#include "mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.h"
17+
#include "mlir/IR/BuiltinTypes.h"
1718
#include "mlir/IR/Types.h"
1819
#include "mlir/Interfaces/DataLayoutInterfaces.h"
1920

mlir/include/mlir/IR/BuiltinTypes.h

+4
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,10 @@ template <typename ConcreteType>
4343
class ValueSemantics
4444
: public TypeTrait::TraitBase<ConcreteType, ValueSemantics> {};
4545

46+
/// Type trait indicating that the type is a pointer-like type.
47+
template <typename ConcreteType>
48+
class PointerLike : public TypeTrait::TraitBase<ConcreteType, PointerLike> {};
49+
4650
//===----------------------------------------------------------------------===//
4751
// TensorType
4852
//===----------------------------------------------------------------------===//

mlir/include/mlir/IR/BuiltinTypes.td

+11-1
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,11 @@ def ValueSemantics : NativeTypeTrait<"ValueSemantics"> {
4040
let cppNamespace = "::mlir";
4141
}
4242

43+
/// Type trait indicating that the type is a pointer-like type.
44+
def PointerLike : NativeTypeTrait<"PointerLike"> {
45+
let cppNamespace = "::mlir";
46+
}
47+
4348
//===----------------------------------------------------------------------===//
4449
// ComplexType
4550
//===----------------------------------------------------------------------===//
@@ -1249,7 +1254,12 @@ def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", "unranked_tensor", [
12491254
// VectorType
12501255
//===----------------------------------------------------------------------===//
12511256

1252-
def Builtin_VectorTypeElementType : AnyTypeOf<[AnyInteger, Index, AnyFloat]> {
1257+
// Note: VectorType supports pointer-like types as element types. Examples for
1258+
// pointer-like types are !llvm.ptr and !ptr.ptr. This makes the MLIR vector
1259+
// type design symmetric to the LLVM vector type. That's desirable because the
1260+
// MLIR vector type is used in the LLVM dialect.
1261+
def Builtin_VectorTypeElementType
1262+
: AnyTypeOf<[AnyInteger, Index, AnyFloat, AnyPointerLike]> {
12531263
let cppFunctionName = "isValidVectorTypeElementType";
12541264
}
12551265

mlir/include/mlir/IR/CommonTypeConstraints.td

+2
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,8 @@ def Index : Type<CPred<"::llvm::isa<::mlir::IndexType>($_self)">, "index",
301301
"::mlir::IndexType">,
302302
BuildableType<"$_builder.getIndexType()">;
303303

304+
def AnyPointerLike : Type<CPred<"$_self.hasTrait<::mlir::PointerLike>()">, "pointer-like", "::mlir::Type">;
305+
304306
// Any signless integer type or index type.
305307
def AnySignlessIntegerOrIndex : Type<CPred<"$_self.isSignlessIntOrIndex()">,
306308
"signless integer or index">;

mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp

+6-2
Original file line numberDiff line numberDiff line change
@@ -140,9 +140,13 @@ static bool isSupportedTypeForConversion(Type type) {
140140
if (isa<LLVM::LLVMFixedVectorType, LLVM::LLVMScalableVectorType>(type))
141141
return false;
142142

143-
// Scalable types are not supported.
144-
if (auto vectorType = dyn_cast<VectorType>(type))
143+
if (auto vectorType = dyn_cast<VectorType>(type)) {
144+
// Vectors of pointers cannot be casted.
145+
if (isa<LLVM::LLVMPointerType>(vectorType.getElementType()))
146+
return false;
147+
// Scalable types are not supported.
145148
return !vectorType.isScalable();
149+
}
146150
return true;
147151
}
148152

mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -690,7 +690,7 @@ LLVMFixedVectorType::getChecked(function_ref<InFlightDiagnostic()> emitError,
690690
}
691691

692692
bool LLVMFixedVectorType::isValidElementType(Type type) {
693-
return llvm::isa<LLVMPointerType, LLVMPPCFP128Type>(type);
693+
return llvm::isa<LLVMPPCFP128Type>(type);
694694
}
695695

696696
LogicalResult
@@ -890,7 +890,7 @@ bool mlir::LLVM::isCompatibleVectorType(Type type) {
890890
if (auto intType = llvm::dyn_cast<IntegerType>(elementType))
891891
return intType.isSignless();
892892
return llvm::isa<BFloat16Type, Float16Type, Float32Type, Float64Type,
893-
Float80Type, Float128Type>(elementType);
893+
Float80Type, Float128Type, LLVMPointerType>(elementType);
894894
}
895895
return false;
896896
}

mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir

+28-29
Original file line numberDiff line numberDiff line change
@@ -2002,8 +2002,8 @@ func.func @gather(%arg0: memref<?xf32>, %arg1: vector<3xi32>, %arg2: vector<3xi1
20022002
}
20032003

20042004
// CHECK-LABEL: func @gather
2005-
// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<3xi32>) -> !llvm.vec<3 x ptr>, f32
2006-
// CHECK: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.vec<3 x ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
2005+
// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<3xi32>) -> vector<3x!llvm.ptr>, f32
2006+
// CHECK: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (vector<3x!llvm.ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
20072007
// CHECK: return %[[G]] : vector<3xf32>
20082008

20092009
// -----
@@ -2015,8 +2015,8 @@ func.func @gather_scalable(%arg0: memref<?xf32>, %arg1: vector<[3]xi32>, %arg2:
20152015
}
20162016

20172017
// CHECK-LABEL: func @gather_scalable
2018-
// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<[3]xi32>) -> !llvm.vec<? x 3 x ptr>, f32
2019-
// CHECK: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.vec<? x 3 x ptr>, vector<[3]xi1>, vector<[3]xf32>) -> vector<[3]xf32>
2018+
// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<[3]xi32>) -> vector<[3]x!llvm.ptr>, f32
2019+
// CHECK: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (vector<[3]x!llvm.ptr>, vector<[3]xi1>, vector<[3]xf32>) -> vector<[3]xf32>
20202020
// CHECK: return %[[G]] : vector<[3]xf32>
20212021

20222022
// -----
@@ -2028,8 +2028,8 @@ func.func @gather_global_memory(%arg0: memref<?xf32, 1>, %arg1: vector<3xi32>, %
20282028
}
20292029

20302030
// CHECK-LABEL: func @gather_global_memory
2031-
// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr<1>, vector<3xi32>) -> !llvm.vec<3 x ptr<1>>, f32
2032-
// CHECK: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.vec<3 x ptr<1>>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
2031+
// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr<1>, vector<3xi32>) -> vector<3x!llvm.ptr<1>>, f32
2032+
// CHECK: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (vector<3x!llvm.ptr<1>>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
20332033
// CHECK: return %[[G]] : vector<3xf32>
20342034

20352035
// -----
@@ -2041,8 +2041,8 @@ func.func @gather_global_memory_scalable(%arg0: memref<?xf32, 1>, %arg1: vector<
20412041
}
20422042

20432043
// CHECK-LABEL: func @gather_global_memory_scalable
2044-
// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr<1>, vector<[3]xi32>) -> !llvm.vec<? x 3 x ptr<1>>, f32
2045-
// CHECK: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.vec<? x 3 x ptr<1>>, vector<[3]xi1>, vector<[3]xf32>) -> vector<[3]xf32>
2044+
// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr<1>, vector<[3]xi32>) -> vector<[3]x!llvm.ptr<1>>, f32
2045+
// CHECK: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (vector<[3]x!llvm.ptr<1>>, vector<[3]xi1>, vector<[3]xf32>) -> vector<[3]xf32>
20462046
// CHECK: return %[[G]] : vector<[3]xf32>
20472047

20482048
// -----
@@ -2055,8 +2055,8 @@ func.func @gather_index(%arg0: memref<?xindex>, %arg1: vector<3xindex>, %arg2: v
20552055
}
20562056

20572057
// CHECK-LABEL: func @gather_index
2058-
// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<3xi64>) -> !llvm.vec<3 x ptr>, i64
2059-
// CHECK: %[[G:.*]] = llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 8 : i32} : (!llvm.vec<3 x ptr>, vector<3xi1>, vector<3xi64>) -> vector<3xi64>
2058+
// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<3xi64>) -> vector<3x!llvm.ptr>, i64
2059+
// CHECK: %[[G:.*]] = llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 8 : i32} : (vector<3x!llvm.ptr>, vector<3xi1>, vector<3xi64>) -> vector<3xi64>
20602060
// CHECK: %{{.*}} = builtin.unrealized_conversion_cast %[[G]] : vector<3xi64> to vector<3xindex>
20612061

20622062
// -----
@@ -2068,13 +2068,12 @@ func.func @gather_index_scalable(%arg0: memref<?xindex>, %arg1: vector<[3]xindex
20682068
}
20692069

20702070
// CHECK-LABEL: func @gather_index_scalable
2071-
// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<[3]xi64>) -> !llvm.vec<? x 3 x ptr>, i64
2072-
// CHECK: %[[G:.*]] = llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 8 : i32} : (!llvm.vec<? x 3 x ptr>, vector<[3]xi1>, vector<[3]xi64>) -> vector<[3]xi64>
2071+
// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<[3]xi64>) -> vector<[3]x!llvm.ptr>, i64
2072+
// CHECK: %[[G:.*]] = llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 8 : i32} : (vector<[3]x!llvm.ptr>, vector<[3]xi1>, vector<[3]xi64>) -> vector<[3]xi64>
20732073
// CHECK: %{{.*}} = builtin.unrealized_conversion_cast %[[G]] : vector<[3]xi64> to vector<[3]xindex>
20742074

20752075
// -----
20762076

2077-
20782077
func.func @gather_1d_from_2d(%arg0: memref<4x4xf32>, %arg1: vector<4xi32>, %arg2: vector<4xi1>, %arg3: vector<4xf32>) -> vector<4xf32> {
20792078
%0 = arith.constant 3 : index
20802079
%1 = vector.gather %arg0[%0, %0][%arg1], %arg2, %arg3 : memref<4x4xf32>, vector<4xi32>, vector<4xi1>, vector<4xf32> into vector<4xf32>
@@ -2083,8 +2082,8 @@ func.func @gather_1d_from_2d(%arg0: memref<4x4xf32>, %arg1: vector<4xi32>, %arg2
20832082

20842083
// CHECK-LABEL: func @gather_1d_from_2d
20852084
// CHECK: %[[B:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, i64) -> !llvm.ptr, f32
2086-
// CHECK: %[[P:.*]] = llvm.getelementptr %[[B]][%{{.*}}] : (!llvm.ptr, vector<4xi32>) -> !llvm.vec<4 x ptr>, f32
2087-
// CHECK: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.vec<4 x ptr>, vector<4xi1>, vector<4xf32>) -> vector<4xf32>
2085+
// CHECK: %[[P:.*]] = llvm.getelementptr %[[B]][%{{.*}}] : (!llvm.ptr, vector<4xi32>) -> vector<4x!llvm.ptr>, f32
2086+
// CHECK: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (vector<4x!llvm.ptr>, vector<4xi1>, vector<4xf32>) -> vector<4xf32>
20882087
// CHECK: return %[[G]] : vector<4xf32>
20892088

20902089
// -----
@@ -2097,8 +2096,8 @@ func.func @gather_1d_from_2d_scalable(%arg0: memref<4x?xf32>, %arg1: vector<[4]x
20972096

20982097
// CHECK-LABEL: func @gather_1d_from_2d_scalable
20992098
// CHECK: %[[B:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, i64) -> !llvm.ptr, f32
2100-
// CHECK: %[[P:.*]] = llvm.getelementptr %[[B]][%{{.*}}] : (!llvm.ptr, vector<[4]xi32>) -> !llvm.vec<? x 4 x ptr>, f32
2101-
// CHECK: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.vec<? x 4 x ptr>, vector<[4]xi1>, vector<[4]xf32>) -> vector<[4]xf32>
2099+
// CHECK: %[[P:.*]] = llvm.getelementptr %[[B]][%{{.*}}] : (!llvm.ptr, vector<[4]xi32>) -> vector<[4]x!llvm.ptr>, f32
2100+
// CHECK: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (vector<[4]x!llvm.ptr>, vector<[4]xi1>, vector<[4]xf32>) -> vector<[4]xf32>
21022101
// CHECK: return %[[G]] : vector<[4]xf32>
21032102

21042103
// -----
@@ -2114,8 +2113,8 @@ func.func @scatter(%arg0: memref<?xf32>, %arg1: vector<3xi32>, %arg2: vector<3xi
21142113
}
21152114

21162115
// CHECK-LABEL: func @scatter
2117-
// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<3xi32>) -> !llvm.vec<3 x ptr>, f32
2118-
// CHECK: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 4 : i32} : vector<3xf32>, vector<3xi1> into !llvm.vec<3 x ptr>
2116+
// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<3xi32>) -> vector<3x!llvm.ptr>, f32
2117+
// CHECK: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 4 : i32} : vector<3xf32>, vector<3xi1> into vector<3x!llvm.ptr>
21192118

21202119
// -----
21212120

@@ -2126,8 +2125,8 @@ func.func @scatter_scalable(%arg0: memref<?xf32>, %arg1: vector<[3]xi32>, %arg2:
21262125
}
21272126

21282127
// CHECK-LABEL: func @scatter_scalable
2129-
// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<[3]xi32>) -> !llvm.vec<? x 3 x ptr>, f32
2130-
// CHECK: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 4 : i32} : vector<[3]xf32>, vector<[3]xi1> into !llvm.vec<? x 3 x ptr>
2128+
// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<[3]xi32>) -> vector<[3]x!llvm.ptr>, f32
2129+
// CHECK: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 4 : i32} : vector<[3]xf32>, vector<[3]xi1> into vector<[3]x!llvm.ptr>
21312130

21322131
// -----
21332132

@@ -2138,8 +2137,8 @@ func.func @scatter_index(%arg0: memref<?xindex>, %arg1: vector<3xindex>, %arg2:
21382137
}
21392138

21402139
// CHECK-LABEL: func @scatter_index
2141-
// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<3xi64>) -> !llvm.vec<3 x ptr>, i64
2142-
// CHECK: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 8 : i32} : vector<3xi64>, vector<3xi1> into !llvm.vec<3 x ptr>
2140+
// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<3xi64>) -> vector<3x!llvm.ptr>, i64
2141+
// CHECK: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 8 : i32} : vector<3xi64>, vector<3xi1> into vector<3x!llvm.ptr>
21432142

21442143
// -----
21452144

@@ -2150,8 +2149,8 @@ func.func @scatter_index_scalable(%arg0: memref<?xindex>, %arg1: vector<[3]xinde
21502149
}
21512150

21522151
// CHECK-LABEL: func @scatter_index_scalable
2153-
// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<[3]xi64>) -> !llvm.vec<? x 3 x ptr>, i64
2154-
// CHECK: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 8 : i32} : vector<[3]xi64>, vector<[3]xi1> into !llvm.vec<? x 3 x ptr>
2152+
// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<[3]xi64>) -> vector<[3]x!llvm.ptr>, i64
2153+
// CHECK: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 8 : i32} : vector<[3]xi64>, vector<[3]xi1> into vector<[3]x!llvm.ptr>
21552154

21562155
// -----
21572156

@@ -2163,8 +2162,8 @@ func.func @scatter_1d_into_2d(%arg0: memref<4x4xf32>, %arg1: vector<4xi32>, %arg
21632162

21642163
// CHECK-LABEL: func @scatter_1d_into_2d
21652164
// CHECK: %[[B:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, i64) -> !llvm.ptr, f32
2166-
// CHECK: %[[P:.*]] = llvm.getelementptr %[[B]][%{{.*}}] : (!llvm.ptr, vector<4xi32>) -> !llvm.vec<4 x ptr>, f32
2167-
// CHECK: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 4 : i32} : vector<4xf32>, vector<4xi1> into !llvm.vec<4 x ptr>
2165+
// CHECK: %[[P:.*]] = llvm.getelementptr %[[B]][%{{.*}}] : (!llvm.ptr, vector<4xi32>) -> vector<4x!llvm.ptr>, f32
2166+
// CHECK: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 4 : i32} : vector<4xf32>, vector<4xi1> into vector<4x!llvm.ptr>
21682167

21692168
// -----
21702169

@@ -2176,8 +2175,8 @@ func.func @scatter_1d_into_2d_scalable(%arg0: memref<4x?xf32>, %arg1: vector<[4]
21762175

21772176
// CHECK-LABEL: func @scatter_1d_into_2d_scalable
21782177
// CHECK: %[[B:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, i64) -> !llvm.ptr, f32
2179-
// CHECK: %[[P:.*]] = llvm.getelementptr %[[B]][%{{.*}}] : (!llvm.ptr, vector<[4]xi32>) -> !llvm.vec<? x 4 x ptr>, f32
2180-
// CHECK: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 4 : i32} : vector<[4]xf32>, vector<[4]xi1> into !llvm.vec<? x 4 x ptr>
2178+
// CHECK: %[[P:.*]] = llvm.getelementptr %[[B]][%{{.*}}] : (!llvm.ptr, vector<[4]xi32>) -> vector<[4]x!llvm.ptr>, f32
2179+
// CHECK: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 4 : i32} : vector<[4]xf32>, vector<[4]xi1> into vector<[4]x!llvm.ptr>
21812180

21822181
// -----
21832182

mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir

+4-4
Original file line numberDiff line numberDiff line change
@@ -1669,8 +1669,8 @@ func.func @gather_with_mask(%arg0: memref<?xf32>, %arg1: vector<2x3xi32>, %arg2:
16691669
}
16701670

16711671
// CHECK-LABEL: func @gather_with_mask
1672-
// CHECK: %[[G0:.*]] = llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.vec<3 x ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
1673-
// CHECK: %[[G1:.*]] = llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.vec<3 x ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
1672+
// CHECK: %[[G0:.*]] = llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : (vector<3x!llvm.ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
1673+
// CHECK: %[[G1:.*]] = llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : (vector<3x!llvm.ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
16741674

16751675
// -----
16761676

@@ -1685,8 +1685,8 @@ func.func @gather_with_mask_scalable(%arg0: memref<?xf32>, %arg1: vector<2x[3]xi
16851685
}
16861686

16871687
// CHECK-LABEL: func @gather_with_mask_scalable
1688-
// CHECK: %[[G0:.*]] = llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.vec<? x 3 x ptr>, vector<[3]xi1>, vector<[3]xf32>) -> vector<[3]xf32>
1689-
// CHECK: %[[G1:.*]] = llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.vec<? x 3 x ptr>, vector<[3]xi1>, vector<[3]xf32>) -> vector<[3]xf32>
1688+
// CHECK: %[[G0:.*]] = llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : (vector<[3]x!llvm.ptr>, vector<[3]xi1>, vector<[3]xf32>) -> vector<[3]xf32>
1689+
// CHECK: %[[G1:.*]] = llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : (vector<[3]x!llvm.ptr>, vector<[3]xi1>, vector<[3]xf32>) -> vector<[3]xf32>
16901690

16911691

16921692
// -----

mlir/test/Dialect/LLVMIR/invalid.mlir

+5-5
Original file line numberDiff line numberDiff line change
@@ -1328,16 +1328,16 @@ func.func @invalid_bitcast_i64_to_ptr() {
13281328

13291329
// -----
13301330

1331-
func.func @invalid_bitcast_vec_to_ptr(%arg : !llvm.vec<4 x ptr>) {
1331+
func.func @invalid_bitcast_vec_to_ptr(%arg : vector<4x!llvm.ptr>) {
13321332
// expected-error@+1 {{cannot cast vector of pointers to pointer}}
1333-
%0 = llvm.bitcast %arg : !llvm.vec<4 x ptr> to !llvm.ptr
1333+
%0 = llvm.bitcast %arg : vector<4x!llvm.ptr> to !llvm.ptr
13341334
}
13351335

13361336
// -----
13371337

13381338
func.func @invalid_bitcast_ptr_to_vec(%arg : !llvm.ptr) {
13391339
// expected-error@+1 {{cannot cast pointer to vector of pointers}}
1340-
%0 = llvm.bitcast %arg : !llvm.ptr to !llvm.vec<4 x ptr>
1340+
%0 = llvm.bitcast %arg : !llvm.ptr to vector<4x!llvm.ptr>
13411341
}
13421342

13431343
// -----
@@ -1349,9 +1349,9 @@ func.func @invalid_bitcast_addr_cast(%arg : !llvm.ptr<1>) {
13491349

13501350
// -----
13511351

1352-
func.func @invalid_bitcast_addr_cast_vec(%arg : !llvm.vec<4 x ptr<1>>) {
1352+
func.func @invalid_bitcast_addr_cast_vec(%arg : vector<4x!llvm.ptr<1>>) {
13531353
// expected-error@+1 {{cannot cast pointers of different address spaces, use 'llvm.addrspacecast' instead}}
1354-
%0 = llvm.bitcast %arg : !llvm.vec<4 x ptr<1>> to !llvm.vec<4 x ptr>
1354+
%0 = llvm.bitcast %arg : vector<4x!llvm.ptr<1>> to vector<4x!llvm.ptr>
13551355
}
13561356

13571357
// -----

mlir/test/Dialect/LLVMIR/mem2reg.mlir

+1-1
Original file line numberDiff line numberDiff line change
@@ -1011,7 +1011,7 @@ llvm.func @load_first_vector_elem() -> i16 {
10111011
llvm.func @load_first_llvm_vector_elem() -> i16 {
10121012
%0 = llvm.mlir.constant(1 : i32) : i32
10131013
// CHECK: llvm.alloca
1014-
%1 = llvm.alloca %0 x !llvm.vec<4 x ptr> : (i32) -> !llvm.ptr
1014+
%1 = llvm.alloca %0 x vector<4x!llvm.ptr> : (i32) -> !llvm.ptr
10151015
%2 = llvm.load %1 : !llvm.ptr -> i16
10161016
llvm.return %2 : i16
10171017
}

mlir/test/Dialect/LLVMIR/opaque-ptr.mlir

+3-3
Original file line numberDiff line numberDiff line change
@@ -68,10 +68,10 @@ llvm.func @opaque_ptr_masked_load(%arg0: !llvm.ptr, %arg1: vector<7xi1>) -> vect
6868
}
6969

7070
// CHECK-LABEL: @opaque_ptr_gather
71-
llvm.func @opaque_ptr_gather(%M: !llvm.vec<7 x ptr>, %mask: vector<7xi1>) -> vector<7xf32> {
71+
llvm.func @opaque_ptr_gather(%M: vector<7x!llvm.ptr>, %mask: vector<7xi1>) -> vector<7xf32> {
7272
// CHECK: = llvm.intr.masked.gather
73-
// CHECK: (!llvm.vec<7 x ptr>, vector<7xi1>) -> vector<7xf32>
73+
// CHECK: (vector<7x!llvm.ptr>, vector<7xi1>) -> vector<7xf32>
7474
%a = llvm.intr.masked.gather %M, %mask { alignment = 1: i32} :
75-
(!llvm.vec<7 x ptr>, vector<7xi1>) -> vector<7xf32>
75+
(vector<7x!llvm.ptr>, vector<7xi1>) -> vector<7xf32>
7676
llvm.return %a : vector<7xf32>
7777
}

0 commit comments

Comments
 (0)