diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td index 73f984dc072d3..115eaebc6aff5 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td @@ -61,6 +61,11 @@ def Linalg_Dialect : Dialect { }]; } +// Define the enum-type Elemwise func attribute. +def ElemwiseFnAttr : EnumAttr { + let assemblyFormat = "`<` $value `>`"; +} + // Define the function attribute enums matching the OpDSL functions. def UnaryFnAttr : EnumAttr { let assemblyFormat = "`<` $value `>`"; diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td index e615876a95d05..5135e9cd4386e 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td @@ -15,6 +15,57 @@ include "mlir/IR/EnumAttr.td" +// Define an `enum class : i32` to categorise element-wise op. +def ElemwiseNAryCategory : I32EnumAttr<"ElemwiseNAryCategory", "", [ + I32EnumAttrCase<"Unary", 0>, + I32EnumAttrCase<"Binary", 1>, + I32EnumAttrCase<"Ternary", 2> +]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::linalg"; +} + +// Define a unified `enum class : i32` for all element-wise options. +// Note: The order of individual fn (e.g. 'exp', 'log') within each +// category (Unary, Binary etc.) must match the ordering of same fn +// defined in UnaryFn, BinaryFn. This is to enable correct mapping +// from this unified enum class to different category enums. +def ElemwiseFn : I32EnumAttr<"ElemwiseFn", "", [ + // Unary + I32EnumAttrCase<"exp", 0>, + I32EnumAttrCase<"log", 1>, + I32EnumAttrCase<"abs", 2>, + I32EnumAttrCase<"ceil", 3>, + I32EnumAttrCase<"floor", 4>, + I32EnumAttrCase<"negf", 5>, + I32EnumAttrCase<"reciprocal", 6>, + I32EnumAttrCase<"round", 7>, + I32EnumAttrCase<"sqrt", 8>, + I32EnumAttrCase<"rsqrt", 9>, + I32EnumAttrCase<"square", 10>, + I32EnumAttrCase<"tanh", 11>, + I32EnumAttrCase<"erf", 12>, + + // Binary + + I32EnumAttrCase<"add", 13>, + I32EnumAttrCase<"sub", 14>, + I32EnumAttrCase<"mul", 15>, + I32EnumAttrCase<"div", 16>, + I32EnumAttrCase<"div_unsigned", 17>, + I32EnumAttrCase<"max_signed", 18>, + I32EnumAttrCase<"min_signed", 19>, + I32EnumAttrCase<"max_unsigned", 20>, + I32EnumAttrCase<"min_unsigned", 21>, + I32EnumAttrCase<"powf", 22>, + + // Ternary + I32EnumAttrCase<"select", 23> +]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::linalg"; +} + // Define the function attribute enums matching the OpDSL functions. def UnaryFn : I32EnumAttr<"UnaryFn", "", [ I32EnumAttrCase<"exp", 0>, diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td index fff4048ee125e..6d6ff7a5c7872 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -551,6 +551,136 @@ def BroadcastOp : LinalgStructuredBase_Op<"broadcast", [ let hasCanonicalizer = 1; } +//===----------------------------------------------------------------------===// +// Op definition for ElemwiseOp - with user-defined maps, computation type etc. +//===----------------------------------------------------------------------===// + +def ElemwiseOp : LinalgStructuredBase_Op<"elemwise", [ + AttrSizedOperandSegments]> { + let summary = [{ Performs element-wise operation }]; + let description = [{ + Linalg op form which performs element-wise computation. The attribute + `func_type` describes the operation type (e.g. add, exp). The func_type + can be any valid unary, binary, or ternary operation. + + Affine-maps for operands and result may be provided by the user. When + a user-defined indexing_map is not provided, identity map is inferred + for all operands. The default indexing maps are N identity-maps. ‘N’ + depends on the arity of the elementwise op. The number of dims is + inferred from rank of the output type. In the case of default indexing + map, the input and output shapes must all match. Affine-map for operands + and result must be only projected permutations with no zero constants. + + For element-wise iterator-type is always inferred as all ‘parallel’. + Iterator-type is needed for constructing this underlying structured op. + The number of dims of the iterator-type is inferred from the rank of + the result type. + + Example: + Defining a unary linalg.elemwise with default indexing-map: + + ```mlir + %exp = linalg.elemwise + func_type=#linalg.elemwise_fn + ins(%x : tensor<4x16x8xf32>) + outs(%y: tensor<4x16x8xf32>) -> tensor<4x16x8xf32> + ``` + + Defining a binary linalg.elemwise with user-defined indexing-map: + + ```mlir + %add = linalg.elemwise + func_type=#linalg.elemwise_fn + indexing_maps = [#transpose, #broadcast, #identity] + ins(%exp, %arg1 : tensor<4x16x8xf32>, tensor<4x16xf32>) + outs(%arg2: tensor<4x8x16xf32>) -> tensor<4x8x16xf32> + ``` + }]; + + let arguments = (ins + Variadic:$inputs, + Variadic:$outputs, + ElemwiseFnAttr:$func_type, + DefaultValuedOptionalAttr:$indexing_maps + ); + + let results = (outs Variadic:$result_tensors); + let regions = (region AnyRegion:$region); + + let skipDefaultBuilders = 1; + let builders = [ + OpBuilder< + (ins "ValueRange":$inputs, "ValueRange":$outputs, + CArg<"ArrayRef", "{}">:$attributes), + [{ + buildElemwiseOp($_builder, $_state, std::nullopt, inputs, outputs, + attributes, ElemwiseOp::getRegionBuilder()); + }]> + ]; + + let hasCustomAssemblyFormat = 1; + let hasFolder = 1; + let hasVerifier = 1; + + let extraClassDeclaration = structuredOpsBaseDecls # [{ + + /// Get the nary category enum, e.g. `ElemwiseNAryCategory::Unary`, + /// corresponding to the given fn, e.g. `ElemwiseFn::exp` + static ElemwiseNAryCategory getNAryCategory(ElemwiseFn fn); + + /// Elementwise is always `dynamic indexing maps` i.e. `user specified` + /// or `default`. Default is identity-maps. + static bool hasDynamicIndexingMaps() { return true; } + + /// Implements the block region builder for the eemwiseOp. This is called + /// by the 'fillStructuredOpRegion'. + static void regionBuilder(ImplicitLocOpBuilder &b, + Block &block, ArrayRef attrs); + + static std::function)> + getRegionBuilder() { + return regionBuilder; + } + + /// Returns elementwise op kind e.g. `add` inferred from func_type attr. + ElemwiseFn getElemwiseFnVal() { + return getFuncType(); + } + + /// Infer dimensionality of the `iteration space` from the result type. + /// Useful when others means are not possible e.g. in case of absence of + /// user-provided indexing map. + unsigned getResultRank(); + + /// Elementwise op does not have to explicitly specify iterator type + /// as it is always 'parallel'. The number of 'parallel' loops is + /// inferred from other means (e.g. result tensor type). + SmallVector getIteratorTypesArray(); + + /// The default indexing maps are N identity-maps. 'N' depends on the + /// arity of the elementwise op. The default case is when all input + /// output tensors are same rank and no transpose/broadcast is needed. + static SmallVector + getDefaultIndexingMaps(unsigned N, unsigned numDims, + MLIRContext *context); + + /// Returns true if the user defined indexing maps are not equal to + /// the default (identity) map. + bool hasUserDefinedMaps(); + + /// destination passing style interface method. + ::mlir::MutableOperandRange getDpsInitsMutable() { + return getOutputsMutable(); + } + + // Generic methods. + std::string getLibraryCallName() { + return generateLibraryCallName(getOperation()); + } + }]; +} + //===----------------------------------------------------------------------===// // Op definition for MatmulOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index c13b663dbf05b..46c5dece38172 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -203,6 +203,15 @@ static void buildMatmulOp(OpBuilder &b, OperationState &state, attributes, regionBuilder); } +static void buildElemwiseOp(OpBuilder &b, OperationState &state, + std::optional resultTensorTypes, + ValueRange inputs, ValueRange outputs, + ArrayRef attributes, + RegionBuilderFn regionBuilder) { + return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs, + attributes, regionBuilder); +} + /// Common parsing used for both named structured ops created by ods-gen and by /// manually defined C++ ops. Does not handle regions. static ParseResult @@ -3566,6 +3575,7 @@ ParseResult MatmulOp::parse(OpAsmParser &parser, OperationState &result) { return parseNamedStructuredOp(parser, result, MatmulOp::getNumRegionArgs(), MatmulOp::getRegionBuilder()); } + void MatmulOp::print(OpAsmPrinter &p) { SmallVector elidedAttrs = { "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"}; @@ -3611,5 +3621,283 @@ Speculation::Speculatability MatmulOp::getSpeculatability() { return getGenericSpeculatabilityImpl(cast(getOperation())); } +//===----------------------------------------------------------------------===// +// ElemwiseOp - with support for affine map, func_type and comp_type +//===----------------------------------------------------------------------===// +// +namespace { +struct NAryCategoryAndFn { + // The enum category class {Unary, Binary, Ternary, ..} + ElemwiseNAryCategory category; + + union NAryFn { + UnaryFn unaryFn; + BinaryFn binaryFn; + TernaryFn ternaryFn; + } fn; + + ::llvm::StringRef stringifyCategory() { + switch (category) { + case ElemwiseNAryCategory::Unary: + return "unary"; + case ElemwiseNAryCategory::Binary: + return "binary"; + case ElemwiseNAryCategory::Ternary: + return "ternary"; + } + llvm_unreachable("unknown-category"); + } + + ::llvm::StringRef stringifyFn() { + switch (category) { + case ElemwiseNAryCategory::Unary: + return stringifyUnaryFn(fn.unaryFn); + case ElemwiseNAryCategory::Binary: + return stringifyBinaryFn(fn.binaryFn); + case ElemwiseNAryCategory::Ternary: + return stringifyTernaryFn(fn.ternaryFn); + } + llvm_unreachable("unknown-fn"); + } +}; + +unsigned getArityFromCategory(ElemwiseNAryCategory category) { + switch (category) { + case ElemwiseNAryCategory::Unary: + return 1; + case ElemwiseNAryCategory::Binary: + return 2; + case ElemwiseNAryCategory::Ternary: + return 3; + } + llvm_unreachable("unhandled category"); +} +} // namespace + +static NAryCategoryAndFn getNAryCategoryAndFn(ElemwiseFn fn) { + constexpr int lastUnary = static_cast(ElemwiseFn::erf); + constexpr int lastBinary = static_cast(ElemwiseFn::powf); + constexpr int lastTernary = static_cast(ElemwiseFn::select); + + int val = static_cast(fn); + NAryCategoryAndFn result; + if (val <= lastUnary) { + result.category = ElemwiseNAryCategory::Unary; + result.fn.unaryFn = static_cast(val); + return result; + } + if (val <= lastBinary) { + result.category = ElemwiseNAryCategory::Binary; + result.fn.binaryFn = static_cast(val - lastUnary - 1); + return result; + } + if (val > lastTernary) { + llvm_unreachable("unhandled ElemwiseFn"); + } + result.category = ElemwiseNAryCategory::Ternary; + result.fn.ternaryFn = static_cast(val - lastBinary - 1); + return result; +} + +unsigned ElemwiseOp::getResultRank() { + auto output = getDpsInitOperand(0)->get(); + auto shapedType = llvm::cast(output.getType()); + return shapedType.getRank(); +} + +SmallVector ElemwiseOp::getIteratorTypesArray() { + auto rank = getResultRank(); + return SmallVector(rank, utils::IteratorType::parallel); +} + +SmallVector +ElemwiseOp::getDefaultIndexingMaps(unsigned numMaps, unsigned numDims, + MLIRContext *context) { + auto map = AffineMap::getMultiDimIdentityMap(numDims, context); + return SmallVector(numMaps, map); +} + +bool ElemwiseOp::hasUserDefinedMaps() { + auto category = getNAryCategoryAndFn(getElemwiseFnVal()).category; + auto arity = getArityFromCategory(category); + + auto numDims = getResultRank(); + SmallVector defaultMaps = + getDefaultIndexingMaps(arity + 1, numDims, this->getContext()); + SmallVector explicitMaps = getIndexingMapsArray(); + return defaultMaps != explicitMaps; +} + +ParseResult ElemwiseOp::parse(OpAsmParser &parser, OperationState &result) { + // Expect e.g. `func_type = #linalg.elemwise_fn` + Attribute attr; + mlir::linalg::ElemwiseFn elemwiseFnVal; + if (parser.parseKeyword("func_type")) + return failure(); + if (parser.parseEqual()) + return failure(); + + if (succeeded(parser.parseAttribute(attr))) { + auto elemwiseFnAttr = dyn_cast(attr); + if (!elemwiseFnAttr) + return parser.emitError(parser.getCurrentLocation(), + "expected ElemwiseFn attribute"); + elemwiseFnVal = elemwiseFnAttr.getValue(); + } else { + return parser.emitError(parser.getCurrentLocation(), + "expected 'func_type' attribute"); + } + result.addAttribute("func_type", + ElemwiseFnAttr::get(parser.getContext(), elemwiseFnVal)); + + // Parse optional `indexing_maps` + SmallVector indexingMapsAttr; + Attribute mapAttr; + if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) { + if (parser.parseEqual()) + return failure(); + + if (parser.parseLSquare()) + return failure(); + do { + if (parser.parseAttribute(mapAttr)) + return failure(); + + if (!isa(mapAttr)) + return parser.emitError(parser.getCurrentLocation(), + "expected affine map attribute"); + + indexingMapsAttr.push_back(mapAttr); + + if (parser.parseOptionalComma()) + break; + } while (true); + + if (parser.parseRSquare()) + return failure(); + } + + // At this stage of parsing the only way to infer number of region + // args is through op kind, as input output tensors are not parsed yet. + auto arityAndCategory = getNAryCategoryAndFn(elemwiseFnVal); + auto arity = getArityFromCategory(arityAndCategory.category); + int numRegionArgs = arity + 1 /*output*/; + + if (parseNamedStructuredOp(parser, result, numRegionArgs, + ElemwiseOp::getRegionBuilder())) { + return parser.emitError(parser.getCurrentLocation(), + "unable to parse elemwise op"); + } + + // Initialize indexingMaps, if not supplied explicitly. + if (indexingMapsAttr.empty()) { + // We need to infer the `number of indexing maps` needed from the result + // type which is already parsed by now. + auto resultType = result.operands[result.operands.size() - 1].getType(); + auto shapedType = llvm::dyn_cast(resultType); + if (!shapedType) + return parser.emitError(parser.getCurrentLocation(), + "return type needs to be shaped type"); + auto numDims = shapedType.getRank(); + indexingMapsAttr = llvm::map_to_vector( + ElemwiseOp::getDefaultIndexingMaps(arity + 1, numDims, + parser.getContext()), + [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); }); + } + result.addAttribute("indexing_maps", + parser.getBuilder().getArrayAttr(indexingMapsAttr)); + return success(); +} + +void ElemwiseOp::print(OpAsmPrinter &p) { + p << " func_type="; + p.printAttribute(getFuncTypeAttr()); + + SmallVector elidedAttrs = {"operandSegmentSizes", "func_type", + "indexing_maps"}; + + auto category = getNAryCategoryAndFn(getElemwiseFnVal()).category; + auto arity = getArityFromCategory(category); + + auto numDims = getResultRank(); + SmallVector indexingMaps = llvm::map_to_vector( + ElemwiseOp::getDefaultIndexingMaps(arity + 1, numDims, getContext()), + [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); }); + if (!llvm::equal(getIndexingMaps(), indexingMaps)) { + p << " indexing_maps = ["; + llvm::interleaveComma(getIndexingMaps(), p, + [&](Attribute attr) { p.printAttribute(attr); }); + p << "]"; + } + + printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(), + elidedAttrs); +} + +LogicalResult ElemwiseOp::verify() { + // All necessary checks are done either by + // - EnumAttr (e.g. unknown func_type) + // - verifyStructuredOpInterface (incorrect map, sizes). + return success(); +} + +/// Implements the block region builder for the ElemwiseOp. This is called by +/// 'fillStructuredOpRegion'. +void ElemwiseOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block, + ArrayRef attrs) { + ElemwiseFn elemwiseFn; + for (auto attr : attrs) { + if (attr.getName() == b.getStringAttr("func_type")) { + auto funcTypeAttr = dyn_cast(attr.getValue()); + assert(funcTypeAttr && "func_type attribute incorrectly set"); + elemwiseFn = funcTypeAttr.getValue(); + break; + } + } + + NAryCategoryAndFn categoryAndFn = getNAryCategoryAndFn(elemwiseFn); + ElemwiseNAryCategory category = categoryAndFn.category; + + unsigned numBlockArgs = getArityFromCategory(categoryAndFn.category) + 1; + assert(block.getNumArguments() == numBlockArgs && + "Elemwise regionBuilder number of block args mismatch"); + + RegionBuilderHelper helper(b, block); + SmallVector yields; + Value result; + + if (category == ElemwiseNAryCategory::Unary) { + result = + helper.buildUnaryFn(categoryAndFn.fn.unaryFn, block.getArgument(0)); + + } else if (category == ElemwiseNAryCategory::Binary) { + result = helper.buildBinaryFn(categoryAndFn.fn.binaryFn, + block.getArgument(0), block.getArgument(1)); + } else if (category == ElemwiseNAryCategory::Ternary) { + result = + helper.buildTernaryFn(categoryAndFn.fn.ternaryFn, block.getArgument(0), + block.getArgument(1), block.getArgument(2)); + } else + assert(false && "found unhandled category in elemwise print"); + + yields.push_back(result); + helper.yieldOutputs(yields); +} + +LogicalResult ElemwiseOp::fold(FoldAdaptor, SmallVectorImpl &) { + return memref::foldMemRefCast(*this); +} +void ElemwiseOp::getEffects( + SmallVectorImpl> + &effects) { + if (hasPureTensorSemantics()) + return; + getGenericEffectsImpl(effects, cast(getOperation())); +} + +Speculation::Speculatability ElemwiseOp::getSpeculatability() { + return getGenericSpeculatabilityImpl(cast(getOperation())); +} + } // namespace linalg } // namespace mlir diff --git a/mlir/test/Dialect/Linalg/elemwise/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/elemwise/generalize-named-ops.mlir new file mode 100644 index 0000000000000..9c5ac619515c0 --- /dev/null +++ b/mlir/test/Dialect/Linalg/elemwise/generalize-named-ops.mlir @@ -0,0 +1,170 @@ +// RUN: mlir-opt %s -linalg-generalize-named-ops -split-input-file | FileCheck %s + +// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// +// CHECK: @unary_exp(%[[A:.+]]: tensor<8x16x32xf32>, %[[B:.+]]: tensor<8x16x32xf32>) +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"] +// CHECK-SAME: ins(%[[A]] +// CHECK-SAME: outs(%[[B]] +// +// CHECK: ^{{.*}}(%[[A_ARG:.+]]: f32, %[[B_ARG:.+]]: f32) +// CHECK: %[[EXP:.+]] = math.exp %[[A_ARG]] : f32 +// CHECK: linalg.yield %[[EXP]] : f32 +// +func.func @unary_exp(%A : tensor<8x16x32xf32>, %B: tensor<8x16x32xf32>) -> tensor<8x16x32xf32> { + %r = linalg.elemwise + func_type=#linalg.elemwise_fn + ins(%A : tensor<8x16x32xf32>) + outs(%B: tensor<8x16x32xf32>) -> tensor<8x16x32xf32> + return %r : tensor<8x16x32xf32> +} + +// ----- + +// CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-DAG: #[[PROJECTION:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)> +// +// CHECK: @unary_transpose_broadcast_tanh(%[[A:.+]]: tensor<32x16xf32>, %[[B:.+]]: tensor<8x16x32xf32>) +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[PROJECTION]], #[[IDENTITY]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"] +// CHECK-SAME: ins(%[[A]] +// CHECK-SAME: outs(%[[B]] +// +// CHECK: ^{{.*}}(%[[A_ARG:.+]]: f32, %[[B_ARG:.+]]: f32) +// CHECK: %[[TANH:.+]] = math.tanh %[[A_ARG]] : f32 +// CHECK: linalg.yield %[[TANH]] : f32 +// +func.func @unary_transpose_broadcast_tanh(%A : tensor<32x16xf32>, %B: tensor<8x16x32xf32>) -> tensor<8x16x32xf32> { + %r = linalg.elemwise + func_type=#linalg.elemwise_fn + indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>] + ins(%A : tensor<32x16xf32>) + outs(%B: tensor<8x16x32xf32>) -> tensor<8x16x32xf32> + return %r : tensor<8x16x32xf32> +} + +// ----- + +// CHECK: #[[MAP:.+]] = affine_map<(d0, d1) -> (d0, d1)> +// +// CHECK: @binary_div_on_memrefs(%[[A:.+]]: memref<16x8xf32>, %[[B:.+]]: memref<16x8xf32>, %[[C:.+]]: memref<16x8xf32>) +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]], #[[MAP]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel"] +// CHECK-SAME: ins(%[[A]], %[[B]] +// CHECK-SAME: outs(%[[C]] +// +// CHECK: ^{{.*}}(%[[A_ARG:.+]]: f32, %[[B_ARG:.+]]: f32, %[[C_ARG:.+]]: f32) +// CHECK: %[[DIV:.+]] = arith.divf %[[A_ARG]], %[[B_ARG]] : f32 +// CHECK: linalg.yield %[[DIV]] : f32 +// +func.func @binary_div_on_memrefs(%A : memref<16x8xf32>, %B: memref<16x8xf32>, %C: memref<16x8xf32>) { + linalg.elemwise + func_type=#linalg.elemwise_fn
+ ins(%A, %B: memref<16x8xf32>, memref<16x8xf32>) + outs(%C: memref<16x8xf32>) + return +} + +// ----- + +// CHECK: #[[MAP:.+]] = affine_map<(d0, d1) -> (d0, d1)> +// +// CHECK: @binary_mul_on_tensors(%[[A:.+]]: tensor<16x8xf32>, %[[B:.+]]: tensor<16x8xf32>, %[[C:.+]]: tensor<16x8xf32>) +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]], #[[MAP]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel"] +// CHECK-SAME: ins(%[[A]], %[[B]] +// CHECK-SAME: outs(%[[C]] +// +// CHECK: ^{{.*}}(%[[A_ARG:.+]]: f32, %[[B_ARG:.+]]: f32, %[[C_ARG:.+]]: f32) +// CHECK: %[[MUL:.+]] = arith.mulf %[[A_ARG]], %[[B_ARG]] : f32 +// CHECK: linalg.yield %[[MUL]] : f32 +// +func.func @binary_mul_on_tensors(%A : tensor<16x8xf32>, %B: tensor<16x8xf32>, %C: tensor<16x8xf32>) -> tensor<16x8xf32> { + %r = linalg.elemwise + func_type=#linalg.elemwise_fn + ins(%A, %B: tensor<16x8xf32>, tensor<16x8xf32>) + outs(%C: tensor<16x8xf32>) -> tensor<16x8xf32> + return %r : tensor<16x8xf32> +} + +// ----- + +// CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-DAG: #[[TRANSPOSE:.+]] = affine_map<(d0, d1) -> (d1, d0)> +// +// CHECK: @binary_transpose_a(%[[A:.+]]: tensor<8x16xf32>, %[[B:.+]]: tensor<16x8xf32>, %[[C:.+]]: tensor<16x8xf32>) +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[TRANSPOSE]], #[[IDENTITY]], #[[IDENTITY]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel"] +// CHECK-SAME: ins(%[[A]], %[[B]] +// CHECK-SAME: outs(%[[C]] +// +// CHECK: ^{{.*}}(%[[A_ARG:.+]]: f32, %[[B_ARG:.+]]: f32, %[[C_ARG:.+]]: f32) +// CHECK: %[[SUB:.+]] = arith.subf %[[A_ARG]], %[[B_ARG]] : f32 +// CHECK: linalg.yield %[[SUB]] : f32 +// +func.func @binary_transpose_a(%A : tensor<8x16xf32>, %B: tensor<16x8xf32>, %C: tensor<16x8xf32>) -> tensor<16x8xf32> { + %r = linalg.elemwise + func_type=#linalg.elemwise_fn + indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>] + ins(%A, %B: tensor<8x16xf32>, tensor<16x8xf32>) + outs(%C: tensor<16x8xf32>) -> tensor<16x8xf32> + return %r : tensor<16x8xf32> +} + +// ----- + +// CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-DAG: #[[TRANSPOSE:.+]] = affine_map<(d0, d1) -> (d1, d0)> +// CHECK-DAG: #[[BROADCAST:.+]] = affine_map<(d0, d1) -> (d0)> +// +// CHECK: @binary_transpose_a_broadcast_b(%[[A:.+]]: tensor<8x16xf32>, %[[B:.+]]: tensor<16xf32>, %[[C:.+]]: tensor<16x8xf32>) +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[TRANSPOSE]], #[[BROADCAST]], #[[IDENTITY]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel"] +// CHECK-SAME: ins(%[[A]], %[[B]] +// CHECK-SAME: outs(%[[C]] +// +// CHECK: ^{{.*}}(%[[A_ARG:.+]]: f32, %[[B_ARG:.+]]: f32, %[[C_ARG:.+]]: f32) +// CHECK: %[[ADD:.+]] = arith.addf %[[A_ARG]], %[[B_ARG]] : f32 +// CHECK: linalg.yield %[[ADD]] : f32 +// +func.func @binary_transpose_a_broadcast_b(%A : tensor<8x16xf32>, %B: tensor<16xf32>, %C: tensor<16x8xf32>) -> tensor<16x8xf32> { + %r = linalg.elemwise + func_type=#linalg.elemwise_fn + indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>] + ins(%A, %B: tensor<8x16xf32>, tensor<16xf32>) + outs(%C: tensor<16x8xf32>) -> tensor<16x8xf32> + return %r : tensor<16x8xf32> +} + +// ----- + +// CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-DAG: #[[PROJECTION:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)> +// +// CHECK: @ternary(%[[A:.+]]: tensor<32x16xi1>, %[[B:.+]]: tensor<8x16x32xf32>, %[[C:.+]]: tensor<8x16x32xf32>, %[[D:.+]]: tensor<8x16x32xf32>) +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[PROJECTION]], #[[IDENTITY]], #[[IDENTITY]], #[[IDENTITY]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"] +// +// CHECK-SAME: ins(%[[A]], %[[B]], %[[C]] +// CHECK-SAME: outs(%[[D]] +// +// CHECK: ^{{.*}}(%[[A_ARG:.+]]: i1, %[[B_ARG:.+]]: f32, %[[C_ARG:.+]]: f32, %[[D_ARG:.+]]: f32) +// CHECK: %[[SELECTED:.+]] = arith.select %[[A_ARG]], %[[B_ARG]], %[[C_ARG]] : f32 +// CHECK: linalg.yield %[[SELECTED]] : f32 +// +func.func @ternary(%A : tensor<32x16xi1>, %B: tensor<8x16x32xf32>, %C : tensor<8x16x32xf32>, %D : tensor<8x16x32xf32>) -> tensor<8x16x32xf32> { + %r = linalg.elemwise + func_type=#linalg.elemwise_fn