Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][linalg] Type conversion of operands in new elementwise-op. #131542

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

javedabsar1
Copy link
Contributor

@javedabsar1 javedabsar1 commented Mar 16, 2025

Previous to this diff, elementwise-op inputs and output types had to match.
Now type conversion happens on inputs to the result type before performing
the elemwise operator.

@llvmbot
Copy link
Member

llvmbot commented Mar 16, 2025

@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-mlir

Author: Javed Absar (javedabsar1)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/131542.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td (+6-2)
  • (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+25-6)
  • (modified) mlir/test/Dialect/Linalg/elementwise/generalize-named-ops.mlir (+24)
  • (modified) mlir/test/Dialect/Linalg/elementwise/roundtrip.mlir (+38)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 308e39a9a51e1..af85daca1c078 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -563,13 +563,16 @@ def ElementwiseOp : LinalgStructuredBase_Op<"elementwise", [
     The number of dims of the iterator-types are inferred from the rank of
     the result type.
 
+    Numeric casting is performed on the input operand, promoting it to the same
+    data type as the result.
+
     Example:
 
     Defining a unary linalg.elemwise with default indexing-map:
     ```mlir
     %exp = linalg.elemwise
         kind=#linalg.elemwise_kind<exp>
-        ins(%x : tensor<4x16x8xf32>)
+        ins(%x : tensor<4x16x8xf16>)
         outs(%y: tensor<4x16x8xf32>) -> tensor<4x16x8xf32>
     ```
 
@@ -587,7 +590,8 @@ def ElementwiseOp : LinalgStructuredBase_Op<"elementwise", [
       Variadic<AnyType>:$inputs,
       Variadic<AnyShaped>:$outputs,
       ElementwiseKindAttr:$kind,
-      DefaultValuedOptionalAttr<AffineMapArrayAttr, "{}">:$indexing_maps
+      DefaultValuedOptionalAttr<AffineMapArrayAttr, "{}">:$indexing_maps,
+      DefaultValuedOptionalAttr<TypeFnAttr, "TypeFn::cast_signed">:$cast
     );
 
   let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 07b19e5cb1a89..0ffa259023faf 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -4250,17 +4250,36 @@ void ElementwiseOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
   SmallVector<Value> yields;
   Value result;
 
+  TypeFn castVal = TypeFn::cast_signed;
+  auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
+    return attr.getName() == "cast";
+  });
+
+  if (castIter != attrs.end()) {
+    if (auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
+      castVal = attr.getValue();
+  }
+
   if (arityGroup == ElementwiseArityGroup::Unary) {
-    result = helper.buildUnaryFn(kind.unaryFn, block.getArgument(0));
+    Value val0 = helper.buildTypeFn(castVal, block.getArgument(1).getType(),
+                                    block.getArgument(0));
+    result = helper.buildUnaryFn(kind.unaryFn, val0);
 
   } else if (arityGroup == ElementwiseArityGroup::Binary) {
-    result = helper.buildBinaryFn(kind.binaryFn, block.getArgument(0),
-                                  block.getArgument(1));
+    Value val0 = helper.buildTypeFn(castVal, block.getArgument(2).getType(),
+                                    block.getArgument(0));
+    Value val1 = helper.buildTypeFn(castVal, block.getArgument(2).getType(),
+                                    block.getArgument(1));
+    result = helper.buildBinaryFn(kind.binaryFn, val0, val1);
 
   } else if (arityGroup == ElementwiseArityGroup::Ternary) {
-    result = helper.buildTernaryFn(kind.ternaryFn, block.getArgument(0),
-                                   block.getArgument(1), block.getArgument(2));
-
+    // select op's select-arg (block arg 0) must remain bool.
+    Value val1 = helper.buildTypeFn(castVal, block.getArgument(3).getType(),
+                                    block.getArgument(1));
+    Value val2 = helper.buildTypeFn(castVal, block.getArgument(3).getType(),
+                                    block.getArgument(2));
+    result =
+        helper.buildTernaryFn(kind.ternaryFn, block.getArgument(0), val1, val2);
   } else
     assert(false && "found unhandled category in elemwise");
 
diff --git a/mlir/test/Dialect/Linalg/elementwise/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/elementwise/generalize-named-ops.mlir
index e884858c016f4..19fb0e61d450b 100644
--- a/mlir/test/Dialect/Linalg/elementwise/generalize-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/elementwise/generalize-named-ops.mlir
@@ -163,3 +163,27 @@ func.func @ternary(%A : tensor<32x16xi1>, %B: tensor<8x16x32xf32>, %C : tensor<8
       outs(%D: tensor<8x16x32xf32>) -> tensor<8x16x32xf32>
   return %r : tensor<8x16x32xf32>
 }
+
+// -----
+
+// CHECK: #[[MAP:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+//
+// CHECK: @cast_f16_to_f32(%[[A:.+]]: tensor<16x8xf16>, %[[B:.+]]: tensor<16x8xf32>, %[[C:.+]]: tensor<16x8xf32>)
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]], #[[MAP]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel"]
+// CHECK-SAME:  ins(%[[A]], %[[B]]
+// CHECK-SAME: outs(%[[C]]
+//
+// CHECK: ^{{.*}}(%[[A_ARG:.+]]: f16, %[[B_ARG:.+]]: f32, %[[C_ARG:.+]]: f32)
+// CHECK:   %[[CAST:.+]] = arith.extf %[[A_ARG]] : f16 to f32
+// CHECK:   %[[MUL:.+]] = arith.mulf %[[CAST]], %[[B_ARG]] : f32
+// CHECK:   linalg.yield %[[MUL]] : f32
+//
+func.func @cast_f16_to_f32(%A : tensor<16x8xf16>, %B: tensor<16x8xf32>, %C: tensor<16x8xf32>) ->  tensor<16x8xf32> {
+  %r = linalg.elementwise
+      kind=#linalg.elementwise_kind<mul>
+      ins(%A, %B: tensor<16x8xf16>, tensor<16x8xf32>)
+      outs(%C: tensor<16x8xf32>) -> tensor<16x8xf32>
+  return %r : tensor<16x8xf32>
+}
diff --git a/mlir/test/Dialect/Linalg/elementwise/roundtrip.mlir b/mlir/test/Dialect/Linalg/elementwise/roundtrip.mlir
index 20ebdd992b5a1..0bce89ca378a4 100644
--- a/mlir/test/Dialect/Linalg/elementwise/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/elementwise/roundtrip.mlir
@@ -88,3 +88,41 @@ func.func @redundant_maps(%A: tensor<1x2x3x4x5xi32>, %B: tensor<1x2x3x4x5xi32>,
       outs(%C: tensor<1x2x3x4x5xi32>) -> tensor<1x2x3x4x5xi32>
   return %r : tensor<1x2x3x4x5xi32>
 }
