Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][Vector] Infer mask and pass_thru types for maskedload/store #131482

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

Groverkss
Copy link
Member

@Groverkss Groverkss commented Mar 16, 2025

The mask and pass_thru types can be completly inferred from the return type. There is no need to specify these types in the operation assembly format.

The type format now exactly matches vector.load and vector.store with the only difference being one takes a mask and the other doesn't.

@llvmbot
Copy link
Member

llvmbot commented Mar 16, 2025

@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir-memref

Author: Kunwar Grover (Groverkss)

Changes

Patch is 84.25 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/131482.diff

20 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.td (+25-8)
  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+6-23)
  • (modified) mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir (+1-1)
  • (modified) mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir (+8-8)
  • (modified) mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir (+8-8)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_vector.mlir (+16-16)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir (+2-2)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_vector_index.mlir (+3-3)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_vector_peeled.mlir (+2-2)
  • (modified) mlir/test/Dialect/SparseTensor/vectorize_reduction.mlir (+7-7)
  • (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+5-5)
  • (modified) mlir/test/Dialect/Vector/emulate-narrow-type-unsupported.mlir (+3-3)
  • (modified) mlir/test/Dialect/Vector/invalid.mlir (+15-12)
  • (modified) mlir/test/Dialect/Vector/ops.mlir (+8-8)
  • (modified) mlir/test/Dialect/Vector/vector-emulate-masked-load-store.mlir (+2-2)
  • (modified) mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir (+8-8)
  • (modified) mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir (+32-32)
  • (modified) mlir/test/Dialect/Vector/vector-mem-transforms.mlir (+6-6)
  • (modified) mlir/test/mlir-tblgen/op-result.td (+7-11)
  • (modified) mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp (+23-28)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index fbbf817ecff98..fd77249402934 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -1840,7 +1840,16 @@ def Vector_StoreOp : Vector_Op<"store"> {
 }
 
 def Vector_MaskedLoadOp :
-  Vector_Op<"maskedload">,
+    Vector_Op<"maskedload", [
+      AllTypesMatch<["result", "pass_thru"]>,
+      TypesMatchWith<"mask shape should match result shape",
+        "result",
+        "mask",
+        "VectorType::get(::llvm::cast<VectorType>($_self).getShape(),"
+          "IntegerType::get($_ctxt, 1),"
+          "::llvm::cast<VectorType>($_self).getScalableDims())">,
+      AllElementTypesMatch<["result", "base"]>
+    ]>,
     Arguments<(ins Arg<AnyMemRef, "", [MemRead]>:$base,
                Variadic<Index>:$indices,
                VectorOfNonZeroRankOf<[I1]>:$mask,
@@ -1875,10 +1884,10 @@ def Vector_MaskedLoadOp :
 
     ```mlir
     %0 = vector.maskedload %base[%i], %mask, %pass_thru
-       : memref<?xf32>, vector<8xi1>, vector<8xf32> into vector<8xf32>
+       : memref<?xf32>, vector<8xf32>
 
     %1 = vector.maskedload %base[%i, %j], %mask, %pass_thru
-       : memref<?x?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+       : memref<?x?xf32>, vector<16xf32>
     ```
   }];
   let extraClassDeclaration = [{
@@ -1896,14 +1905,22 @@ def Vector_MaskedLoadOp :
     }
   }];
   let assemblyFormat = "$base `[` $indices `]` `,` $mask `,` $pass_thru attr-dict `:` "
-    "type($base) `,` type($mask) `,` type($pass_thru) `into` type($result)";
+    "type($base) `,` type($result)";
   let hasCanonicalizer = 1;
   let hasFolder = 1;
   let hasVerifier = 1;
 }
 
 def Vector_MaskedStoreOp :
-  Vector_Op<"maskedstore">,
+    Vector_Op<"maskedstore", [
+      TypesMatchWith<"mask shape should match result shape",
+        "valueToStore",
+        "mask",
+        "VectorType::get(::llvm::cast<VectorType>($_self).getShape(),"
+          "IntegerType::get($_ctxt, 1),"
+          "::llvm::cast<VectorType>($_self).getScalableDims())">,
+      AllElementTypesMatch<["valueToStore", "base"]>
+    ]>,
     Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
                Variadic<Index>:$indices,
                VectorOfNonZeroRankOf<[I1]>:$mask,
