-
Notifications
You must be signed in to change notification settings - Fork 13.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][linalg] Type conversion of operands in new elementwise-op. #131542
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-mlir-linalg @llvm/pr-subscribers-mlir Author: Javed Absar (javedabsar1) ChangesFull diff: https://github.com/llvm/llvm-project/pull/131542.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 308e39a9a51e1..af85daca1c078 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -563,13 +563,16 @@ def ElementwiseOp : LinalgStructuredBase_Op<"elementwise", [
The number of dims of the iterator-types are inferred from the rank of
the result type.
+ Numeric casting is performed on the input operand, promoting it to the same
+ data type as the result.
+
Example:
Defining a unary linalg.elemwise with default indexing-map:
```mlir
%exp = linalg.elemwise
kind=#linalg.elemwise_kind<exp>
- ins(%x : tensor<4x16x8xf32>)
+ ins(%x : tensor<4x16x8xf16>)
outs(%y: tensor<4x16x8xf32>) -> tensor<4x16x8xf32>
```
@@ -587,7 +590,8 @@ def ElementwiseOp : LinalgStructuredBase_Op<"elementwise", [
Variadic<AnyType>:$inputs,
Variadic<AnyShaped>:$outputs,
ElementwiseKindAttr:$kind,
- DefaultValuedOptionalAttr<AffineMapArrayAttr, "{}">:$indexing_maps
+ DefaultValuedOptionalAttr<AffineMapArrayAttr, "{}">:$indexing_maps,
+ DefaultValuedOptionalAttr<TypeFnAttr, "TypeFn::cast_signed">:$cast
);
let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 07b19e5cb1a89..0ffa259023faf 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -4250,17 +4250,36 @@ void ElementwiseOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
SmallVector<Value> yields;
Value result;
+ TypeFn castVal = TypeFn::cast_signed;
+ auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
+ return attr.getName() == "cast";
+ });
+
+ if (castIter != attrs.end()) {
+ if (auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
+ castVal = attr.getValue();
+ }
+
if (arityGroup == ElementwiseArityGroup::Unary) {
- result = helper.buildUnaryFn(kind.unaryFn, block.getArgument(0));
+ Value val0 = helper.buildTypeFn(castVal, block.getArgument(1).getType(),
+ block.getArgument(0));
+ result = helper.buildUnaryFn(kind.unaryFn, val0);
} else if (arityGroup == ElementwiseArityGroup::Binary) {
- result = helper.buildBinaryFn(kind.binaryFn, block.getArgument(0),
- block.getArgument(1));
+ Value val0 = helper.buildTypeFn(castVal, block.getArgument(2).getType(),
+ block.getArgument(0));
+ Value val1 = helper.buildTypeFn(castVal, block.getArgument(2).getType(),
+ block.getArgument(1));
+ result = helper.buildBinaryFn(kind.binaryFn, val0, val1);
} else if (arityGroup == ElementwiseArityGroup::Ternary) {
- result = helper.buildTernaryFn(kind.ternaryFn, block.getArgument(0),
- block.getArgument(1), block.getArgument(2));
-
+ // select op's select-arg (block arg 0) must remain bool.
+ Value val1 = helper.buildTypeFn(castVal, block.getArgument(3).getType(),
+ block.getArgument(1));
+ Value val2 = helper.buildTypeFn(castVal, block.getArgument(3).getType(),
+ block.getArgument(2));
+ result =
+ helper.buildTernaryFn(kind.ternaryFn, block.getArgument(0), val1, val2);
} else
assert(false && "found unhandled category in elemwise");
diff --git a/mlir/test/Dialect/Linalg/elementwise/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/elementwise/generalize-named-ops.mlir
index e884858c016f4..19fb0e61d450b 100644
--- a/mlir/test/Dialect/Linalg/elementwise/generalize-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/elementwise/generalize-named-ops.mlir
@@ -163,3 +163,27 @@ func.func @ternary(%A : tensor<32x16xi1>, %B: tensor<8x16x32xf32>, %C : tensor<8
outs(%D: tensor<8x16x32xf32>) -> tensor<8x16x32xf32>
return %r : tensor<8x16x32xf32>
}
+
+// -----
+
+// CHECK: #[[MAP:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+//
+// CHECK: @cast_f16_to_f32(%[[A:.+]]: tensor<16x8xf16>, %[[B:.+]]: tensor<16x8xf32>, %[[C:.+]]: tensor<16x8xf32>)
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]], #[[MAP]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel"]
+// CHECK-SAME: ins(%[[A]], %[[B]]
+// CHECK-SAME: outs(%[[C]]
+//
+// CHECK: ^{{.*}}(%[[A_ARG:.+]]: f16, %[[B_ARG:.+]]: f32, %[[C_ARG:.+]]: f32)
+// CHECK: %[[CAST:.+]] = arith.extf %[[A_ARG]] : f16 to f32
+// CHECK: %[[MUL:.+]] = arith.mulf %[[CAST]], %[[B_ARG]] : f32
+// CHECK: linalg.yield %[[MUL]] : f32
+//
+func.func @cast_f16_to_f32(%A : tensor<16x8xf16>, %B: tensor<16x8xf32>, %C: tensor<16x8xf32>) -> tensor<16x8xf32> {
+ %r = linalg.elementwise
+ kind=#linalg.elementwise_kind<mul>
+ ins(%A, %B: tensor<16x8xf16>, tensor<16x8xf32>)
+ outs(%C: tensor<16x8xf32>) -> tensor<16x8xf32>
+ return %r : tensor<16x8xf32>
+}
diff --git a/mlir/test/Dialect/Linalg/elementwise/roundtrip.mlir b/mlir/test/Dialect/Linalg/elementwise/roundtrip.mlir
index 20ebdd992b5a1..0bce89ca378a4 100644
--- a/mlir/test/Dialect/Linalg/elementwise/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/elementwise/roundtrip.mlir
@@ -88,3 +88,41 @@ func.func @redundant_maps(%A: tensor<1x2x3x4x5xi32>, %B: tensor<1x2x3x4x5xi32>,
outs(%C: tensor<1x2x3x4x5xi32>) -> tensor<1x2x3x4x5xi32>
return %r : tensor<1x2x3x4x5xi32>
}
+
+// -----
+
+// CHECK: @convert_f16_to_f32(%[[A:.+]]: tensor<16x8xf16>, %[[B:.+]]: tensor<16x8xf32>,
+// CHECK-SAME: %[[C:.+]]: tensor<16x8xf32>) -> tensor<16x8xf32> {
+// CHECK: {{.*}} = linalg.elementwise
+// CHECK-SAME: kind=#linalg.elementwise_kind<div>
+// CHECK-SAME: ins(%[[A]], %[[B]] : tensor<16x8xf16>, tensor<16x8xf32>)
+// CHECK-SAME: outs(%[[C]] : tensor<16x8xf32>) -> tensor<16x8xf32>
+//
+func.func @convert_f16_to_f32(%A: tensor<16x8xf16>, %B: tensor<16x8xf32>,
+ %C: tensor<16x8xf32>) -> tensor<16x8xf32> {
+ %r = linalg.elementwise
+ kind=#linalg.elementwise_kind<div>
+ ins(%A, %B: tensor<16x8xf16>, tensor<16x8xf32>)
+ outs(%C: tensor<16x8xf32>) -> tensor<16x8xf32>
+ return %r : tensor<16x8xf32>
+}
+
+
+// -----
+
+// CHECK: @explicit_cast(%[[A:.+]]: tensor<16x8xi16>, %[[B:.+]]: tensor<16x8xi32>,
+// CHECK-SAME: %[[C:.+]]: tensor<16x8xi32>) -> tensor<16x8xi32> {
+// CHECK: {{.*}} = linalg.elementwise
+// CHECK-SAME: kind=#linalg.elementwise_kind<add>
+// CHECK-SAME: {cast = #linalg.type_fn<cast_signed>}
+// CHECK-SAME: ins(%[[A]], %[[B]] : tensor<16x8xi16>, tensor<16x8xi32>)
+// CHECK-SAME: outs(%[[C]] : tensor<16x8xi32>) -> tensor<16x8xi32>
+//
+func.func @explicit_cast(%A: tensor<16x8xi16>, %B: tensor<16x8xi32>, %C: tensor<16x8xi32>) -> tensor<16x8xi32> {
+ %0 = linalg.elementwise
+ kind=#linalg.elementwise_kind<add>
+ {cast = #linalg.type_fn<cast_signed>}
+ ins(%A, %B : tensor<16x8xi16>, tensor<16x8xi32>)
+ outs(%C : tensor<16x8xi32>) -> tensor<16x8xi32>
+ return %0 : tensor<16x8xi32>
+}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm confused. This does not add a comp-type
to element-wise, just a cast
operation to the output-type
. The PR description seems wrong.
Also, as @MaheshRavishankar noted earlier, there are cases where you have different casts for different inputs. I think for now we can assume the output type isn't casted.
func.func @explicit_cast(%A: tensor<16x8xi16>, %B: tensor<16x8xi32>, %C: tensor<16x8xi32>) -> tensor<16x8xi32> { | ||
%0 = linalg.elementwise | ||
kind=#linalg.elementwise_kind<add> | ||
{cast = #linalg.type_fn<cast_signed>} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You also want test for unsigned cast.
How does this relate to: Is it completely unrelated (IIUC, Thanks Javed! |
Apologies @rengolin and @banach-space . I should have explained better . OK (a), lets take example.
after linalg-generalize:
Now, the same thing in new elementwise
after generalize -
---------------=====================--------------------------
Generalize -
Now for new elementwise
Generalize-
----------------=======================------------------
Generalize
New Op:
Generalize
|
comp-type
to new elementwise-op.
Gentle ping! @rengolin @banach-space |
Gentle Ping again! |
No apology required, this is obviously super clear in your head. You wouldn't know what's not clear in ours if we didn't ask :)
OK, just to clarify. From the examples that you shared,
All in all, makes sense to me. I've left a few minor comments. My bigger comment is to update the summary to communicate the actual extension when compare to:
Once that's done, I am happy to approve. Many thanks for working on this 🙏🏻 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry about the delay, Javed!
Numeric casting is performed on the input operand, promoting it to the same | ||
data type as the result. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you document that there's default casting and that it can be specialised with the cast
attribute?
Example: | ||
|
||
Defining a unary linalg.elemwise with default indexing-map: | ||
```mlir | ||
%exp = linalg.elemwise | ||
kind=#linalg.elemwise_kind<exp> | ||
ins(%x : tensor<4x16x8xf32>) | ||
ins(%x : tensor<4x16x8xf16>) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So you have changed this example so that there's casting. But what kind of casting? And why is it crucial? It would be good to expand docs.
// CHECK-SAME: ins(%[[A]], %[[B]] : tensor<16x8xf16>, tensor<16x8xf32>) | ||
// CHECK-SAME: outs(%[[C]] : tensor<16x8xf32>) -> tensor<16x8xf32> | ||
// | ||
func.func @convert_f16_to_f32(%A: tensor<16x8xf16>, %B: tensor<16x8xf32>, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note, you are not following the naming convention documented at the top.
// CHECK: %[[MUL:.+]] = arith.mulf %[[CAST]], %[[B_ARG]] : f32 | ||
// CHECK: linalg.yield %[[MUL]] : f32 | ||
// | ||
func.func @cast_f16_to_f32(%A : tensor<16x8xf16>, %B: tensor<16x8xf32>, %C: tensor<16x8xf32>) -> tensor<16x8xf32> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nit] You are not following the existing naming convention from this file.
Also, a test with a non-default cast
attribute would be also helpful.
if (arityGroup == ElementwiseArityGroup::Unary) { | ||
result = helper.buildUnaryFn(kind.unaryFn, block.getArgument(0)); | ||
Value val0 = helper.buildTypeFn(castVal, block.getArgument(1).getType(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These val0
and val1
are quite enigmatic. I don't quite see what these mean. Could you use more descriptive names? Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it is better off if you t
@@ -587,7 +590,8 @@ def ElementwiseOp : LinalgStructuredBase_Op<"elementwise", [ | |||
Variadic<AnyType>:$inputs, | |||
Variadic<AnyShaped>:$outputs, | |||
ElementwiseKindAttr:$kind, | |||
DefaultValuedOptionalAttr<AffineMapArrayAttr, "{}">:$indexing_maps | |||
DefaultValuedOptionalAttr<AffineMapArrayAttr, "{}">:$indexing_maps, | |||
DefaultValuedOptionalAttr<TypeFnAttr, "TypeFn::cast_signed">:$cast |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it would be better to make this a list of TypeFnAttr
which allows for a sentinel none
if no castin is required. Some parser/printer helpers can allow something like [cast_signed, -]
to say no casting is needed. I'd also say this list should be off the size of the number of ins
operands.
Previous to this diff, elementwise-op inputs and output types had to match.
Now type conversion happens on inputs to the result type before performing
the elemwise operator.