Skip to content

Commit

Permalink
[mlir][linalg] Extend linalg elementwise
Browse files Browse the repository at this point in the history
Implements Linalg elemwise named-op following the proposal and discussions in RFC:
https://discourse.llvm.org/t/rfc-extend-linalg-elemwise-named-ops-semantics/83927/1
  • Loading branch information
javedabsar1 committed Feb 21, 2025
1 parent ff99af7 commit 4235fb9
Show file tree
Hide file tree
Showing 7 changed files with 721 additions and 0 deletions.
6 changes: 6 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,12 @@ def Linalg_Dialect : Dialect {
}];
}

// Define the attribute enums matching elementwise op kind (e.g., add).
def ElementwiseKindAttr : EnumAttr<Linalg_Dialect,
ElementwiseKind, "elementwise_kind"> {
let assemblyFormat = "`<` $value `>`";
}

// Define the function attribute enums matching the OpDSL functions.
def UnaryFnAttr : EnumAttr<Linalg_Dialect, UnaryFn, "unary_fn"> {
let assemblyFormat = "`<` $value `>`";
Expand Down
59 changes: 59 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,65 @@ def TernaryFn : I32EnumAttr<"TernaryFn", "", [
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::linalg";
}

// Join two I32EnumAttrCase lists. This joining takes care that the
// 'int enum values' in the combined list do not overlap. It does this
// by adding to each element of second list the offset '!size(a)'.
class JoinTwoI32EnumAttrCaseList< list<I32EnumAttrCase> a,
list<I32EnumAttrCase> b> {
int aSize = !size(a);
list<I32EnumAttrCase> result =
!foldl(a, b, acc, var,
acc # [I32EnumAttrCase<var.symbol,
!add(var.value, aSize)
>]);
}

// Flatten 'list of list of I32EnumAttrCase' to 'list of I32EnumAttrCase'.
// The flattening (via call to 'join') ensures no overlap in enum values.
class ConcatI32EnumAtrCaseList< list<list<I32EnumAttrCase>> l> {
list<I32EnumAttrCase> result =
!foldl([]<I32EnumAttrCase>, l, acc, var,
JoinTwoI32EnumAttrCaseList<acc, var>.result);
}

// Define a unified `enum class : i32` for all element-wise op functions.
def ElementwiseKind :
I32EnumAttr<"ElementwiseKind",
"",
ConcatI32EnumAtrCaseList<[UnaryFn.enumerants,
BinaryFn.enumerants,
TernaryFn.enumerants]>.result
> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::linalg";
}

// Define an `enum class : i32` that marks where each individual enum class
// e.g. UnaryFn, BinaryFn, etc. end in the unified enum class ElementwiseKind.
def ElementwiseCaseLimits : I32EnumAttr<"ElementwiseCaseLimits", "", []> {
int last_unary = !size(UnaryFn.enumerants);
int last_binary = !add(last_unary, !size(BinaryFn.enumerants));
int last_ternary = !add(last_binary, !size(TernaryFn.enumerants));

let enumerants = [
I32EnumAttrCase<"LastUnary", last_unary>,
I32EnumAttrCase<"LastBinary", last_binary>,
I32EnumAttrCase<"LastTernary", last_ternary>];
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::linalg";
}

// Define an `enum class : i32` to categorise arity elementwise ops.
def ElementwiseArityGroup : I32EnumAttr<"ElementwiseArityGroup", "", [
I32EnumAttrCase<"Unary", 1>,
I32EnumAttrCase<"Binary", 2>,
I32EnumAttrCase<"Ternary", 3>
]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::linalg";
}

def TypeFn : I32EnumAttr<"TypeFn", "", [
I32EnumAttrCase<"cast_signed", 0>,
I32EnumAttrCase<"cast_unsigned", 1>
Expand Down
120 changes: 120 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -538,6 +538,126 @@ def BroadcastOp : LinalgStructuredBase_Op<"broadcast", [
let hasCanonicalizer = 1;
}

//===----------------------------------------------------------------------===//
// Op definition for ElementwiseOp
//===----------------------------------------------------------------------===//
def ElementwiseOp : LinalgStructuredBase_Op<"elementwise", [
AttrSizedOperandSegments]> {
let summary = [{ Performs element-wise operation }];
let description = [{
The attribute `kind` describes arithmetic operation to perform. The
operation kind can be unary (e.g. max), binary (e.g. add) or ternary
(e.g. select).

By default, all indexing maps are identities. In the case of default
indexing map, all input and output shapes must match. The number of dims in
each of the identity maps is equal to the rank of the output type.

Affine-maps for operands and result are required to be provided by the user
when a transpose and/or broadcast is needed on any operand. When a map is not
provided, default identity maps are inferred for each operand.

Iterator-types are always all `parallel`.
Iterator-types are needed for constructing the underlying structured op.

The number of dims of the iterator-types are inferred from the rank of
the result type.

Example:

Defining a unary linalg.elemwise with default indexing-map:
```mlir
%exp = linalg.elemwise
kind=#linalg.elemwise_kind<exp>
ins(%x : tensor<4x16x8xf32>)
outs(%y: tensor<4x16x8xf32>) -> tensor<4x16x8xf32>
```

Defining a binary linalg.elemwise with user-defined indexing-map:
```mlir
%add = linalg.elemwise
kind=#linalg.elemwise_kind<add>
indexing_maps = [#transpose, #broadcast, #identity]
ins(%exp, %arg1 : tensor<4x16x8xf32>, tensor<4x16xf32>)
outs(%arg2: tensor<4x8x16xf32>) -> tensor<4x8x16xf32>
```
}];

let arguments = (ins
Variadic<AnyType>:$inputs,
Variadic<AnyShaped>:$outputs,
ElementwiseKindAttr:$kind,
DefaultValuedOptionalAttr<AffineMapArrayAttr, "{}">:$indexing_maps
);

let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
let regions = (region AnyRegion:$region);
let skipDefaultBuilders = 1;

let builders = [
OpBuilder<
(ins "ValueRange":$inputs, "ValueRange":$outputs,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
[{
buildStructuredOp($_builder, $_state, std::nullopt, inputs, outputs,
attributes, ElementwiseOp::getRegionBuilder());
}]>
];

let hasCustomAssemblyFormat = 1;
let hasFolder = 1;
let hasVerifier = 1;

let extraClassDeclaration = structuredOpsBaseDecls # [{
/// Get the arity enum corresponding to the kind of op, e.g. if arg is
/// `ElementwiseKind::add`, return `ElementwiseArityGroup::Binary`.
static ElementwiseArityGroup getArityGroup(ElementwiseKind n);

/// Both user-specified and default indexing map will always depend on
/// the current Op instance.
static bool hasDynamicIndexingMaps() { return true; }

/// Implements the block region builder for the elementwiseOp. This is
/// called by the 'fillStructuredOpRegion'.
static void regionBuilder(ImplicitLocOpBuilder &b,
Block &block, ArrayRef<NamedAttribute> attrs);

static std::function<void(ImplicitLocOpBuilder &,
Block &, ArrayRef<NamedAttribute>)>
getRegionBuilder() {
return regionBuilder;
}

/// Returns rank of the result tensor/memref. Useful for knowing
/// the dimensionality of the iteration space when others means
/// are not possible e.g. absence of user-provided indexing map.
unsigned getResultRank() {
Value output = getDpsInitOperand(0)->get();
ShapedType shapedType = llvm::cast<ShapedType>(output.getType());
return shapedType.getRank();
}

/// Returns N 'parallel' iterator types where N is rank of result.
SmallVector<utils::IteratorType> getIteratorTypesArray();

/// The default indexing maps are identities.
/// There will be N+1 such maps, where N is the arity of the Op.
static SmallVector<AffineMap>
getDefaultIndexingMaps(unsigned NumMaps, unsigned numDims,
MLIRContext *context);

/// Destination passing style interface method.
::mlir::MutableOperandRange getDpsInitsMutable() {
return getOutputsMutable();
}

// Generic methods.
std::string getLibraryCallName() {
return generateLibraryCallName(getOperation());
}
}];
}

//===----------------------------------------------------------------------===//
// Op definition for MatmulOp
//===----------------------------------------------------------------------===//
Expand Down
Loading

0 comments on commit 4235fb9

Please sign in to comment.