@@ -1937,10 +1954,10 @@ def Vector_MaskedStoreOp :
 
     ```mlir
     vector.maskedstore %base[%i], %mask, %value
-      : memref<?xf32>, vector<8xi1>, vector<8xf32>
+      : memref<?xf32>, vector<8xf32>
 
     vector.maskedstore %base[%i, %j], %mask, %value
-      : memref<?x?xf32>, vector<16xi1>, vector<16xf32>
+      : memref<?x?xf32>, vector<16xf32>
     ```
   }];
   let extraClassDeclaration = [{
@@ -1956,7 +1973,7 @@ def Vector_MaskedStoreOp :
   }];
   let assemblyFormat =
       "$base `[` $indices `]` `,` $mask `,` $valueToStore "
-      "attr-dict `:` type($base) `,` type($mask) `,` type($valueToStore)";
+      "attr-dict `:` type($base) `,` type($valueToStore)";
   let hasCanonicalizer = 1;
   let hasFolder = 1;
   let hasVerifier = 1;
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 8e0e723cf4ed3..83b962e54110a 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5127,19 +5127,9 @@ LogicalResult StoreOp::fold(FoldAdaptor adaptor,
 //===----------------------------------------------------------------------===//
 
 LogicalResult MaskedLoadOp::verify() {
-  VectorType maskVType = getMaskVectorType();
-  VectorType passVType = getPassThruVectorType();
-  VectorType resVType = getVectorType();
-  MemRefType memType = getMemRefType();
-
-  if (resVType.getElementType() != memType.getElementType())
-    return emitOpError("base and result element type should match");
-  if (llvm::size(getIndices()) != memType.getRank())
-    return emitOpError("requires ") << memType.getRank() << " indices";
-  if (resVType.getShape() != maskVType.getShape())
-    return emitOpError("expected result shape to match mask shape");
-  if (resVType != passVType)
-    return emitOpError("expected pass_thru of same type as result type");
+  int64_t memRank = getMemRefType().getRank();
+  if (llvm::size(getIndices()) != memRank)
+    return emitOpError("requires ") << memRank << " indices";
   return success();
 }
 
@@ -5181,16 +5171,9 @@ OpFoldResult MaskedLoadOp::fold(FoldAdaptor) {
 //===----------------------------------------------------------------------===//
 
 LogicalResult MaskedStoreOp::verify() {
-  VectorType maskVType = getMaskVectorType();
-  VectorType valueVType = getVectorType();
-  MemRefType memType = getMemRefType();
-
-  if (valueVType.getElementType() != memType.getElementType())
-    return emitOpError("base and valueToStore element type should match");
-  if (llvm::size(getIndices()) != memType.getRank())
-    return emitOpError("requires ") << memType.getRank() << " indices";
-  if (valueVType.getShape() != maskVType.getShape())
-    return emitOpError("expected valueToStore shape to match mask shape");
+  int64_t memRank = getMemRefType().getRank();
+  if (llvm::size(getIndices()) != memRank)
+    return emitOpError("requires ") << memRank << " indices";
   return success();
 }
 
diff --git a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
index 4ae710aa29113..10224aec95d48 100644
--- a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
+++ b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
@@ -88,7 +88,7 @@ func.func @arm_sme_tile_load_hor_with_mask_and_pad_zero(%src : memref<?x?xi32>)
 // CHECK-NEXT:        %[[MASK_1D:.*]] = vector.create_mask %[[MASK_INDEX]] : vector<[4]xi1>
 // CHECK-NEXT:        %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index
 // CHECK:             %[[PAD_1D:.*]] = vector.splat %[[PAD]] : vector<[4]xi32>
-// CHECK:             %[[LOAD_SLICE:.*]] = vector.maskedload %[[SRC]]{{\[}}%[[OFFSET]], %[[C0]]], %[[MASK_1D]], %[[PAD_1D]] : memref<?x?xi32>, vector<[4]xi1>, vector<[4]xi32> into vector<[4]xi32>
+// CHECK:             %[[LOAD_SLICE:.*]] = vector.maskedload %[[SRC]]{{\[}}%[[OFFSET]], %[[C0]]], %[[MASK_1D]], %[[PAD_1D]] : memref<?x?xi32>, vector<[4]xi32>
 // CHECK:             %[[TILE_UPDATE:.*]] = arm_sme.insert_tile_slice %[[LOAD_SLICE]], %[[CURRENT_TILE]][%[[TILE_SLICE_INDEX]]] : vector<[4]xi32> into vector<[4]x[4]xi32>
 // CHECK-NEXT:        scf.yield %[[TILE_UPDATE]] : vector<[4]x[4]xi32>
 func.func @arm_sme_tile_load_hor_with_mask_and_nonzero_pad(%src : memref<?x?xi32>, %pad : i32) {
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
index c3f06dd4d5dd1..6dcae67abda57 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
@@ -1891,7 +1891,7 @@ func.func @store_0d(%memref : memref<200x100xf32>, %i : index, %j : index) {
 
 func.func @masked_load(%arg0: memref<?xf32>, %arg1: vector<16xi1>, %arg2: vector<16xf32>) -> vector<16xf32> {
   %c0 = arith.constant 0: index
-  %0 = vector.maskedload %arg0[%c0], %arg1, %arg2 : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+  %0 = vector.maskedload %arg0[%c0], %arg1, %arg2 : memref<?xf32>, vector<16xf32>
   return %0 : vector<16xf32>
 }
 
@@ -1906,7 +1906,7 @@ func.func @masked_load(%arg0: memref<?xf32>, %arg1: vector<16xi1>, %arg2: vector
 
 func.func @masked_load_scalable(%arg0: memref<?xf32>, %arg1: vector<[16]xi1>, %arg2: vector<[16]xf32>) -> vector<[16]xf32> {
   %c0 = arith.constant 0: index
-  %0 = vector.maskedload %arg0[%c0], %arg1, %arg2 : memref<?xf32>, vector<[16]xi1>, vector<[16]xf32> into vector<[16]xf32>
+  %0 = vector.maskedload %arg0[%c0], %arg1, %arg2 : memref<?xf32>, vector<[16]xf32>
   return %0 : vector<[16]xf32>
 }
 
@@ -1921,7 +1921,7 @@ func.func @masked_load_scalable(%arg0: memref<?xf32>, %arg1: vector<[16]xi1>, %a
 
 func.func @masked_load_index(%arg0: memref<?xindex>, %arg1: vector<16xi1>, %arg2: vector<16xindex>) -> vector<16xindex> {
   %c0 = arith.constant 0: index
-  %0 = vector.maskedload %arg0[%c0], %arg1, %arg2 : memref<?xindex>, vector<16xi1>, vector<16xindex> into vector<16xindex>
+  %0 = vector.maskedload %arg0[%c0], %arg1, %arg2 : memref<?xindex>, vector<16xindex>
   return %0 : vector<16xindex>
 }
 // CHECK-LABEL: func @masked_load_index
@@ -1931,7 +1931,7 @@ func.func @masked_load_index(%arg0: memref<?xindex>, %arg1: vector<16xi1>, %arg2
 
 func.func @masked_load_index_scalable(%arg0: memref<?xindex>, %arg1: vector<[16]xi1>, %arg2: vector<[16]xindex>) -> vector<[16]xindex> {
   %c0 = arith.constant 0: index
-  %0 = vector.maskedload %arg0[%c0], %arg1, %arg2 : memref<?xindex>, vector<[16]xi1>, vector<[16]xindex> into vector<[16]xindex>
+  %0 = vector.maskedload %arg0[%c0], %arg1, %arg2 : memref<?xindex>, vector<[16]xindex>
   return %0 : vector<[16]xindex>
 }
 // CHECK-LABEL: func @masked_load_index_scalable
@@ -1945,7 +1945,7 @@ func.func @masked_load_index_scalable(%arg0: memref<?xindex>, %arg1: vector<[16]
 
 func.func @masked_store(%arg0: memref<?xf32>, %arg1: vector<16xi1>, %arg2: vector<16xf32>) {
   %c0 = arith.constant 0: index
-  vector.maskedstore %arg0[%c0], %arg1, %arg2 : memref<?xf32>, vector<16xi1>, vector<16xf32>
+  vector.maskedstore %arg0[%c0], %arg1, %arg2 : memref<?xf32>, vector<16xf32>
   return
 }
 
@@ -1959,7 +1959,7 @@ func.func @masked_store(%arg0: memref<?xf32>, %arg1: vector<16xi1>, %arg2: vecto
 
 func.func @masked_store_scalable(%arg0: memref<?xf32>, %arg1: vector<[16]xi1>, %arg2: vector<[16]xf32>) {
   %c0 = arith.constant 0: index
-  vector.maskedstore %arg0[%c0], %arg1, %arg2 : memref<?xf32>, vector<[16]xi1>, vector<[16]xf32>
+  vector.maskedstore %arg0[%c0], %arg1, %arg2 : memref<?xf32>, vector<[16]xf32>
   return
 }
 
@@ -1973,7 +1973,7 @@ func.func @masked_store_scalable(%arg0: memref<?xf32>, %arg1: vector<[16]xi1>, %
 
 func.func @masked_store_index(%arg0: memref<?xindex>, %arg1: vector<16xi1>, %arg2: vector<16xindex>) {
   %c0 = arith.constant 0: index
-  vector.maskedstore %arg0[%c0], %arg1, %arg2 : memref<?xindex>, vector<16xi1>, vector<16xindex>
+  vector.maskedstore %arg0[%c0], %arg1, %arg2 : memref<?xindex>, vector<16xindex>
   return
 }
 // CHECK-LABEL: func @masked_store_index
@@ -1983,7 +1983,7 @@ func.func @masked_store_index(%arg0: memref<?xindex>, %arg1: vector<16xi1>, %arg
 
 func.func @masked_store_index_scalable(%arg0: memref<?xindex>, %arg1: vector<[16]xi1>, %arg2: vector<[16]xindex>) {
   %c0 = arith.constant 0: index
-  vector.maskedstore %arg0[%c0], %arg1, %arg2 : memref<?xindex>, vector<[16]xi1>, vector<[16]xindex>
+  vector.maskedstore %arg0[%c0], %arg1, %arg2 : memref<?xindex>, vector<[16]xindex>
   return
 }
 // CHECK-LABEL: func @masked_store_index_scalable
diff --git a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
index 327cacf7d9a20..7246ae4884a19 100644
--- a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
+++ b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
@@ -837,7 +837,7 @@ func.func @fold_vector_load_subview(
 func.func @fold_vector_maskedload_subview(
   %arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3: vector<32xi1>, %arg4: vector<32xf32>) -> vector<32xf32> {
   %0 = memref.subview %arg0[%arg1, %arg2][1, 1][1, 1] : memref<12x32xf32> to memref<f32, strided<[], offset: ?>>
-  %1 = vector.maskedload %0[], %arg3, %arg4 : memref<f32, strided<[], offset: ?>>, vector<32xi1>, vector<32xf32> into vector<32xf32>
+  %1 = vector.maskedload %0[], %arg3, %arg4 : memref<f32, strided<[], offset: ?>>, vector<32xf32>
   return %1 : vector<32xf32>
 }
 
@@ -847,7 +847,7 @@ func.func @fold_vector_maskedload_subview(
 // CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: index
 // CHECK-SAME:   %[[ARG3:[a-zA-Z0-9_]+]]: vector<32xi1>
 // CHECK-SAME:   %[[ARG4:[a-zA-Z0-9_]+]]: vector<32xf32>
-//      CHECK:   vector.maskedload %[[ARG0]][%[[ARG1]], %[[ARG2]]], %[[ARG3]], %[[ARG4]] : memref<12x32xf32>, vector<32xi1>, vector<32xf32> into vector<32xf32>
+//      CHECK:   vector.maskedload %[[ARG0]][%[[ARG1]], %[[ARG2]]], %[[ARG3]], %[[ARG4]] : memref<12x32xf32>, vector<32xf32>
 
 // -----
 
@@ -871,7 +871,7 @@ func.func @fold_vector_store_subview(
 func.func @fold_vector_maskedstore_subview(
   %arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3: vector<32xi1>, %arg4: vector<32xf32>) -> () {
   %0 = memref.subview %arg0[%arg1, %arg2][1, 1][1, 1] : memref<12x32xf32> to memref<f32, strided<[], offset: ?>>
-  vector.maskedstore %0[], %arg3, %arg4 : memref<f32, strided<[], offset: ?>>, vector<32xi1>, vector<32xf32>
+  vector.maskedstore %0[], %arg3, %arg4 : memref<f32, strided<[], offset: ?>>, vector<32xf32>
   return
 }
 
@@ -881,7 +881,7 @@ func.func @fold_vector_maskedstore_subview(
 // CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: index
 // CHECK-SAME:   %[[ARG3:[a-zA-Z0-9_]+]]: vector<32xi1>
 // CHECK-SAME:   %[[ARG4:[a-zA-Z0-9_]+]]: vector<32xf32>
-//      CHECK:   vector.maskedstore %[[ARG0]][%[[ARG1]], %[[ARG2]]], %[[ARG3]], %[[ARG4]] : memref<12x32xf32>, vector<32xi1>, vector<32xf32>
+//      CHECK:   vector.maskedstore %[[ARG0]][%[[ARG1]], %[[ARG2]]], %[[ARG3]], %[[ARG4]] : memref<12x32xf32>, vector<32xf32>
 //      CHECK:   return
 
 // -----
@@ -907,7 +907,7 @@ func.func @fold_vector_maskedload_expand_shape(
   %arg0 : memref<32xf32>, %arg1 : index, %arg3: vector<8xi1>, %arg4: vector<8xf32>) -> vector<8xf32> {
   %c0 = arith.constant 0 : index
   %0 = memref.expand_shape %arg0 [[0, 1]] output_shape [4, 8] : memref<32xf32> into memref<4x8xf32>
-  %1 = vector.maskedload %0[%arg1, %c0], %arg3, %arg4 : memref<4x8xf32>, vector<8xi1>, vector<8xf32> into vector<8xf32>
+  %1 = vector.maskedload %0[%arg1, %c0], %arg3, %arg4 : memref<4x8xf32>, vector<8xf32>
   return %1 : vector<8xf32>
 }
 
@@ -943,7 +943,7 @@ func.func @fold_vector_maskedstore_expand_shape(
   %arg0 : memref<32xf32>, %arg1 : index, %arg3: vector<8xi1>, %arg4: vector<8xf32>) {
   %c0 = arith.constant 0 : index
   %0 = memref.expand_shape %arg0 [[0, 1]] output_shape [4, 8] : memref<32xf32> into memref<4x8xf32>
-  vector.maskedstore %0[%arg1, %c0], %arg3, %arg4 : memref<4x8xf32>, vector<8xi1>, vector<8xf32>
+  vector.maskedstore %0[%arg1, %c0], %arg3, %arg4 : memref<4x8xf32>, vector<8xf32>
   return
 }
 
@@ -979,7 +979,7 @@ func.func @fold_vector_load_collapse_shape(
 func.func @fold_vector_maskedload_collapse_shape(
   %arg0 : memref<4x8xf32>, %arg1 : index, %arg3: vector<8xi1>, %arg4: vector<8xf32>) -> vector<8xf32> {
   %0 = memref.collapse_shape %arg0 [[0, 1]] : memref<4x8xf32> into memref<32xf32>
-  %1 = vector.maskedload %0[%arg1], %arg3, %arg4 : memref<32xf32>, vector<8xi1>, vector<8xf32> into vector<8xf32>
+  %1 = vector.maskedload %0[%arg1], %arg3, %arg4 : memref<32xf32>, vector<8xf32>
   return %1 : vector<8xf32>
 }
 
@@ -1017,7 +1017,7 @@ func.func @fold_vector_store_collapse_shape(
 func.func @fold_vector_maskedstore_collapse_shape(
   %arg0 : memref<4x8xf32>, %arg1 : index, %arg3: vector<8xi1>, %arg4: vector<8xf32>) {
   %0 = memref.collapse_shape %arg0 [[0, 1]] : memref<4x8xf32> into memref<32xf32>
-  vector.maskedstore %0[%arg1], %arg3, %arg4 : memref<32xf32>, vector<8xi1>, vector<8xf32>
+  vector.maskedstore %0[%arg1], %arg3, %arg4 : memref<32xf32>, vector<8xf32>
   return
 }
 
diff --git a/mlir/test/Dialect/SparseTensor/sparse_vector.mlir b/mlir/test/Dialect/SparseTensor/sparse_vector.mlir
index 364ba6e71ff3b..c50d44f7faa1e 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_vector.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_vector.mlir
@@ -65,10 +65,10 @@
 // CHECK-VEC4-SVE:       scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[step]] {
 // CHECK-VEC4-SVE:         %[[sub:.*]] = affine.min #[[$map]](%[[c1024]], %[[i]])[%[[step]]]
 // CHECK-VEC4-SVE:         %[[mask:.*]] = vector.create_mask %[[sub]] : vector<[4]xi1>
-// CHECK-VEC4-SVE:         %[[val:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %[[v0]] : memref<?xf32>, vector<[4]xi1>, vector<[4]xf32> into vector<[4]xf32>
+// CHECK-VEC4-SVE:         %[[val:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %[[v0]] : memref<?xf32>, vector<[4]xf32>
 // CHECK-VEC4-SVE:         %[[scalev:.*]] = vector.broadcast %{{.*}} : f32 to vector<[4]xf32>
 // CHECK-VEC4-SVE:         %[[scaled:.*]] = arith.mulf %[[val]], %[[scalev]] : vector<[4]xf32>
-// CHECK-VEC4-SVE:         vector.maskedstore %{{.*}}[%[[i]]], %[[mask]], %[[scaled]] : memref<1024xf32>, vector<[4]xi1>, vector<[4]xf32>
+// CHECK-VEC4-SVE:         vector.maskedstore %{{.*}}[%[[i]]], %[[mask]], %[[scaled]] : memref<1024xf32>, vector<[4]xf32>
 // CHECK-VEC4-SVE:       }
 // CHECK-VEC4-SVE:       return
 //
@@ -136,9 +136,9 @@ func.func @scale_d(%arga: tensor<1024xf32, #DenseVector>, %b: f32, %argx: tensor
 // CHECK-VEC16:       scf.for %[[i:.*]] = %[[q]] to %[[s]] step %[[c16]] {
 // CHECK-VEC16:         %[[sub:.*]] = affine.min #[[$map]](%[[s]], %[[i]])[%[[c16]]]
 // CHECK-VEC16:         %[[mask:.*]] = vector.create_mask %[[sub]] : vector<16xi1>
-// CHECK-VEC16:         %[[li:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref<?xi32>, vector<16xi1>, vector<16xi32> into vector<16xi32>
+// CHECK-VEC16:         %[[li:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref<?xi32>, vector<16xi32>
 // CHECK-VEC16:         %[[zi:.*]] = arith.extui %[[li]] : vector<16xi32> to vector<16xi64>
-// CHECK-VEC16:         %[[la:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+// CHECK-VEC16:         %[[la:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref<?xf32>, vector<16xf32>
 // CHECK-VEC16:         %[[lb:.*]] = vector.gather %{{.*}}[%[[c0]]] [%[[zi]]], %[[mask]], %{{.*}} : memref<1024xf32>, vector<16xi64>, vector<16xi1>, vector<16xf32> into vector<16xf32>
 // CHECK-VEC16:         %[[m:.*]] = arith.mulf %[[la]], %[[lb]] : vector<16xf32>
 // CHECK-VEC16:         vector.scatter %{{.*}}[%[[c0]]] [%[[zi]]], %[[mask]], %[[m]] : memref<1024xf32>, vector<16xi64>, vector<16xi1>, vector<16xf32>
@@ -159,8 +159,8 @@ func.func @scale_d(%arga: tensor<1024xf32, #DenseVector>, %b: f32, %argx: tensor
 // CHECK-VEC16-IDX32:       scf.for %[[i:.*]] = %[[q]] to %[[s]] step %[[c16]] {
 // CHECK-VEC16-IDX32:         %[[sub:.*]] = affine.min #[[$map]](%[[s]], %[[i]])[%[[c16]]]
 // CHECK-VEC16-IDX32:         %[[mask:.*]] = vector.create_mask %[[sub]] : vector<16xi1>
-// CHECK-VEC16-IDX32:         %[[li:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref<?xi32>, vector<16xi1>, vector<16xi32> into vector<16xi32>
-// CHECK-VEC16-IDX32:         %[[la:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+// CHECK-VEC16-IDX32:         %[[li:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref<?xi32>, vector<16xi32>
+// CHECK-VEC16-IDX32:         %[[la:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref<?xf32>, vector<16xf32>
 // CHECK-VEC16-IDX32:         %[[lb:.*]] = vector.gather %{{.*}}[%[[c0]]] [%[[li]]], %[[mask]], %{{.*}} : memref<1024xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
 // CHECK-VEC16-IDX32:         %[[m:.*]] = arith.mulf %[[la]], %[[lb]] : vector<16xf32>
 // CHECK-VEC16-IDX32:         vector.scatter %{{.*}}[%[[c0]]] [%[[li]]], %[[mask]], %[[m]] : memref<1024xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
@@ -185,9 +185,9 @@ func.func @scale_d(%arga: tensor<1024xf32, #DenseVector>, %b: f32, %argx: tensor
 // CHECK-VEC4-SVE:       scf.for %[[i:.*]] = %[[q]] to %[[s]] step %[[step]] {
 // CHECK-VEC4-SVE:         %[[sub:.*]] = affine.min #[[$map]](%[[s]], %[[i]])[%[[step]]]
 // CHECK-VEC4-SVE:         %[[mask:.*]] = vector.create_mask %[[sub]] : vector<[4]xi1>
-// CHECK-VEC4-SVE:         %[[li:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %[[v0i]] : memref<?xi32>, vector<[4]xi1>, vector<[4]xi32> into vector<[4]xi32>
+// CHECK-VEC4-SVE:         %[[li:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %[[v0i]] : memref<?xi32>, vector<[4]xi32>
 // CHECK-VEC4-SVE:         %[[lii64:.*]] = arith.extui %[[li]] : vector<[4]xi32> to vector<[4]xi64>
-// CHECK-VEC4-SVE:         %[[la:.*]] = vector.maskedload %{{.*...
[truncated]

@Groverkss
Copy link
Member Author

Groverkss commented Mar 16, 2025

Depends on #131483

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like a welcome cleanup. The vector dialect part LGTM, but I'm not familiar enough with the tablegen part to meaningfully review.

Please wait for more approvals before landing.

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi Kunwar, thanks for looking at ways to improve the codebase!

To be perfectly honest, I actually prefer the longer and more explicit spelling that we have today 😅. The cognitive load here:

%0 = vector.maskedload %arg0[%c0], %arg1, %arg2 : memref<?xindex>, vector<[16]xi1>, vector<[16]xindex> into vector<[16]xindex>

feels much lower than in:

%0 = vector.maskedload %arg0[%c0], %arg1, %arg2 : memref<?xindex>, vector<[16]xindex> into vector<[16]xindex>

With an explicit type for every argument, it's immediately clear what each type corresponds to. By contrast, in the shorter version, it’s unclear which type belongs to which operand:

  • %arg0 + %arg1,
  • %arg0 + %arg2,
  • %arg1 + %arg2.

Sure, things take some getting used to, and perhaps over time, the proposed format would feel more natural. However, I’m particularly concerned about newcomers to Vector. IMHO, they would likely find this quite confusing - is there a clear way to map each type to its corresponding argument?

The type format now exactly matches vector.load and vector.store, with the only difference being that one takes a mask and the other doesn’t.

But isn’t this what we want to avoid? If the distinction between vector.load and vector.maskedload becomes too blurred, it may lead to more confusion rather than clarity.

Additionally, I’ve noticed that the auto-generated error messages aren’t as helpful with the shorter format.

Some of these things are a matter or personal preference - I've tried to explain my reasoning so that you know where I am coming from. I think that it would be good to hear from more people.

Thanks again for the effort - I really appreciate the time you put into this!

func.func @maskedload_dim_mask_mismatch(%base: memref<?xf32>, %mask: vector<15xi1>, %pass: vector<16xf32>) {
%c0 = arith.constant 0 : index
// expected-error@+1 {{'vector.maskedload' op expected result shape to match mask shape}}
%0 = vector.maskedload %base[%c0], %mask, %pass : memref<?xf32>, vector<15xi1>, vector<16xf32> into vector<16xf32>
// expected-error@+1 {{use of value '%mask' expects different type than prior uses: 'vector<16xi1>' vs 'vector<15xi1>'}}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TBH, I find the new error less informative than the original one. The new one refers to "prior uses", but what prior uses? The old one makes it super clear what the issue is.

Is there any way to preserve the more informative errors?

Copy link
Member Author

@Groverkss Groverkss Mar 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TBH, I find the new error less informative than the original one. The new one refers to "prior uses", but what prior uses? The old one makes it super clear what the issue is.

It also shows what the prior uses are, just on a different line, where they originated from. it's much more helpful than the current error because:

  • it shows that the type difference
  • it shows where the type originated from as an error messasge on another line

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me see if i can make tablgen emit a better message

Copy link
Member Author

@Groverkss Groverkss Mar 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, no we cannot emit a better message than this. This is failing in the parser. The parser when reading the whole IR, is finding two possible types for the mask: vector<16xi1> and vector<15xi1. This is the equivalent error message you will get if you had passed the type vector<16xi1> to vector.maskedload but the type of the actual value is vector<15xi1>.

This is actually now producing an error message for something different.

Before, we could have two problems:

  1. Type of mask passed to operation is wrong
  2. The actual type of mask is wrong

With this patch, 1. is not possible anymore so the error is for 2.

This is actually better, because there is no possibility of writing the wrong type here anymore.

You can reproduce the same error before this patch if you passed vector<16xi1> as a type for the mask, but the actual value had the type vector<15xi1>.

TLDR: We cannot compare error message quality here, because they are error messages for different things. This patch eliminates the possibility of the old message from the parser completly.

@Groverkss
Copy link
Member Author

Groverkss commented Mar 19, 2025

Sure, things take some getting used to, and perhaps over time, the proposed format would feel more natural. However, I’m particularly concerned about newcomers to Vector. IMHO, they would likely find this quite confusing - is there a clear way to map each type to its corresponding argument?

I don't think this is a strong argument against this change. We have proper documentation for what the types should be. We have documentation on the assembly format. The tablegen format explicitly documents how arguments are infered, it's not some C++ magic, it's explicitly written in the tablegen file:

      AllTypesMatch<["result", "pass_thru"]>,
      TypesMatchWith<"mask shape should match result shape",
        "result",
        "mask",
        "VectorType::get(::llvm::cast<VectorType>($_self).getShape(),"
          "IntegerType::get($_ctxt, 1),"
          "::llvm::cast<VectorType>($_self).getScalableDims())">,
      AllElementTypesMatch<["result", "base"]>

This is very explicit in how the types are infered. New users should read docs carefully before using things.

I agree type inference can be confusing sometimes (I talked about this with @kuhar who had problems reading SPIRV IR when it was doing some magic with type inference), but for mask/pass_thru it's very clear what their types should be based on the vector type. pass_thru has the same type as the vector type and mask has the same shape (and is i1). There is very little cognitive overhead here.

But isn’t this what we want to avoid? If the distinction between vector.load and vector.maskedload becomes too blurred, it may lead to more confusion rather than clarity.

Adding redundant type infromation to an op to improve distinction between another op does not seem like good IR design to me.

Some of these things are a matter or personal preference - I've tried to explain my reasoning so that you know where I am coming from. I think that it would be good to hear from more people.

I'm only looking to remove redundant information which have a low cognitive overhead to recover from IR here. These things are documented properly in docs and are somewhat obvious once a new user reads the docs.

Do you have a strong preference here? I do have a strong preference here for having less redundent information in the IR here.

@kuhar
Copy link
Member

kuhar commented Mar 19, 2025

I talked about this with @kuhar who had problems reading SPIRV IR when it was doing some magic with type inference

SPIR-V used to have very questionable type inference around load/store/acces-chain instructions where it would implicitly form the pointer type based on the value type and the storage class: #116545

However, for this PR, I don't think it takes much effort to figure out the mask/pass_thru type -- the only thing that changes is the element type.

@banach-space
Copy link
Contributor

Do you have a strong preference here? I do have a strong preference here for having less redundent information in the IR here.

I do, but this shouldn’t be about personal preferences. To help unblock this, we agreed in today’s "Tensor Compiler" call that:

  • This should be split into two PRs (verifier changes + syntax changes).
  • @dcaballe, our masking expert, will review it to ensure broader input.
  • If "mask type" is removed from the syntax, it should be done consistently across all Ops that take masks.

Regardless of personal preferences, let’s maintain consistency.

Thanks for your contributions to Vector! 🙏🏻

@dcaballe
Copy link
Contributor

Thanks a lot for the contribution! My suggestion is that we remove certain level or redundancy but we don't go too far in doing so. I would remove the types that are strictly redundant but I would keep the mask type. That is:

       memref<?xf32>, vector<8xi1>, vector<8xf32>

My particular reasoning is that this is a compiler intermediate representation a certain level or redundancy facilitates understanding. In particular, it could be confusing to find an error about the mask type when the mask type is not even printed so I think being more explicit about it would be helpful in general. I would expect much more emphasis in having a compact textual form and other type of ergonomics if this was a user-facing language expected to be written by hand but that's not the case.

Hopefully that helps!

@Groverkss
Copy link
Member Author

Groverkss commented Mar 20, 2025

Do you have a strong preference here? I do have a strong preference here for having less redundent information in the IR here.

I do, but this shouldn’t be about personal preferences. To help unblock this, we agreed in today’s "Tensor Compiler" call that:

  • This should be split into two PRs (verifier changes + syntax changes).
  • @dcaballe, our masking expert, will review it to ensure broader input.
  • If "mask type" is removed from the syntax, it should be done consistently across all Ops that take masks.

Regardless of personal preferences, let’s maintain consistency.

Thanks for your contributions to Vector! 🙏🏻

Sorry, I should have framed what I meant in the last sentence better and it may have come across not how I wanted it to come across. What I was trying to understand was if this is something which will need a bigger discussion. Because if so, it would be better to come to consensus before I make indvidual prs (I had some more prs for gather/scatter too).. I think I got the answer though (thanks to your very actionable reply), that whatever we decide on, let's be consistent, which is a good idea. We can discuss later what we want to end up with.

I was under the assumption that this would be considered a minor cleanup. After talking With Jakub, and your comments, I think it's not as trivial as I thought it was. I think you have good points, and I have different opinions here, which is okay.

Overall, this is nice to have for me, but it's not something I'm willing to spend too much time, which is what I meant in my original comment (strong preference is the wrong word there, sorry for that).

Thank you for the replies and reviews everyone!

@Groverkss
Copy link
Member Author

Groverkss commented Mar 20, 2025

I'm going to split this into 3 PRs:

  1. PR1: Change tablegen verifiers for all vector ops that have mask and pass_thru and convert it to tablegen constraints. Keeping consistency. This should be an NFC really.
  2. PR2: Remove pass_thru type from assembly format for all the above ops. I think based on @dcaballe 's review, this should be non controversial as well.
  3. PR3: Remove mask type from assembly format. I think this is something we disagree on and I'm willing to just let this pr hang until we decide to close it or land it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants