Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][linalg] Extend elementwise #124661

Merged
merged 1 commit into from
Feb 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,12 @@ def Linalg_Dialect : Dialect {
}];
}

// Define the attribute enums matching elementwise op kind (e.g., add).
def ElementwiseKindAttr : EnumAttr<Linalg_Dialect,
ElementwiseKind, "elementwise_kind"> {
let assemblyFormat = "`<` $value `>`";
}

// Define the function attribute enums matching the OpDSL functions.
def UnaryFnAttr : EnumAttr<Linalg_Dialect, UnaryFn, "unary_fn"> {
let assemblyFormat = "`<` $value `>`";
Expand Down
59 changes: 59 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,65 @@ def TernaryFn : I32EnumAttr<"TernaryFn", "", [
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::linalg";
}

// Join two I32EnumAttrCase lists. This joining takes care that the
// 'int enum values' in the combined list do not overlap. It does this
// by adding to each element of second list the offset '!size(a)'.
class JoinTwoI32EnumAttrCaseList< list<I32EnumAttrCase> a,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thinking out loud, perhaps it would be easier to use bit patterns instead of joining enum lists. We won't have more than 20 operations per Category, so:

  • Unary: (op | (0xFF << 1))
  • Binary: (op | (0xFF << 2))
  • Ternary: (op | (0xFF << 3))

And set the enums above like:

  • I32EnumAttrCase<"log", (1 << 1)>
  • I32EnumAttrCase<"sub", (1 << 2)>
  • I32EnumAttrCase<"select", (1 << 3)>

etc?

Then you don't need all the complex sequential logic in the parser/verifiers.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need anymore.

Copy link
Contributor

@rolfmorel rolfmorel Feb 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the alternative join-based approach, @javedabsar1 - It's quite impressive what can done with TableGen! (For better or worse, TableGen is a programming language on its own.)

I would still think an approach like @rengolin's would lead to simpler C++. The scheme I have in mind is to just shift the arity 30 bits, e.g. I32EnumAttrCase<"abs", 2> becomes I32EnumAttrCase<"abs", 2 + (1 << 30)>,
I32EnumAttrCase<"div", 3 + (2 << 30)>
I32EnumAttrCase<"select", 0 + (3 << 30)>. This way the arity can be retrieved by just shifting right 30 bits (e.g. derivedEnumVal >> 30) and to obtain the original op code you just do derivedEnumVal & ((1 << 30) - 1).

Three nested !lfolds should now suffice to derive all the derived enum cases and ElementwiseFnLimits could go and NAryCategoryAndFn could go or be simplified.

What do you think? (If you could just state the benefits of your approach, that would also be fine OFC.)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me make a pragmatic case for C++ ...

I've not seen much (any?) sophisticated TableGen in dialect definition. C++ might simply be more familiar to folks. I've already revealed my own personal preference 😅

list<I32EnumAttrCase> b> {
int aSize = !size(a);
list<I32EnumAttrCase> result =
!foldl(a, b, acc, var,
acc # [I32EnumAttrCase<var.symbol,
!add(var.value, aSize)
>]);
}

// Flatten 'list of list of I32EnumAttrCase' to 'list of I32EnumAttrCase'.
// The flattening (via call to 'join') ensures no overlap in enum values.
class ConcatI32EnumAtrCaseList< list<list<I32EnumAttrCase>> l> {
list<I32EnumAttrCase> result =
!foldl([]<I32EnumAttrCase>, l, acc, var,
JoinTwoI32EnumAttrCaseList<acc, var>.result);
}

// Define a unified `enum class : i32` for all element-wise op functions.
def ElementwiseKind :
I32EnumAttr<"ElementwiseKind",
"",
ConcatI32EnumAtrCaseList<[UnaryFn.enumerants,
BinaryFn.enumerants,
TernaryFn.enumerants]>.result
> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::linalg";
}

// Define an `enum class : i32` that marks where each individual enum class
// e.g. UnaryFn, BinaryFn, etc. end in the unified enum class ElementwiseKind.
def ElementwiseCaseLimits : I32EnumAttr<"ElementwiseCaseLimits", "", []> {
int last_unary = !size(UnaryFn.enumerants);
int last_binary = !add(last_unary, !size(BinaryFn.enumerants));
int last_ternary = !add(last_binary, !size(TernaryFn.enumerants));

let enumerants = [
I32EnumAttrCase<"LastUnary", last_unary>,
I32EnumAttrCase<"LastBinary", last_binary>,
I32EnumAttrCase<"LastTernary", last_ternary>];
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::linalg";
}

// Define an `enum class : i32` to categorise arity elementwise ops.
def ElementwiseArityGroup : I32EnumAttr<"ElementwiseArityGroup", "", [
I32EnumAttrCase<"Unary", 1>,
I32EnumAttrCase<"Binary", 2>,
I32EnumAttrCase<"Ternary", 3>
]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::linalg";
}

