-
Notifications
You must be signed in to change notification settings - Fork 12.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][linalg] Extend elementwise #124661
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -538,6 +538,126 @@ def BroadcastOp : LinalgStructuredBase_Op<"broadcast", [ | |
let hasCanonicalizer = 1; | ||
} | ||
|
||
//===----------------------------------------------------------------------===// | ||
// Op definition for ElementwiseOp | ||
//===----------------------------------------------------------------------===// | ||
def ElementwiseOp : LinalgStructuredBase_Op<"elementwise", [ | ||
AttrSizedOperandSegments]> { | ||
let summary = [{ Performs element-wise operation }]; | ||
let description = [{ | ||
The attribute `kind` describes arithmetic operation to perform. The | ||
operation kind can be unary (e.g. max), binary (e.g. add) or ternary | ||
(e.g. select). | ||
|
||
By default, all indexing maps are identities. In the case of default | ||
indexing map, all input and output shapes must match. The number of dims in | ||
each of the identity maps is equal to the rank of the output type. | ||
|
||
Affine-maps for operands and result are required to be provided by the user | ||
when a transpose and/or broadcast is needed on any operand. When a map is not | ||
provided, default identity maps are inferred for each operand. | ||
|
||
Iterator-types are always all `parallel`. | ||
Iterator-types are needed for constructing the underlying structured op. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure why this is relevant here. This is somehow implying the op stores iterator types on the operation. This is confusing. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agreed on this. |
||
|
||
The number of dims of the iterator-types are inferred from the rank of | ||
the result type. | ||
|
||
Example: | ||
|
||
Defining a unary linalg.elemwise with default indexing-map: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Old op's name, needs updating |
||
```mlir | ||
%exp = linalg.elemwise | ||
kind=#linalg.elemwise_kind<exp> | ||
ins(%x : tensor<4x16x8xf32>) | ||
outs(%y: tensor<4x16x8xf32>) -> tensor<4x16x8xf32> | ||
``` | ||
|
||
Defining a binary linalg.elemwise with user-defined indexing-map: | ||
```mlir | ||
%add = linalg.elemwise | ||
kind=#linalg.elemwise_kind<add> | ||
indexing_maps = [#transpose, #broadcast, #identity] | ||
ins(%exp, %arg1 : tensor<4x16x8xf32>, tensor<4x16xf32>) | ||
outs(%arg2: tensor<4x8x16xf32>) -> tensor<4x8x16xf32> | ||
``` | ||
}]; | ||
|
||
let arguments = (ins | ||
Variadic<AnyType>:$inputs, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What's the reason to accept There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. thanks. changed to AnyShaped. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Left a comment about this below: the primary reason to allow |
||
Variadic<AnyShaped>:$outputs, | ||
ElementwiseKindAttr:$kind, | ||
DefaultValuedOptionalAttr<AffineMapArrayAttr, "{}">:$indexing_maps | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure if this should be DefaultValuedOptionalAttr. "{}" can be invalid in many cases. Instead, we should just have a builder for having a derived default value of the attribute. |
||
); | ||
|
||
let results = (outs Variadic<AnyRankedTensor>:$result_tensors); | ||
let regions = (region AnyRegion:$region); | ||
let skipDefaultBuilders = 1; | ||
|
||
let builders = [ | ||
OpBuilder< | ||
(ins "ValueRange":$inputs, "ValueRange":$outputs, | ||
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes), | ||
[{ | ||
buildStructuredOp($_builder, $_state, std::nullopt, inputs, outputs, | ||
attributes, ElementwiseOp::getRegionBuilder()); | ||
}]> | ||
]; | ||
|
||
let hasCustomAssemblyFormat = 1; | ||
let hasFolder = 1; | ||
let hasVerifier = 1; | ||
|
||
let extraClassDeclaration = structuredOpsBaseDecls # [{ | ||
/// Get the arity enum corresponding to the kind of op, e.g. if arg is | ||
/// `ElementwiseKind::add`, return `ElementwiseArityGroup::Binary`. | ||
static ElementwiseArityGroup getArityGroup(ElementwiseKind n); | ||
|
||
/// Both user-specified and default indexing map will always depend on | ||
/// the current Op instance. | ||
static bool hasDynamicIndexingMaps() { return true; } | ||
|
||
/// Implements the block region builder for the elementwiseOp. This is | ||
/// called by the 'fillStructuredOpRegion'. | ||
static void regionBuilder(ImplicitLocOpBuilder &b, | ||
Block &block, ArrayRef<NamedAttribute> attrs); | ||
|
||
static std::function<void(ImplicitLocOpBuilder &, | ||
Block &, ArrayRef<NamedAttribute>)> | ||
getRegionBuilder() { | ||
return regionBuilder; | ||
} | ||
|
||
/// Returns rank of the result tensor/memref. Useful for knowing | ||
/// the dimensionality of the iteration space when others means | ||
/// are not possible e.g. absence of user-provided indexing map. | ||
unsigned getResultRank() { | ||
Value output = getDpsInitOperand(0)->get(); | ||
ShapedType shapedType = llvm::cast<ShapedType>(output.getType()); | ||
return shapedType.getRank(); | ||
} | ||
|
||
/// Returns N 'parallel' iterator types where N is rank of result. | ||
SmallVector<utils::IteratorType> getIteratorTypesArray(); | ||
rolfmorel marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
/// The default indexing maps are identities. | ||
/// There will be N+1 such maps, where N is the arity of the Op. | ||
static SmallVector<AffineMap> | ||
getDefaultIndexingMaps(unsigned NumMaps, unsigned numDims, | ||
MLIRContext *context); | ||
|
||
/// Destination passing style interface method. | ||
::mlir::MutableOperandRange getDpsInitsMutable() { | ||
return getOutputsMutable(); | ||
} | ||
|
||
// Generic methods. | ||
std::string getLibraryCallName() { | ||
return generateLibraryCallName(getOperation()); | ||
} | ||
}]; | ||
} | ||
|
||
//===----------------------------------------------------------------------===// | ||
// Op definition for MatmulOp | ||
//===----------------------------------------------------------------------===// | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thinking out loud, perhaps it would be easier to use bit patterns instead of joining enum lists. We won't have more than 20 operations per
Category
, so:(op | (0xFF << 1))
(op | (0xFF << 2))
(op | (0xFF << 3))
And set the enums above like:
I32EnumAttrCase<"log", (1 << 1)>
I32EnumAttrCase<"sub", (1 << 2)>
I32EnumAttrCase<"select", (1 << 3)>
etc?
Then you don't need all the complex sequential logic in the parser/verifiers.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need anymore.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the alternative join-based approach, @javedabsar1 - It's quite impressive what can done with TableGen! (For better or worse, TableGen is a programming language on its own.)
I would still think an approach like @rengolin's would lead to simpler C++. The scheme I have in mind is to just shift the arity 30 bits, e.g.
I32EnumAttrCase<"abs", 2>
becomesI32EnumAttrCase<"abs", 2 + (1 << 30)>
,I32EnumAttrCase<"div", 3 + (2 << 30)>
I32EnumAttrCase<"select", 0 + (3 << 30)>
. This way the arity can be retrieved by just shifting right 30 bits (e.g.derivedEnumVal >> 30
) and to obtain the original op code you just doderivedEnumVal & ((1 << 30) - 1)
.Three nested
!lfold
s should now suffice to derive all the derived enum cases andElementwiseFnLimits
could go andNAryCategoryAndFn
could go or be simplified.What do you think? (If you could just state the benefits of your approach, that would also be fine OFC.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let me make a pragmatic case for C++ ...
I've not seen much (any?) sophisticated TableGen in dialect definition. C++ might simply be more familiar to folks. I've already revealed my own personal preference 😅