-
Notifications
You must be signed in to change notification settings - Fork 12.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][linalg] Extend Linalg elemwise named ops semantics #122753
Changes from 2 commits
627b06c
578998d
3c0d02c
4d1a40f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,6 +15,57 @@ | |
|
||
include "mlir/IR/EnumAttr.td" | ||
|
||
// Define an `enum class : i32` to categorise element-wise op. | ||
def ElemwiseNAryCategory : I32EnumAttr<"ElemwiseNAryCategory", "", [ | ||
I32EnumAttrCase<"Unary", 0>, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need separate enums here? Can we juse use the enum below to classify what is a unary/binary or ternary function? (I think this is a repeat comment from above) |
||
I32EnumAttrCase<"Binary", 1>, | ||
I32EnumAttrCase<"Ternary", 2> | ||
]> { | ||
let genSpecializedAttr = 0; | ||
let cppNamespace = "::mlir::linalg"; | ||
} | ||
|
||
// Define a unified `enum class : i32` for all element-wise options. | ||
// Note: The order of individual fn (e.g. 'exp', 'log') within each | ||
// category (Unary, Binary etc.) must match the ordering of same fn | ||
// defined in UnaryFn, BinaryFn. This is to enable correct mapping | ||
// from this unified enum class to different category enums. | ||
def ElemwiseFn : I32EnumAttr<"ElemwiseFn", "", [ | ||
// Unary | ||
I32EnumAttrCase<"exp", 0>, | ||
I32EnumAttrCase<"log", 1>, | ||
I32EnumAttrCase<"abs", 2>, | ||
I32EnumAttrCase<"ceil", 3>, | ||
I32EnumAttrCase<"floor", 4>, | ||
I32EnumAttrCase<"negf", 5>, | ||
I32EnumAttrCase<"reciprocal", 6>, | ||
I32EnumAttrCase<"round", 7>, | ||
I32EnumAttrCase<"sqrt", 8>, | ||
I32EnumAttrCase<"rsqrt", 9>, | ||
I32EnumAttrCase<"square", 10>, | ||
I32EnumAttrCase<"tanh", 11>, | ||
I32EnumAttrCase<"erf", 12>, | ||
|
||
// Binary | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: unnecessary extra new line |
||
I32EnumAttrCase<"add", 13>, | ||
I32EnumAttrCase<"sub", 14>, | ||
I32EnumAttrCase<"mul", 15>, | ||
I32EnumAttrCase<"div", 16>, | ||
I32EnumAttrCase<"div_unsigned", 17>, | ||
I32EnumAttrCase<"max_signed", 18>, | ||
I32EnumAttrCase<"min_signed", 19>, | ||
I32EnumAttrCase<"max_unsigned", 20>, | ||
I32EnumAttrCase<"min_unsigned", 21>, | ||
I32EnumAttrCase<"powf", 22>, | ||
|
||
// Ternary | ||
I32EnumAttrCase<"select", 23> | ||
]> { | ||
let genSpecializedAttr = 0; | ||
let cppNamespace = "::mlir::linalg"; | ||
} | ||
|
||
// Define the function attribute enums matching the OpDSL functions. | ||
def UnaryFn : I32EnumAttr<"UnaryFn", "", [ | ||
I32EnumAttrCase<"exp", 0>, | ||
|
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
|
@@ -551,6 +551,136 @@ def BroadcastOp : LinalgStructuredBase_Op<"broadcast", [ | |||||||||
let hasCanonicalizer = 1; | ||||||||||
} | ||||||||||
|
||||||||||
//===----------------------------------------------------------------------===// | ||||||||||
// Op definition for ElemwiseOp - with user-defined maps, computation type etc. | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [nit] Lets stick with the existing convention.
Suggested change
|
||||||||||
//===----------------------------------------------------------------------===// | ||||||||||
|
||||||||||
def ElemwiseOp : LinalgStructuredBase_Op<"elemwise", [ | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: I'd rather just use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. +1 for full spelling, feels more intuitive There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I too agree. I just wasnt sure what you folks would prefer (longer IR and precise, or shorter and guess-it) |
||||||||||
AttrSizedOperandSegments]> { | ||||||||||
let summary = [{ Performs element-wise operation }]; | ||||||||||
let description = [{ | ||||||||||
Linalg op form which performs element-wise computation. The attribute | ||||||||||
`func_type` describes the operation type (e.g. add, exp). The func_type | ||||||||||
can be any valid unary, binary, or ternary operation. | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How are we defining "any valid"? I would either list all or skip this altogether. |
||||||||||
|
||||||||||
Affine-maps for operands and result may be provided by the user. When | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
IIUC, maps are required when transposing or broadcasting? |
||||||||||
a user-defined indexing_map is not provided, identity map is inferred | ||||||||||
for all operands. The default indexing maps are N identity-maps. ‘N’ | ||||||||||
depends on the arity of the elementwise op. The number of dims is | ||||||||||
Comment on lines
+568
to
+569
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Isn't N simply the rank of the output? And doesn't "arity" mean the number of function args? Is that relevant here? |
||||||||||
inferred from rank of the output type. In the case of default indexing | ||||||||||
map, the input and output shapes must all match. Affine-map for operands | ||||||||||
Comment on lines
+570
to
+571
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||
and result must be only projected permutations with no zero constants. | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||
|
||||||||||
For element-wise iterator-type is always inferred as all ‘parallel’. | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
or
Suggested change
|
||||||||||
Iterator-type is needed for constructing this underlying structured op. | ||||||||||
The number of dims of the iterator-type is inferred from the rank of | ||||||||||
the result type. | ||||||||||
|
||||||||||
Example: | ||||||||||
Defining a unary linalg.elemwise with default indexing-map: | ||||||||||
|
||||||||||
```mlir | ||||||||||
%exp = linalg.elemwise | ||||||||||
func_type=#linalg.elemwise_fn<exp> | ||||||||||
ins(%x : tensor<4x16x8xf32>) | ||||||||||
outs(%y: tensor<4x16x8xf32>) -> tensor<4x16x8xf32> | ||||||||||
``` | ||||||||||
|
||||||||||
Defining a binary linalg.elemwise with user-defined indexing-map: | ||||||||||
|
||||||||||
```mlir | ||||||||||
%add = linalg.elemwise | ||||||||||
func_type=#linalg.elemwise_fn<add> | ||||||||||
indexing_maps = [#transpose, #broadcast, #identity] | ||||||||||
ins(%exp, %arg1 : tensor<4x16x8xf32>, tensor<4x16xf32>) | ||||||||||
outs(%arg2: tensor<4x8x16xf32>) -> tensor<4x8x16xf32> | ||||||||||
``` | ||||||||||
}]; | ||||||||||
|
||||||||||
let arguments = (ins | ||||||||||
Variadic<AnyType>:$inputs, | ||||||||||
Variadic<AnyShaped>:$outputs, | ||||||||||
ElemwiseFnAttr:$func_type, | ||||||||||
DefaultValuedOptionalAttr<AffineMapArrayAttr, "{}">:$indexing_maps | ||||||||||
); | ||||||||||
|
||||||||||
let results = (outs Variadic<AnyRankedTensor>:$result_tensors); | ||||||||||
let regions = (region AnyRegion:$region); | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need a region on the op? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just to cross-pollinate the Linalg discussions. |
||||||||||
|
||||||||||
let skipDefaultBuilders = 1; | ||||||||||
let builders = [ | ||||||||||
OpBuilder< | ||||||||||
(ins "ValueRange":$inputs, "ValueRange":$outputs, | ||||||||||
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes), | ||||||||||
[{ | ||||||||||
buildElemwiseOp($_builder, $_state, std::nullopt, inputs, outputs, | ||||||||||
attributes, ElemwiseOp::getRegionBuilder()); | ||||||||||
}]> | ||||||||||
]; | ||||||||||
|
||||||||||
let hasCustomAssemblyFormat = 1; | ||||||||||
let hasFolder = 1; | ||||||||||
let hasVerifier = 1; | ||||||||||
|
||||||||||
let extraClassDeclaration = structuredOpsBaseDecls # [{ | ||||||||||
|
||||||||||
/// Get the nary category enum, e.g. `ElemwiseNAryCategory::Unary`, | ||||||||||
/// corresponding to the given fn, e.g. `ElemwiseFn::exp` | ||||||||||
static ElemwiseNAryCategory getNAryCategory(ElemwiseFn fn); | ||||||||||
|
||||||||||
/// Elementwise is always `dynamic indexing maps` i.e. `user specified` | ||||||||||
/// or `default`. Default is identity-maps. | ||||||||||
Comment on lines
+631
to
+632
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've checked the comment for
Suggested change
|
||||||||||
static bool hasDynamicIndexingMaps() { return true; } | ||||||||||
|
||||||||||
/// Implements the block region builder for the eemwiseOp. This is called | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||
/// by the 'fillStructuredOpRegion'. | ||||||||||
static void regionBuilder(ImplicitLocOpBuilder &b, | ||||||||||
Block &block, ArrayRef<NamedAttribute> attrs); | ||||||||||
|
||||||||||
static std::function<void(ImplicitLocOpBuilder &, | ||||||||||
Block &, ArrayRef<NamedAttribute>)> | ||||||||||
getRegionBuilder() { | ||||||||||
return regionBuilder; | ||||||||||
} | ||||||||||
|
||||||||||
/// Returns elementwise op kind e.g. `add` inferred from func_type attr. | ||||||||||
ElemwiseFn getElemwiseFnVal() { | ||||||||||
return getFuncType(); | ||||||||||
} | ||||||||||
Comment on lines
+646
to
+649
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why isn't |
||||||||||
|
||||||||||
/// Infer dimensionality of the `iteration space` from the result type. | ||||||||||
/// Useful when others means are not possible e.g. in case of absence of | ||||||||||
/// user-provided indexing map. | ||||||||||
Comment on lines
+651
to
+653
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Doesn't this simply return the rank of the result? There's isn't much "inference" involved 😅 As per your implementation, |
||||||||||
unsigned getResultRank(); | ||||||||||
|
||||||||||
/// Elementwise op does not have to explicitly specify iterator type | ||||||||||
/// as it is always 'parallel'. The number of 'parallel' loops is | ||||||||||
/// inferred from other means (e.g. result tensor type). | ||||||||||
Comment on lines
+656
to
+658
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This part is already included the general Op description above:
Please only document what |
||||||||||
SmallVector<utils::IteratorType> getIteratorTypesArray(); | ||||||||||
|
||||||||||
/// The default indexing maps are N identity-maps. 'N' depends on the | ||||||||||
/// arity of the elementwise op. The default case is when all input | ||||||||||
/// output tensors are same rank and no transpose/broadcast is needed. | ||||||||||
Comment on lines
+661
to
+663
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It would be good to refine this, to me "N identity-maps" could mean "identity-maps with N dims". Perhaps,
|
||||||||||
static SmallVector<AffineMap> | ||||||||||
getDefaultIndexingMaps(unsigned N, unsigned numDims, | ||||||||||
MLIRContext *context); | ||||||||||
|
||||||||||
/// Returns true if the user defined indexing maps are not equal to | ||||||||||
/// the default (identity) map. | ||||||||||
Comment on lines
+668
to
+669
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hm, this comment seems inconsistent with what the function name implies (i.e., the function name implies that this is simply returning |
||||||||||
bool hasUserDefinedMaps(); | ||||||||||
|
||||||||||
/// destination passing style interface method. | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||
::mlir::MutableOperandRange getDpsInitsMutable() { | ||||||||||
return getOutputsMutable(); | ||||||||||
} | ||||||||||
|
||||||||||
// Generic methods. | ||||||||||
std::string getLibraryCallName() { | ||||||||||
return generateLibraryCallName(getOperation()); | ||||||||||
} | ||||||||||
}]; | ||||||||||
} | ||||||||||
|
||||||||||
//===----------------------------------------------------------------------===// | ||||||||||
// Op definition for MatmulOp | ||||||||||
//===----------------------------------------------------------------------===// | ||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This might be frowned upon, but I dont think we need two different sets of enums. We could just use a range of enums in
ElemwiseFn
below to denote unary/binary and ternary functions....