def TypeFn : I32EnumAttr<"TypeFn", "", [
I32EnumAttrCase<"cast_signed", 0>,
I32EnumAttrCase<"cast_unsigned", 1>
Expand Down
120 changes: 120 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -538,6 +538,126 @@ def BroadcastOp : LinalgStructuredBase_Op<"broadcast", [
let hasCanonicalizer = 1;
}

//===----------------------------------------------------------------------===//
// Op definition for ElementwiseOp
//===----------------------------------------------------------------------===//
def ElementwiseOp : LinalgStructuredBase_Op<"elementwise", [
AttrSizedOperandSegments]> {
let summary = [{ Performs element-wise operation }];
let description = [{
The attribute `kind` describes arithmetic operation to perform. The
operation kind can be unary (e.g. max), binary (e.g. add) or ternary
(e.g. select).

By default, all indexing maps are identities. In the case of default
indexing map, all input and output shapes must match. The number of dims in
each of the identity maps is equal to the rank of the output type.

Affine-maps for operands and result are required to be provided by the user
when a transpose and/or broadcast is needed on any operand. When a map is not
provided, default identity maps are inferred for each operand.

Iterator-types are always all `parallel`.
Iterator-types are needed for constructing the underlying structured op.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure why this is relevant here. This is somehow implying the op stores iterator types on the operation. This is confusing.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed on this.


The number of dims of the iterator-types are inferred from the rank of
the result type.

Example:

Defining a unary linalg.elemwise with default indexing-map:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Old op's name, needs updating

```mlir
%exp = linalg.elemwise
kind=#linalg.elemwise_kind<exp>
ins(%x : tensor<4x16x8xf32>)
outs(%y: tensor<4x16x8xf32>) -> tensor<4x16x8xf32>
```

Defining a binary linalg.elemwise with user-defined indexing-map:
```mlir
%add = linalg.elemwise
kind=#linalg.elemwise_kind<add>
indexing_maps = [#transpose, #broadcast, #identity]
ins(%exp, %arg1 : tensor<4x16x8xf32>, tensor<4x16xf32>)
outs(%arg2: tensor<4x8x16xf32>) -> tensor<4x8x16xf32>
```
}];

let arguments = (ins
Variadic<AnyType>:$inputs,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the reason to accept AnyType on the input and not AnyShaped?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks. changed to AnyShaped.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left a comment about this below: the primary reason to allow AnyType is so that a scalar operand can be broadcast to the full output shape.

Variadic<AnyShaped>:$outputs,
ElementwiseKindAttr:$kind,
DefaultValuedOptionalAttr<AffineMapArrayAttr, "{}">:$indexing_maps
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if this should be DefaultValuedOptionalAttr. "{}" can be invalid in many cases. Instead, we should just have a builder for having a derived default value of the attribute.

);

let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
let regions = (region AnyRegion:$region);
let skipDefaultBuilders = 1;

let builders = [
OpBuilder<
(ins "ValueRange":$inputs, "ValueRange":$outputs,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
[{
buildStructuredOp($_builder, $_state, std::nullopt, inputs, outputs,
attributes, ElementwiseOp::getRegionBuilder());
}]>
];

let hasCustomAssemblyFormat = 1;
let hasFolder = 1;
let hasVerifier = 1;

let extraClassDeclaration = structuredOpsBaseDecls # [{
/// Get the arity enum corresponding to the kind of op, e.g. if arg is
/// `ElementwiseKind::add`, return `ElementwiseArityGroup::Binary`.
static ElementwiseArityGroup getArityGroup(ElementwiseKind n);

/// Both user-specified and default indexing map will always depend on
/// the current Op instance.
static bool hasDynamicIndexingMaps() { return true; }

/// Implements the block region builder for the elementwiseOp. This is
/// called by the 'fillStructuredOpRegion'.
static void regionBuilder(ImplicitLocOpBuilder &b,
Block &block, ArrayRef<NamedAttribute> attrs);

static std::function<void(ImplicitLocOpBuilder &,
Block &, ArrayRef<NamedAttribute>)>
getRegionBuilder() {
return regionBuilder;
}

/// Returns rank of the result tensor/memref. Useful for knowing
/// the dimensionality of the iteration space when others means
/// are not possible e.g. absence of user-provided indexing map.
unsigned getResultRank() {
Value output = getDpsInitOperand(0)->get();
ShapedType shapedType = llvm::cast<ShapedType>(output.getType());
return shapedType.getRank();
}

/// Returns N 'parallel' iterator types where N is rank of result.
SmallVector<utils::IteratorType> getIteratorTypesArray();

/// The default indexing maps are identities.
/// There will be N+1 such maps, where N is the arity of the Op.
static SmallVector<AffineMap>
getDefaultIndexingMaps(unsigned NumMaps, unsigned numDims,
MLIRContext *context);

/// Destination passing style interface method.
::mlir::MutableOperandRange getDpsInitsMutable() {
return getOutputsMutable();
}

// Generic methods.
std::string getLibraryCallName() {
return generateLibraryCallName(getOperation());
}
}];
}

//===----------------------------------------------------------------------===//
// Op definition for MatmulOp
//===----------------------------------------------------------------------===//
Expand Down
Loading
Loading