From 4235fb93e871a2daae4686916816a4b44e811244 Mon Sep 17 00:00:00 2001 From: Javed Absar Date: Thu, 20 Feb 2025 19:26:18 -0500 Subject: [PATCH] [mlir][linalg] Extend linalg elementwise Implements Linalg elemwise named-op following the proposal and discussions in RFC: https://discourse.llvm.org/t/rfc-extend-linalg-elemwise-named-ops-semantics/83927/1 --- .../mlir/Dialect/Linalg/IR/LinalgBase.td | 6 + .../mlir/Dialect/Linalg/IR/LinalgEnums.td | 59 +++++ .../Dialect/Linalg/IR/LinalgStructuredOps.td | 120 +++++++++ mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 227 ++++++++++++++++++ .../elementwise/generalize_named_ops.mlir | 165 +++++++++++++ .../Dialect/Linalg/elementwise/invalid.mlir | 54 +++++ .../Linalg/elementwise/round-trip.mlir | 90 +++++++ 7 files changed, 721 insertions(+) create mode 100644 mlir/test/Dialect/Linalg/elementwise/generalize_named_ops.mlir create mode 100644 mlir/test/Dialect/Linalg/elementwise/invalid.mlir create mode 100644 mlir/test/Dialect/Linalg/elementwise/round-trip.mlir diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td index 73f984dc072d3..33601c5d6dad9 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td @@ -61,6 +61,12 @@ def Linalg_Dialect : Dialect { }]; } +// Define the attribute enums matching elementwise op kind (e.g., add). +def ElementwiseKindAttr : EnumAttr { + let assemblyFormat = "`<` $value `>`"; +} + // Define the function attribute enums matching the OpDSL functions. def UnaryFnAttr : EnumAttr { let assemblyFormat = "`<` $value `>`"; diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td index e615876a95d05..ce68afe471fe8 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td @@ -55,6 +55,65 @@ def TernaryFn : I32EnumAttr<"TernaryFn", "", [ let genSpecializedAttr = 0; let cppNamespace = "::mlir::linalg"; } + +// Join two I32EnumAttrCase lists. This joining takes care that the +// 'int enum values' in the combined list do not overlap. It does this +// by adding to each element of second list the offset '!size(a)'. +class JoinTwoI32EnumAttrCaseList< list a, + list b> { + int aSize = !size(a); + list result = + !foldl(a, b, acc, var, + acc # [I32EnumAttrCase]); +} + +// Flatten 'list of list of I32EnumAttrCase' to 'list of I32EnumAttrCase'. +// The flattening (via call to 'join') ensures no overlap in enum values. +class ConcatI32EnumAtrCaseList< list> l> { + list result = + !foldl([], l, acc, var, + JoinTwoI32EnumAttrCaseList.result); +} + +// Define a unified `enum class : i32` for all element-wise op functions. +def ElementwiseKind : + I32EnumAttr<"ElementwiseKind", + "", + ConcatI32EnumAtrCaseList<[UnaryFn.enumerants, + BinaryFn.enumerants, + TernaryFn.enumerants]>.result + > { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::linalg"; +} + +// Define an `enum class : i32` that marks where each individual enum class +// e.g. UnaryFn, BinaryFn, etc. end in the unified enum class ElementwiseKind. +def ElementwiseCaseLimits : I32EnumAttr<"ElementwiseCaseLimits", "", []> { + int last_unary = !size(UnaryFn.enumerants); + int last_binary = !add(last_unary, !size(BinaryFn.enumerants)); + int last_ternary = !add(last_binary, !size(TernaryFn.enumerants)); + + let enumerants = [ + I32EnumAttrCase<"LastUnary", last_unary>, + I32EnumAttrCase<"LastBinary", last_binary>, + I32EnumAttrCase<"LastTernary", last_ternary>]; + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::linalg"; +} + +// Define an `enum class : i32` to categorise arity elementwise ops. +def ElementwiseArityGroup : I32EnumAttr<"ElementwiseArityGroup", "", [ + I32EnumAttrCase<"Unary", 1>, + I32EnumAttrCase<"Binary", 2>, + I32EnumAttrCase<"Ternary", 3> +]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::linalg"; +} + def TypeFn : I32EnumAttr<"TypeFn", "", [ I32EnumAttrCase<"cast_signed", 0>, I32EnumAttrCase<"cast_unsigned", 1> diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td index a5725d6f1507e..ce6e9e7bb28c4 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -538,6 +538,126 @@ def BroadcastOp : LinalgStructuredBase_Op<"broadcast", [ let hasCanonicalizer = 1; } +//===----------------------------------------------------------------------===// +// Op definition for ElementwiseOp +//===----------------------------------------------------------------------===// +def ElementwiseOp : LinalgStructuredBase_Op<"elementwise", [ + AttrSizedOperandSegments]> { + let summary = [{ Performs element-wise operation }]; + let description = [{ + The attribute `kind` describes arithmetic operation to perform. The + operation kind can be unary (e.g. max), binary (e.g. add) or ternary + (e.g. select). + + By default, all indexing maps are identities. In the case of default + indexing map, all input and output shapes must match. The number of dims in + each of the identity maps is equal to the rank of the output type. + + Affine-maps for operands and result are required to be provided by the user + when a transpose and/or broadcast is needed on any operand. When a map is not + provided, default identity maps are inferred for each operand. + + Iterator-types are always all `parallel`. + Iterator-types are needed for constructing the underlying structured op. + + The number of dims of the iterator-types are inferred from the rank of + the result type. + + Example: + + Defining a unary linalg.elemwise with default indexing-map: + ```mlir + %exp = linalg.elemwise + kind=#linalg.elemwise_kind + ins(%x : tensor<4x16x8xf32>) + outs(%y: tensor<4x16x8xf32>) -> tensor<4x16x8xf32> + ``` + + Defining a binary linalg.elemwise with user-defined indexing-map: + ```mlir + %add = linalg.elemwise + kind=#linalg.elemwise_kind + indexing_maps = [#transpose, #broadcast, #identity] + ins(%exp, %arg1 : tensor<4x16x8xf32>, tensor<4x16xf32>) + outs(%arg2: tensor<4x8x16xf32>) -> tensor<4x8x16xf32> + ``` + }]; + + let arguments = (ins + Variadic:$inputs, + Variadic:$outputs, + ElementwiseKindAttr:$kind, + DefaultValuedOptionalAttr:$indexing_maps + ); + + let results = (outs Variadic:$result_tensors); + let regions = (region AnyRegion:$region); + let skipDefaultBuilders = 1; + + let builders = [ + OpBuilder< + (ins "ValueRange":$inputs, "ValueRange":$outputs, + CArg<"ArrayRef", "{}">:$attributes), + [{ + buildStructuredOp($_builder, $_state, std::nullopt, inputs, outputs, + attributes, ElementwiseOp::getRegionBuilder()); + }]> + ]; + + let hasCustomAssemblyFormat = 1; + let hasFolder = 1; + let hasVerifier = 1; + + let extraClassDeclaration = structuredOpsBaseDecls # [{ + /// Get the arity enum corresponding to the kind of op, e.g. if arg is + /// `ElementwiseKind::add`, return `ElementwiseArityGroup::Binary`. + static ElementwiseArityGroup getArityGroup(ElementwiseKind n); + + /// Both user-specified and default indexing map will always depend on + /// the current Op instance. + static bool hasDynamicIndexingMaps() { return true; } + + /// Implements the block region builder for the elementwiseOp. This is + /// called by the 'fillStructuredOpRegion'. + static void regionBuilder(ImplicitLocOpBuilder &b, + Block &block, ArrayRef attrs); + + static std::function)> + getRegionBuilder() { + return regionBuilder; + } + + /// Returns rank of the result tensor/memref. Useful for knowing + /// the dimensionality of the iteration space when others means + /// are not possible e.g. absence of user-provided indexing map. + unsigned getResultRank() { + Value output = getDpsInitOperand(0)->get(); + ShapedType shapedType = llvm::cast(output.getType()); + return shapedType.getRank(); + } + + /// Returns N 'parallel' iterator types where N is rank of result. + SmallVector getIteratorTypesArray(); + + /// The default indexing maps are identities. + /// There will be N+1 such maps, where N is the arity of the Op. + static SmallVector + getDefaultIndexingMaps(unsigned NumMaps, unsigned numDims, + MLIRContext *context); + + /// Destination passing style interface method. + ::mlir::MutableOperandRange getDpsInitsMutable() { + return getOutputsMutable(); + } + + // Generic methods. + std::string getLibraryCallName() { + return generateLibraryCallName(getOperation()); + } + }]; +} + //===----------------------------------------------------------------------===// // Op definition for MatmulOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 42ea0e1197ef1..161c334c4c985 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -4058,6 +4058,233 @@ Speculation::Speculatability BatchMatmulOp::getSpeculatability() { return getGenericSpeculatabilityImpl(cast(getOperation())); } +//===----------------------------------------------------------------------===// +// ElementwiseOp +//===----------------------------------------------------------------------===// +// +namespace { +struct ArityGroupAndKind { + // The enum class {Unary, Binary, Ternary, ..} + ElementwiseArityGroup arityGroup; + + // The kind (e.g. `exp` or `add`) belonging to the arity group. + union Kind { + UnaryFn unaryFn; + BinaryFn binaryFn; + TernaryFn ternaryFn; + } kind; +}; + +unsigned getArityGroupAsUInt(ElementwiseArityGroup arityGroup) { + return static_cast(arityGroup); +} +} // namespace + +static ArityGroupAndKind getArityGroupAndKind(ElementwiseKind kind) { + constexpr int lastUnary = static_cast(ElementwiseCaseLimits::LastUnary); + constexpr int lastBinary = + static_cast(ElementwiseCaseLimits::LastBinary); + constexpr int lastTernary = + static_cast(ElementwiseCaseLimits::LastTernary); + + int val = static_cast(kind); + ArityGroupAndKind result; + + if (val < lastUnary) { + result.arityGroup = ElementwiseArityGroup::Unary; + result.kind.unaryFn = static_cast(val); + return result; + } + if (val < lastBinary) { + result.arityGroup = ElementwiseArityGroup::Binary; + result.kind.binaryFn = static_cast(val - lastUnary); + return result; + } + if (val >= lastTernary) { + llvm_unreachable("unhandled ElementwiseFn"); + } + result.arityGroup = ElementwiseArityGroup::Ternary; + result.kind.ternaryFn = static_cast(val - lastBinary); + return result; +} + +SmallVector ElementwiseOp::getIteratorTypesArray() { + auto rank = getResultRank(); + return SmallVector(rank, utils::IteratorType::parallel); +} + +SmallVector +ElementwiseOp::getDefaultIndexingMaps(unsigned numMaps, unsigned numDims, + MLIRContext *context) { + auto map = AffineMap::getMultiDimIdentityMap(numDims, context); + return SmallVector(numMaps, map); +} + +ParseResult ElementwiseOp::parse(OpAsmParser &parser, OperationState &result) { + // Expect e.g. `kind = #linalg.elemwise_kind` + Attribute attr; + mlir::linalg::ElementwiseKind elemwiseKindVal; + if (parser.parseKeyword("kind") || parser.parseEqual()) + return failure(); + + if (succeeded(parser.parseAttribute(attr))) { + auto elemwiseKindAttr = dyn_cast(attr); + if (!elemwiseKindAttr) + return parser.emitError(parser.getCurrentLocation(), + "expected ElementwiseKind attribute"); + elemwiseKindVal = elemwiseKindAttr.getValue(); + } else { + return parser.emitError(parser.getCurrentLocation(), + "expected operation 'kind' attribute"); + } + result.addAttribute( + "kind", ElementwiseKindAttr::get(parser.getContext(), elemwiseKindVal)); + + // Parse optional `indexing_maps` + SmallVector indexingMapsAttr; + Attribute mapAttr; + if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) { + if (parser.parseEqual()) + return failure(); + if (parser.parseLSquare()) + return failure(); + do { + if (parser.parseAttribute(mapAttr)) + return failure(); + if (!isa(mapAttr)) + return parser.emitError(parser.getCurrentLocation(), + "expected affine map attribute"); + indexingMapsAttr.push_back(mapAttr); + if (parser.parseOptionalComma()) + break; + } while (true); + if (parser.parseRSquare()) + return failure(); + } + // At this stage of parsing the only way to infer number of region + // args is through op kind, as input output tensors are not parsed yet. + auto arityGroupAndKind = getArityGroupAndKind(elemwiseKindVal); + int numRegionArgs = + getArityGroupAsUInt(arityGroupAndKind.arityGroup) + 1 /*output*/; + if (parseNamedStructuredOp(parser, result, numRegionArgs, + ElementwiseOp::getRegionBuilder())) { + return parser.emitError(parser.getCurrentLocation(), + "unable to parse elemwise op"); + } + + // Initialize indexingMaps, if not supplied explicitly. + if (indexingMapsAttr.empty()) { + // We need to infer the numDims of the indexing maps from the output + // type which is already parsed by now. + auto resultType = result.operands[result.operands.size() - 1].getType(); + auto shapedType = llvm::dyn_cast(resultType); + if (!shapedType) + return parser.emitError(parser.getCurrentLocation(), + "return type needs to be shaped type"); + auto numDims = shapedType.getRank(); + indexingMapsAttr = llvm::map_to_vector( + ElementwiseOp::getDefaultIndexingMaps(numRegionArgs, numDims, + parser.getContext()), + [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); }); + } + + result.addAttribute("indexing_maps", + parser.getBuilder().getArrayAttr(indexingMapsAttr)); + return success(); +} + +void ElementwiseOp::print(OpAsmPrinter &p) { + p << " kind="; + p.printAttribute(getKindAttr()); + SmallVector elidedAttrs = {"operandSegmentSizes", "kind", + "indexing_maps"}; + unsigned arity = + getArityGroupAsUInt(getArityGroupAndKind(getKind()).arityGroup); + unsigned numDims = getResultRank(); + + SmallVector indexingMaps = llvm::map_to_vector( + ElementwiseOp::getDefaultIndexingMaps(arity + 1 /*output*/, numDims, + getContext()), + [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); }); + + if (!llvm::equal(getIndexingMaps(), indexingMaps)) { + p << " indexing_maps = ["; + llvm::interleaveComma(getIndexingMaps(), p, + [&](Attribute attr) { p.printAttribute(attr); }); + p << "]"; + } + + printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(), + elidedAttrs); +} + +LogicalResult ElementwiseOp::verify() { + // All necessary checks are done either by + // - EnumAttr (e.g. unknown operation kind) + // - verifyStructuredOpInterface (incorrect map, sizes). + return success(); +} + +/// Implements the block region builder for the ElementwiseOp. This is called by +/// 'fillStructuredOpRegion'. +void ElementwiseOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block, + ArrayRef attrs) { + ElementwiseKind elemwiseKind; + for (auto attr : attrs) { + if (attr.getName() == b.getStringAttr("kind")) { + auto kindAttr = dyn_cast(attr.getValue()); + assert(kindAttr && "op kind attribute incorrectly set"); + elemwiseKind = kindAttr.getValue(); + break; + } + } + + ArityGroupAndKind groupAndKind = getArityGroupAndKind(elemwiseKind); + auto arityGroup = groupAndKind.arityGroup; + auto kind = groupAndKind.kind; + unsigned numBlockArgs = getArityGroupAsUInt(arityGroup) + 1 /*output*/; + assert(block.getNumArguments() == numBlockArgs && + "Elementwise regionBuilder number of block args mismatch"); + + RegionBuilderHelper helper(b, block); + SmallVector yields; + Value result; + + if (arityGroup == ElementwiseArityGroup::Unary) { + result = helper.buildUnaryFn(kind.unaryFn, block.getArgument(0)); + + } else if (arityGroup == ElementwiseArityGroup::Binary) { + result = helper.buildBinaryFn(kind.binaryFn, block.getArgument(0), + block.getArgument(1)); + + } else if (arityGroup == ElementwiseArityGroup::Ternary) { + result = helper.buildTernaryFn(kind.ternaryFn, block.getArgument(0), + block.getArgument(1), block.getArgument(2)); + + } else + assert(false && "found unhandled category in elemwise"); + + yields.push_back(result); + helper.yieldOutputs(yields); +} + +LogicalResult ElementwiseOp::fold(FoldAdaptor, + SmallVectorImpl &) { + return memref::foldMemRefCast(*this); +} + +void ElementwiseOp::getEffects( + SmallVectorImpl> + &effects) { + if (hasPureTensorSemantics()) + return; + getGenericEffectsImpl(effects, cast(getOperation())); +} + +Speculation::Speculatability ElementwiseOp::getSpeculatability() { + return getGenericSpeculatabilityImpl(cast(getOperation())); +} + //===----------------------------------------------------------------------===// // PackOp/UnPackOp Common //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Linalg/elementwise/generalize_named_ops.mlir b/mlir/test/Dialect/Linalg/elementwise/generalize_named_ops.mlir new file mode 100644 index 0000000000000..94a46d97e6e86 --- /dev/null +++ b/mlir/test/Dialect/Linalg/elementwise/generalize_named_ops.mlir @@ -0,0 +1,165 @@ +// RUN: mlir-opt %s -linalg-generalize-named-ops -split-input-file | FileCheck %s +// CHECK: #[[IDENTITY:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// +// CHECK: @unary_exp(%[[A:.+]]: tensor<8x16x32xf32>, %[[B:.+]]: tensor<8x16x32xf32>) +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[IDENTITY]], #[[IDENTITY]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"] +// CHECK-SAME: ins(%[[A]] +// CHECK-SAME: outs(%[[B]] +// +// CHECK: ^{{.*}}(%[[A_ARG:.+]]: f32, %[[B_ARG:.+]]: f32) +// CHECK: %[[EXP:.+]] = math.exp %[[A_ARG]] : f32 +// CHECK: linalg.yield %[[EXP]] : f32 +// +func.func @unary_exp(%A : tensor<8x16x32xf32>, %B: tensor<8x16x32xf32>) -> tensor<8x16x32xf32> { + %r = linalg.elementwise + kind=#linalg.elementwise_kind + ins(%A : tensor<8x16x32xf32>) + outs(%B: tensor<8x16x32xf32>) -> tensor<8x16x32xf32> + return %r : tensor<8x16x32xf32> +} +// ----- +// CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-DAG: #[[PROJECTION:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)> +// +// CHECK: @unary_transpose_broadcast_tanh(%[[A:.+]]: tensor<32x16xf32>, %[[B:.+]]: tensor<8x16x32xf32>) +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[PROJECTION]], #[[IDENTITY]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"] +// CHECK-SAME: ins(%[[A]] +// CHECK-SAME: outs(%[[B]] +// +// CHECK: ^{{.*}}(%[[A_ARG:.+]]: f32, %[[B_ARG:.+]]: f32) +// CHECK: %[[TANH:.+]] = math.tanh %[[A_ARG]] : f32 +// CHECK: linalg.yield %[[TANH]] : f32 +// +func.func @unary_transpose_broadcast_tanh(%A : tensor<32x16xf32>, %B: tensor<8x16x32xf32>) -> tensor<8x16x32xf32> { + %r = linalg.elementwise + kind=#linalg.elementwise_kind + indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, + affine_map<(d0, d1, d2) -> (d0, d1, d2)>] + ins(%A : tensor<32x16xf32>) + outs(%B: tensor<8x16x32xf32>) -> tensor<8x16x32xf32> + return %r : tensor<8x16x32xf32> +} +// ----- +// CHECK: #[[MAP:.+]] = affine_map<(d0, d1) -> (d0, d1)> +// +// CHECK: @binary_div_on_memrefs(%[[A:.+]]: memref<16x8xf32>, %[[B:.+]]: memref<16x8xf32>, %[[C:.+]]: memref<16x8xf32>) +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]], #[[MAP]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel"] +// CHECK-SAME: ins(%[[A]], %[[B]] +// CHECK-SAME: outs(%[[C]] +// +// CHECK: ^{{.*}}(%[[A_ARG:.+]]: f32, %[[B_ARG:.+]]: f32, %[[C_ARG:.+]]: f32) +// CHECK: %[[DIV:.+]] = arith.divf %[[A_ARG]], %[[B_ARG]] : f32 +// CHECK: linalg.yield %[[DIV]] : f32 +// +func.func @binary_div_on_memrefs(%A : memref<16x8xf32>, %B: memref<16x8xf32>, %C: memref<16x8xf32>) { + linalg.elementwise + kind=#linalg.elementwise_kind
+ ins(%A, %B: memref<16x8xf32>, memref<16x8xf32>) + outs(%C: memref<16x8xf32>) + return +} +// ----- +// CHECK: #[[MAP:.+]] = affine_map<(d0, d1) -> (d0, d1)> +// +// CHECK: @binary_mul_on_tensors(%[[A:.+]]: tensor<16x8xf32>, %[[B:.+]]: tensor<16x8xf32>, %[[C:.+]]: tensor<16x8xf32>) +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]], #[[MAP]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel"] +// CHECK-SAME: ins(%[[A]], %[[B]] +// CHECK-SAME: outs(%[[C]] +// +// CHECK: ^{{.*}}(%[[A_ARG:.+]]: f32, %[[B_ARG:.+]]: f32, %[[C_ARG:.+]]: f32) +// CHECK: %[[MUL:.+]] = arith.mulf %[[A_ARG]], %[[B_ARG]] : f32 +// CHECK: linalg.yield %[[MUL]] : f32 +// +func.func @binary_mul_on_tensors(%A : tensor<16x8xf32>, %B: tensor<16x8xf32>, %C: tensor<16x8xf32>) -> tensor<16x8xf32> { + %r = linalg.elementwise + kind=#linalg.elementwise_kind + ins(%A, %B: tensor<16x8xf32>, tensor<16x8xf32>) + outs(%C: tensor<16x8xf32>) -> tensor<16x8xf32> + return %r : tensor<16x8xf32> +} +// ----- +// CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-DAG: #[[TRANSPOSE:.+]] = affine_map<(d0, d1) -> (d1, d0)> +// +// CHECK: @binary_transpose_a(%[[A:.+]]: tensor<8x16xf32>, %[[B:.+]]: tensor<16x8xf32>, %[[C:.+]]: tensor<16x8xf32>) +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[TRANSPOSE]], #[[IDENTITY]], #[[IDENTITY]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel"] +// CHECK-SAME: ins(%[[A]], %[[B]] +// CHECK-SAME: outs(%[[C]] +// +// CHECK: ^{{.*}}(%[[A_ARG:.+]]: f32, %[[B_ARG:.+]]: f32, %[[C_ARG:.+]]: f32) +// CHECK: %[[SUB:.+]] = arith.subf %[[A_ARG]], %[[B_ARG]] : f32 +// CHECK: linalg.yield %[[SUB]] : f32 +// +func.func @binary_transpose_a(%A : tensor<8x16xf32>, %B: tensor<16x8xf32>, %C: tensor<16x8xf32>) -> tensor<16x8xf32> { + %r = linalg.elementwise + kind=#linalg.elementwise_kind + indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, + affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0, d1)>] + ins(%A, %B: tensor<8x16xf32>, tensor<16x8xf32>) + outs(%C: tensor<16x8xf32>) -> tensor<16x8xf32> + return %r : tensor<16x8xf32> +} +// ----- +// CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-DAG: #[[TRANSPOSE:.+]] = affine_map<(d0, d1) -> (d1, d0)> +// CHECK-DAG: #[[BROADCAST:.+]] = affine_map<(d0, d1) -> (d0)> +// +// CHECK: @binary_transpose_a_broadcast_b(%[[A:.+]]: tensor<8x16xf32>, %[[B:.+]]: tensor<16xf32>, %[[C:.+]]: tensor<16x8xf32>) +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[TRANSPOSE]], #[[BROADCAST]], #[[IDENTITY]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel"] +// CHECK-SAME: ins(%[[A]], %[[B]] +// CHECK-SAME: outs(%[[C]] +// +// CHECK: ^{{.*}}(%[[A_ARG:.+]]: f32, %[[B_ARG:.+]]: f32, %[[C_ARG:.+]]: f32) +// CHECK: %[[ADD:.+]] = arith.addf %[[A_ARG]], %[[B_ARG]] : f32 +// CHECK: linalg.yield %[[ADD]] : f32 +// +func.func @binary_transpose_a_broadcast_b(%A : tensor<8x16xf32>, %B: tensor<16xf32>, %C: tensor<16x8xf32>) -> tensor<16x8xf32> { + %r = linalg.elementwise + kind=#linalg.elementwise_kind + indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, + affine_map<(d0, d1) -> (d0)>, + affine_map<(d0, d1) -> (d0, d1)>] + ins(%A, %B: tensor<8x16xf32>, tensor<16xf32>) + outs(%C: tensor<16x8xf32>) -> tensor<16x8xf32> + return %r : tensor<16x8xf32> +} +// ----- +// CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-DAG: #[[PROJECTION:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)> +// +// CHECK: @ternary(%[[A:.+]]: tensor<32x16xi1>, %[[B:.+]]: tensor<8x16x32xf32>, %[[C:.+]]: tensor<8x16x32xf32>, %[[D:.+]]: tensor<8x16x32xf32>) +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[PROJECTION]], #[[IDENTITY]], #[[IDENTITY]], #[[IDENTITY]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"] +// +// CHECK-SAME: ins(%[[A]], %[[B]], %[[C]] +// CHECK-SAME: outs(%[[D]] +// +// CHECK: ^{{.*}}(%[[A_ARG:.+]]: i1, %[[B_ARG:.+]]: f32, %[[C_ARG:.+]]: f32, %[[D_ARG:.+]]: f32) +// CHECK: %[[SELECTED:.+]] = arith.select %[[A_ARG]], %[[B_ARG]], %[[C_ARG]] : f32 +// CHECK: linalg.yield %[[SELECTED]] : f32 +// +func.func @ternary(%A : tensor<32x16xi1>, %B: tensor<8x16x32xf32>, %C : tensor<8x16x32xf32>, %D : tensor<8x16x32xf32>) -> tensor<8x16x32xf32> { + %r = linalg.elementwise + kind=#linalg.elementwise_kind