Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][linalg] Extend Linalg elemwise named ops semantics #122753

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@ def Linalg_Dialect : Dialect {
}];
}

// Define the enum-type Elemwise func attribute.
def ElemwiseFnAttr : EnumAttr<Linalg_Dialect, ElemwiseFn, "elemwise_fn"> {
let assemblyFormat = "`<` $value `>`";
}

// Define the function attribute enums matching the OpDSL functions.
def UnaryFnAttr : EnumAttr<Linalg_Dialect, UnaryFn, "unary_fn"> {
let assemblyFormat = "`<` $value `>`";
Expand Down
51 changes: 51 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,57 @@

include "mlir/IR/EnumAttr.td"

// Define an `enum class : i32` to categorise element-wise op.
def ElemwiseNAryCategory : I32EnumAttr<"ElemwiseNAryCategory", "", [
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might be frowned upon, but I dont think we need two different sets of enums. We could just use a range of enums in ElemwiseFn below to denote unary/binary and ternary functions....

I32EnumAttrCase<"Unary", 0>,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need separate enums here? Can we juse use the enum below to classify what is a unary/binary or ternary function? (I think this is a repeat comment from above)

I32EnumAttrCase<"Binary", 1>,
I32EnumAttrCase<"Ternary", 2>
]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::linalg";
}

// Define a unified `enum class : i32` for all element-wise options.
// Note: The order of individual fn (e.g. 'exp', 'log') within each
// category (Unary, Binary etc.) must match the ordering of same fn
// defined in UnaryFn, BinaryFn. This is to enable correct mapping
// from this unified enum class to different category enums.
def ElemwiseFn : I32EnumAttr<"ElemwiseFn", "", [
// Unary
I32EnumAttrCase<"exp", 0>,
I32EnumAttrCase<"log", 1>,
I32EnumAttrCase<"abs", 2>,
I32EnumAttrCase<"ceil", 3>,
I32EnumAttrCase<"floor", 4>,
I32EnumAttrCase<"negf", 5>,
I32EnumAttrCase<"reciprocal", 6>,
I32EnumAttrCase<"round", 7>,
I32EnumAttrCase<"sqrt", 8>,
I32EnumAttrCase<"rsqrt", 9>,
I32EnumAttrCase<"square", 10>,
I32EnumAttrCase<"tanh", 11>,
I32EnumAttrCase<"erf", 12>,

// Binary

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: unnecessary extra new line

I32EnumAttrCase<"add", 13>,
I32EnumAttrCase<"sub", 14>,
I32EnumAttrCase<"mul", 15>,
I32EnumAttrCase<"div", 16>,
I32EnumAttrCase<"div_unsigned", 17>,
I32EnumAttrCase<"max_signed", 18>,
I32EnumAttrCase<"min_signed", 19>,
I32EnumAttrCase<"max_unsigned", 20>,
I32EnumAttrCase<"min_unsigned", 21>,
I32EnumAttrCase<"powf", 22>,

// Ternary
I32EnumAttrCase<"select", 23>
]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::linalg";
}

// Define the function attribute enums matching the OpDSL functions.
def UnaryFn : I32EnumAttr<"UnaryFn", "", [
I32EnumAttrCase<"exp", 0>,
Expand Down
130 changes: 130 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -551,6 +551,136 @@ def BroadcastOp : LinalgStructuredBase_Op<"broadcast", [
let hasCanonicalizer = 1;
}

//===----------------------------------------------------------------------===//
// Op definition for ElemwiseOp - with user-defined maps, computation type etc.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] Lets stick with the existing convention.

Suggested change
// Op definition for ElemwiseOp - with user-defined maps, computation type etc.
// Op definition for ElemwiseOp

//===----------------------------------------------------------------------===//

def ElemwiseOp : LinalgStructuredBase_Op<"elemwise", [
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: I'd rather just use ElementWiseOp and element_wise . But totally my preference.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 for full spelling, feels more intuitive
nit to nit: I'd go with a single word version ElementwiseOp and elementwise, it's also more consistent with spelling across repo

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: I'd rather just use ElementWiseOp and element_wise . But totally my preference.

I too agree. I just wasnt sure what you folks would prefer (longer IR and precise, or shorter and guess-it)

AttrSizedOperandSegments]> {
let summary = [{ Performs element-wise operation }];
let description = [{
Linalg op form which performs element-wise computation. The attribute
`func_type` describes the operation type (e.g. add, exp). The func_type
can be any valid unary, binary, or ternary operation.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How are we defining "any valid"? I would either list all or skip this altogether.


Affine-maps for operands and result may be provided by the user. When
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Affine-maps for operands and result may be provided by the user.

IIUC, maps are required when transposing or broadcasting?

a user-defined indexing_map is not provided, identity map is inferred
for all operands. The default indexing maps are N identity-maps. ‘N’
depends on the arity of the elementwise op. The number of dims is
Comment on lines +568 to +569
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't N simply the rank of the output? And doesn't "arity" mean the number of function args? Is that relevant here?

inferred from rank of the output type. In the case of default indexing
map, the input and output shapes must all match. Affine-map for operands
Comment on lines +570 to +571
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
inferred from rank of the output type. In the case of default indexing
map, the input and output shapes must all match. Affine-map for operands
inferred from rank of the output type. In the case of default indexing
maps, the input and output shapes must match. Affine-maps for operands

and result must be only projected permutations with no zero constants.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
and result must be only projected permutations with no zero constants.
and result must be projected permutations with no zero constants.


For element-wise iterator-type is always inferred as all ‘parallel’.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
For element-wise iterator-type is always inferred as all ‘parallel’.
For element-wise Op iterator-types are always inferred as all ‘parallel’.

or

Suggested change
For element-wise iterator-type is always inferred as all ‘parallel’.
Iterator-types are always inferred as all ‘parallel’.

Iterator-type is needed for constructing this underlying structured op.
The number of dims of the iterator-type is inferred from the rank of
the result type.

Example:
Defining a unary linalg.elemwise with default indexing-map:

```mlir
%exp = linalg.elemwise
func_type=#linalg.elemwise_fn<exp>
ins(%x : tensor<4x16x8xf32>)
outs(%y: tensor<4x16x8xf32>) -> tensor<4x16x8xf32>
```

Defining a binary linalg.elemwise with user-defined indexing-map:

```mlir
%add = linalg.elemwise
func_type=#linalg.elemwise_fn<add>
indexing_maps = [#transpose, #broadcast, #identity]
ins(%exp, %arg1 : tensor<4x16x8xf32>, tensor<4x16xf32>)
outs(%arg2: tensor<4x8x16xf32>) -> tensor<4x8x16xf32>
```
}];

let arguments = (ins
Variadic<AnyType>:$inputs,
Variadic<AnyShaped>:$outputs,
ElemwiseFnAttr:$func_type,
DefaultValuedOptionalAttr<AffineMapArrayAttr, "{}">:$indexing_maps
);

let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
let regions = (region AnyRegion:$region);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need a region on the op?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to cross-pollinate the Linalg discussions.
It looks like regions are required by Linalg's interface. Thus, all ops currently have and should have a region.


let skipDefaultBuilders = 1;
let builders = [
OpBuilder<
(ins "ValueRange":$inputs, "ValueRange":$outputs,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
[{
buildElemwiseOp($_builder, $_state, std::nullopt, inputs, outputs,
attributes, ElemwiseOp::getRegionBuilder());
}]>
];

let hasCustomAssemblyFormat = 1;
let hasFolder = 1;
let hasVerifier = 1;

let extraClassDeclaration = structuredOpsBaseDecls # [{

/// Get the nary category enum, e.g. `ElemwiseNAryCategory::Unary`,
/// corresponding to the given fn, e.g. `ElemwiseFn::exp`
static ElemwiseNAryCategory getNAryCategory(ElemwiseFn fn);

/// Elementwise is always `dynamic indexing maps` i.e. `user specified`
/// or `default`. Default is identity-maps.
Comment on lines +631 to +632
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've checked the comment for hasDynamicIndexingMaps and this would make a bit more sense to me:

Suggested change
/// Elementwise is always `dynamic indexing maps` i.e. `user specified`
/// or `default`. Default is identity-maps.
/// Both user-specified and default indexing map will always depend on the current Op instance.

static bool hasDynamicIndexingMaps() { return true; }

/// Implements the block region builder for the eemwiseOp. This is called
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
/// Implements the block region builder for the eemwiseOp. This is called
/// Implements the block region builder for the elemwiseOp. This is called

/// by the 'fillStructuredOpRegion'.
static void regionBuilder(ImplicitLocOpBuilder &b,
Block &block, ArrayRef<NamedAttribute> attrs);

static std::function<void(ImplicitLocOpBuilder &,
Block &, ArrayRef<NamedAttribute>)>
getRegionBuilder() {
return regionBuilder;
}

/// Returns elementwise op kind e.g. `add` inferred from func_type attr.
ElemwiseFn getElemwiseFnVal() {
return getFuncType();
}
Comment on lines +646 to +649
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why isn't getFuncType sufficient?


/// Infer dimensionality of the `iteration space` from the result type.
/// Useful when others means are not possible e.g. in case of absence of
/// user-provided indexing map.
Comment on lines +651 to +653
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't this simply return the rank of the result? There's isn't much "inference" involved 😅 As per your implementation, return shapedType.getRank();.

unsigned getResultRank();

/// Elementwise op does not have to explicitly specify iterator type
/// as it is always 'parallel'. The number of 'parallel' loops is
/// inferred from other means (e.g. result tensor type).
Comment on lines +656 to +658
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part is already included the general Op description above:

  /// Elementwise op does not have to explicitly specify iterator type
/// as it is always 'parallel'

Please only document what getIteratorTypesArray does.

SmallVector<utils::IteratorType> getIteratorTypesArray();

/// The default indexing maps are N identity-maps. 'N' depends on the
/// arity of the elementwise op. The default case is when all input
/// output tensors are same rank and no transpose/broadcast is needed.
Comment on lines +661 to +663
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be good to refine this, to me "N identity-maps" could mean "identity-maps with N dims". Perhaps,

The default indexing maps are identities. There will be N such maps, where N is the arity of the Op.

static SmallVector<AffineMap>
getDefaultIndexingMaps(unsigned N, unsigned numDims,
MLIRContext *context);

/// Returns true if the user defined indexing maps are not equal to
/// the default (identity) map.
Comment on lines +668 to +669
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, this comment seems inconsistent with what the function name implies (i.e., the function name implies that this is simply returning true if uses specified the indexing maps). What happens if the user does specify indexing maps and these maps are identities?

bool hasUserDefinedMaps();

/// destination passing style interface method.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
/// destination passing style interface method.
/// Destination passing style interface method.

::mlir::MutableOperandRange getDpsInitsMutable() {
return getOutputsMutable();
}

// Generic methods.
std::string getLibraryCallName() {
return generateLibraryCallName(getOperation());
}
}];
}

//===----------------------------------------------------------------------===//
// Op definition for MatmulOp
//===----------------------------------------------------------------------===//
Expand Down
Loading
Loading