+
+// -----
+
+// CHECK: @convert_f16_to_f32(%[[A:.+]]: tensor<16x8xf16>, %[[B:.+]]: tensor<16x8xf32>,
+// CHECK-SAME:                    %[[C:.+]]: tensor<16x8xf32>) ->  tensor<16x8xf32> {
+// CHECK: {{.*}} = linalg.elementwise
+// CHECK-SAME:       kind=#linalg.elementwise_kind<div>
+// CHECK-SAME:       ins(%[[A]], %[[B]] : tensor<16x8xf16>, tensor<16x8xf32>)
+// CHECK-SAME:       outs(%[[C]] : tensor<16x8xf32>) -> tensor<16x8xf32>
+//
+func.func @convert_f16_to_f32(%A: tensor<16x8xf16>, %B: tensor<16x8xf32>,
+                      %C: tensor<16x8xf32>) ->  tensor<16x8xf32> {
+  %r = linalg.elementwise
+      kind=#linalg.elementwise_kind<div>
+      ins(%A, %B: tensor<16x8xf16>, tensor<16x8xf32>)
+      outs(%C: tensor<16x8xf32>) -> tensor<16x8xf32>
+  return %r : tensor<16x8xf32>
+}
+
+
+// -----
+
+// CHECK: @explicit_cast(%[[A:.+]]: tensor<16x8xi16>, %[[B:.+]]: tensor<16x8xi32>,
+// CHECK-SAME:           %[[C:.+]]: tensor<16x8xi32>) -> tensor<16x8xi32> {
+// CHECK:  {{.*}} = linalg.elementwise
+// CHECK-SAME:          kind=#linalg.elementwise_kind<add>
+// CHECK-SAME:          {cast = #linalg.type_fn<cast_signed>}
+// CHECK-SAME:          ins(%[[A]], %[[B]] : tensor<16x8xi16>, tensor<16x8xi32>)
+// CHECK-SAME:          outs(%[[C]] : tensor<16x8xi32>) -> tensor<16x8xi32>
+//
+func.func @explicit_cast(%A: tensor<16x8xi16>, %B: tensor<16x8xi32>, %C: tensor<16x8xi32>) -> tensor<16x8xi32> {
+    %0 = linalg.elementwise
+            kind=#linalg.elementwise_kind<add>
+            {cast = #linalg.type_fn<cast_signed>}
+            ins(%A, %B : tensor<16x8xi16>, tensor<16x8xi32>)
+            outs(%C : tensor<16x8xi32>) -> tensor<16x8xi32>
+    return %0 : tensor<16x8xi32>
+}

Copy link
Member

@rengolin rengolin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm confused. This does not add a comp-type to element-wise, just a cast operation to the output-type. The PR description seems wrong.

Also, as @MaheshRavishankar noted earlier, there are cases where you have different casts for different inputs. I think for now we can assume the output type isn't casted.

func.func @explicit_cast(%A: tensor<16x8xi16>, %B: tensor<16x8xi32>, %C: tensor<16x8xi32>) -> tensor<16x8xi32> {
%0 = linalg.elementwise
kind=#linalg.elementwise_kind<add>
{cast = #linalg.type_fn<cast_signed>}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You also want test for unsigned cast.

@rengolin rengolin requested a review from rolfmorel March 16, 2025 23:03
@banach-space
Copy link
Contributor

How does this relate to:

Is it completely unrelated (IIUC, comp-type is something much more general, beyond linalg.elementwise) or is this a stepping stone towards the deprecation?

Thanks Javed!

@javedabsar1
Copy link
Contributor Author

javedabsar1 commented Mar 17, 2025

I'm confused. This does not add a comp-type to element-wise, just a cast operation to the output-type. The PR description seems wrong.

Apologies @rengolin and @banach-space . I should have explained better .
My intention was/still-is -
(a) align with legacy `linalg.elemwise_unary /binary .
(b) allow type-conversion of inputs of the new linalg.elementwise to the type of the result-type.

OK (a), lets take example.
legacy IR (note i16 input):

 %0 = linalg.elemwise_unary 
                ins(%lhs: tensor<4x8xi16>) outs(%output: tensor<4x8xf32>) -> tensor<4x8xf32>

after linalg-generalize:

    %0 = linalg.generic 
               {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} 
                ins(%arg0 : tensor<4x8xi16>) outs(%arg1 : tensor<4x8xf32>) {
   ^bb0(%in: i16, %out: f32):
     %1 = arith.sitofp %in : i16 to f32
     %2 = math.exp %1 : f32
     linalg.yield %2 : f32
   } -> tensor<4x8xf32>
   return %0 : tensor<4x8xf32>
 }

Now, the same thing in new elementwise

 %0 = linalg.elementwise 
             kind=#linalg.elementwise_kind<exp>
             ins(%arg0 : tensor<4x8xi16>) outs(%arg1 : tensor<4x8xf32>) -> tensor<4x8xf32>

after generalize -

    %0 = linalg.generic
              {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]}
              ins(%arg0 : tensor<4x8xi16>) outs(%arg1 : tensor<4x8xf32>) {
    ^bb0(%in: i16, %out: f32):
      %1 = arith.sitofp %in : i16 to f32
      %2 = math.exp %1 : f32
      linalg.yield %2 : f32
    } -> tensor<4x8xf32>

---------------=====================--------------------------
Now for the role of 'cast' and confusion with 'comp-type'
When promoting int to float type, 'cast' is used to decide whether to sitofp (default as shown above) or uitofp.
signed is default and covered above. So, for unsigned:
Legacy -

linalg.elemwise_unary 
          {cast = #linalg.type_fn<cast_unsigned>} 
          ins(%arg0 : tensor<4x8xi16>) outs(%arg1 : tensor<4x8xf32>) -> tensor<4x8xf32>

Generalize -

    %0 = linalg.generic
               {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]}
               ins(%arg0 : tensor<4x8xi16>) outs(%arg1 : tensor<4x8xf32>) {
    ^bb0(%in: i16, %out: f32):
      %1 = arith.uitofp %in : i16 to f32
      %2 = math.exp %1 : f32                                                                                                                                                                                           linalg.yield %2 : f32
    } -> tensor<4x8xf32>

Now for new elementwise

%0 = linalg.elementwise 
               kind=#linalg.elementwise_kind<exp>
               {cast = #linalg.type_fn<cast_unsigned>}
               ins(%arg0 : tensor<4x8xi16>) outs(%arg1 : tensor<4x8xf32>) -> tensor<4x8xf32>

Generalize-

...
^bb0(%in: i16, %out: f32):
      %1 = arith.uitofp %in : i16 to f32
      %2 = math.exp %1 : f32
      linalg.yield %2 : f32
    } -> tensor<4x8xf32>

----------------=======================------------------
Similary, when going from inputs 'fp' to result 'i', default 'cast=signed' generates fptosi on inputs, and

  %0 = linalg.elemwise_binary
               { fun = #linalg.binary_fn<add> , cast=#linalg.type_fn<cast_unsigned>}
               ins(%lhs, %rhs: tensor<4x8xf32>, tensor<4x8xf32>)
               outs(%output: tensor<4x8xi32>) -> tensor<4x8xi32>

Generalize

    ^bb0(%in: f32, %in_0: f32, %out: i32):
      %1 = arith.fptoui %in : f32 to i32
      %2 = arith.fptoui %in_0 : f32 to i32
      %3 = arith.addi %1, %2 : i32
      linalg.yield %3 : i32

New Op:

%0 = linalg.elementwise 
                kind=#linalg.elementwise_kind<add>
               {cast = #linalg.type_fn<cast_unsigned>}
               ins(%arg0, %arg1 : tensor<4x8xf32>, tensor<4x8xf32>)
               outs(%arg2 : tensor<4x8xi16>) -> tensor<4x8xi16>

Generalize

    ^bb0(%in: f32, %in_0: f32, %out: i16):
      %1 = arith.fptoui %in : f32 to i16
      %2 = arith.fptoui %in_0 : f32 to i16
      %3 = arith.addi %1, %2 : i16
      linalg.yield %3 : i16
    } -> tensor<4x8xi16>

@javedabsar1 javedabsar1 changed the title [mlir][linalg] Add comp-type to new elementwise-op. [mlir][linalg] Type conversion of operands in new elementwise-op. Mar 17, 2025
@javedabsar1
Copy link
Contributor Author

Gentle ping! @rengolin @banach-space

@javedabsar1
Copy link
Contributor Author

Gentle ping! @rengolin @banach-space

Gentle Ping again!

@banach-space
Copy link
Contributor

Apologies @rengolin and @banach-space . I should have explained better .

No apology required, this is obviously super clear in your head. You wouldn't know what's not clear in ours if we didn't ask :)

My intention was/still-is -
(a) align with legacy `linalg.elemwise_unary /binary .
(b) allow type-conversion of inputs of the new linalg.elementwise to the type of the result-type.

OK, just to clarify. From the examples that you shared, linalg.elementwise already does casting similar to what the legacy Ops do, right? So there's nothing to align? I am going by this example:

Now, the same thing in new elementwise

 %0 = linalg.elementwise 
             kind=#linalg.elementwise_kind<exp>
             ins(%arg0 : tensor<4x8xi16>) outs(%arg1 : tensor<4x8xf32>) -> tensor<4x8xf32>
after generalize -

    %0 = linalg.generic
              {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]}
              ins(%arg0 : tensor<4x8xi16>) outs(%arg1 : tensor<4x8xf32>) {
    ^bb0(%in: i16, %out: f32):
      %1 = arith.sitofp %in : i16 to f32
      %2 = math.exp %1 : f32
      linalg.yield %2 : f32
    } -> tensor<4x8xf32>

All in all, makes sense to me. I've left a few minor comments. My bigger comment is to update the summary to communicate the actual extension when compare to:

  • the legacy Op
  • what's currently supported by the new Op.

Once that's done, I am happy to approve. Many thanks for working on this 🙏🏻

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry about the delay, Javed!

Comment on lines +566 to +567
Numeric casting is performed on the input operand, promoting it to the same
data type as the result.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you document that there's default casting and that it can be specialised with the cast attribute?

Example:

Defining a unary linalg.elemwise with default indexing-map:
```mlir
%exp = linalg.elemwise
kind=#linalg.elemwise_kind<exp>
ins(%x : tensor<4x16x8xf32>)
ins(%x : tensor<4x16x8xf16>)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So you have changed this example so that there's casting. But what kind of casting? And why is it crucial? It would be good to expand docs.

// CHECK-SAME: ins(%[[A]], %[[B]] : tensor<16x8xf16>, tensor<16x8xf32>)
// CHECK-SAME: outs(%[[C]] : tensor<16x8xf32>) -> tensor<16x8xf32>
//
func.func @convert_f16_to_f32(%A: tensor<16x8xf16>, %B: tensor<16x8xf32>,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note, you are not following the naming convention documented at the top.

// CHECK: %[[MUL:.+]] = arith.mulf %[[CAST]], %[[B_ARG]] : f32
// CHECK: linalg.yield %[[MUL]] : f32
//
func.func @cast_f16_to_f32(%A : tensor<16x8xf16>, %B: tensor<16x8xf32>, %C: tensor<16x8xf32>) -> tensor<16x8xf32> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] You are not following the existing naming convention from this file.

Also, a test with a non-default cast attribute would be also helpful.

if (arityGroup == ElementwiseArityGroup::Unary) {
result = helper.buildUnaryFn(kind.unaryFn, block.getArgument(0));
Value val0 = helper.buildTypeFn(castVal, block.getArgument(1).getType(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These val0 and val1 are quite enigmatic. I don't quite see what these mean. Could you use more descriptive names? Thanks!

Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is better off if you t

@@ -587,7 +590,8 @@ def ElementwiseOp : LinalgStructuredBase_Op<"elementwise", [
Variadic<AnyType>:$inputs,
Variadic<AnyShaped>:$outputs,
ElementwiseKindAttr:$kind,
DefaultValuedOptionalAttr<AffineMapArrayAttr, "{}">:$indexing_maps
DefaultValuedOptionalAttr<AffineMapArrayAttr, "{}">:$indexing_maps,
DefaultValuedOptionalAttr<TypeFnAttr, "TypeFn::cast_signed">:$cast
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be better to make this a list of TypeFnAttr which allows for a sentinel none if no castin is required. Some parser/printer helpers can allow something like [cast_signed, -] to say no casting is needed. I'd also say this list should be off the size of the number of ins operands.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants