Skip to content

Commit

Permalink
[mlir][tensor][linalg] Move Pack/Unpack Ops to Linalg (3/4)
Browse files Browse the repository at this point in the history
This is merely moving code around, no new functionality is added.

PATCH 3: Update/move/replace all tests for `tensor.{pack|unpack}` with
identical tests for `linalg.{pack|unpack}`. Updates the testing
infrastructure accordingly and copy all the required transformations.

To help reviewing, below is an overview of non-obvious code moves:

1. Tests from:
  * "mlir/test/Dialect/Tensor/tiling.mlir"
are moved to to:
  * "mlir/test/Dialect/Linalg/transform-op-tile-pack-unpack.mlir"

2. Tests from:
  * "mlir/test/Dialect/Tensor/fold-empty-op.mlir"
are moved to:
  * "mlir/test/Dialect/Linalg/fold-empty-op.mlir"

CONTEXT:
This change was discussed in the following RFC:
* https://discourse.llvm.org/t/rfc-move-tensor-pack-and-tensor-unpack-into-linalg
  • Loading branch information
banach-space committed Feb 11, 2025
1 parent 0d6d732 commit 2ac6868
Show file tree
Hide file tree
Showing 58 changed files with 2,956 additions and 1,968 deletions.
10 changes: 10 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,16 @@ def LinalgConvolutionOpInterface : OpInterface<"ConvolutionOpInterface"> {
];
}

def LinalgRelayoutOpInterface : OpInterface<"RelayoutOpInterface"> {
let description = [{
A Linalg relayout-op is either linalg.pack or linalg.unpack.

While we could extend this interface with methods from Linalg_RelayoutOp,
this is currently not needed and left as a TODO.
}];
let cppNamespace = "::mlir::linalg";
}

def LinalgFillOpInterface : OpInterface<"FillOpInterface"> {
let description = [{
A fill operation is defined in general terms:
Expand Down
3 changes: 2 additions & 1 deletion mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ include "mlir/Dialect/Linalg/IR/LinalgBase.td"
include "mlir/Interfaces/DestinationStyleOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td"
include "mlir/IR/OpAsmInterface.td"

//===----------------------------------------------------------------------===//
Expand All @@ -31,7 +32,7 @@ include "mlir/IR/OpAsmInterface.td"
class Linalg_RelayoutOp<string mnemonic, list<Trait> traits = []> :
Op<Linalg_Dialect, mnemonic, !listconcat(traits, [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
DestinationStyleOpInterface,
DestinationStyleOpInterface, LinalgRelayoutOpInterface,
ConditionallySpeculatable, NoMemoryEffect,
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
TypesMatchWith<"result type matches type of dest",
Expand Down
86 changes: 53 additions & 33 deletions mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def ApplyDecomposeTensorPackUnpackPatternsOp
: Op<Transform_Dialect, "apply_patterns.linalg.decompose_pack_unpack",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
let description = [{
Collect patterns to decompose tensor.pack and tensor.unpack into e.g.
Collect patterns to decompose linalg.pack and linalg.unpack into e.g.
tensor::PadOp, linalg::transposeOp Ops. Requires all outer dims to be unit.
}];

Expand Down Expand Up @@ -126,6 +126,28 @@ def ApplyPadVectorizationPatternsOp : Op<Transform_Dialect,
let assemblyFormat = "attr-dict";
}

def ApplyFoldIntoPackAndUnpackPatternsOp : Op<Transform_Dialect,
"apply_patterns.tensor.fold_into_pack_and_unpack",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
let description = [{
Indicates that operations like tensor.pad and tensor.extract_slice should
be folded into linalg.pack and linalg.unpack operations, respectively.
}];

let assemblyFormat = "attr-dict";
}

def ApplyFoldPackUnpackIntoEmptyPatternsOp : Op<Transform_Dialect,
"apply_patterns.linalg.fold_pack_unpack_into_empty",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
let description = [{
// TODO:
}];

let arguments = (ins DefaultValuedAttr<BoolAttr, "false">:$fold_single_use_only);
let assemblyFormat = "attr-dict";
}

//===----------------------------------------------------------------------===//
// BufferizeToAllocationOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -547,19 +569,18 @@ def LowerPackOp : Op<Transform_Dialect, "structured.lower_pack", [
TransformOpInterface,
ReportTrackingListenerFailuresOpTrait]> {
let description = [{
Rewrite a tensor.pack into tensor.pad + tensor.expand_shape + linalg.transpose.
Rewrite a linalg.pack into tensor.pad + tensor.expand_shape + linalg.transpose.

#### Return modes

This operation ignores non-pack ops and drops them in the return.
This operation produces a silenceable failure if the rewrite fails for any
reason.
If all the operations referred to by the `target` are rewritten, the
transform succeeds.
Return handles to the newly produced pad, expand_shape and transpose ops.
This operation ignores non-pack ops and drops them in the return. This
operation produces a silenceable failure if the rewrite fails for any
reason. If all the operations referred to by the `target` are rewritten,
the transform succeeds. Return handles to the newly produced pad,
expand_shape and transpose ops.
}];

let arguments = (ins Transform_ConcreteOpType<"tensor.pack">:$target,
let arguments = (ins Transform_ConcreteOpType<"linalg.pack">:$target,
DefaultValuedAttr<BoolAttr, "true">:$lowerPadLikeWithInsertSlice);
let results = (outs Transform_ConcreteOpType<"tensor.pad">:$pad_op,
Transform_ConcreteOpType<"tensor.expand_shape">:$expand_shape_op,
Expand All @@ -571,7 +592,7 @@ def LowerPackOp : Op<Transform_Dialect, "structured.lower_pack", [
let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
::mlir::transform::TransformRewriter &rewriter,
::mlir::tensor::PackOp target,
::mlir::linalg::PackOp target,
::mlir::transform::ApplyToEachResultList &transformResults,
::mlir::transform::TransformState &state);
}];
Expand All @@ -587,20 +608,19 @@ def LowerUnPackOp : Op<Transform_Dialect, "structured.lower_unpack", [
TransformOpInterface,
ReportTrackingListenerFailuresOpTrait]> {
let description = [{
Lower a tensor.unpack into empty + linalg.transpose + tensor.collapse_shape +
Lower a linalg.unpack into empty + linalg.transpose + tensor.collapse_shape +
tensor.extract_slice.

#### Return modes

This operation ignores non-unpack ops and drops them in the return.
This operation produces a silenceable failure if the rewrite fails for any
reason.
If all the operations referred to by the `target` are rewritten, the
transform succeeds.
Return handles to the newly produced empty, transpose, collapse_shape and extract_slice ops.
This operation ignores non-unpack ops and drops them in the return. This
operation produces a silenceable failure if the rewrite fails for any
reason. If all the operations referred to by the `target` are rewritten,
the transform succeeds. Return handles to the newly produced empty,
transpose, collapse_shape and extract_slice ops.
}];

let arguments = (ins Transform_ConcreteOpType<"tensor.unpack">:$target,
let arguments = (ins Transform_ConcreteOpType<"linalg.unpack">:$target,
DefaultValuedAttr<BoolAttr, "true">:$lowerUnpadLikeWithExtractSlice);
let results = (outs Transform_ConcreteOpType<"tensor.empty">:$empty_op,
Transform_ConcreteOpType<"linalg.transpose">:$transpose_op,
Expand All @@ -613,7 +633,7 @@ def LowerUnPackOp : Op<Transform_Dialect, "structured.lower_unpack", [
let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
::mlir::transform::TransformRewriter &rewriter,
::mlir::tensor::UnPackOp target,
::mlir::linalg::UnPackOp target,
::mlir::transform::ApplyToEachResultList &transformResults,
::mlir::transform::TransformState &state);
}];
Expand Down Expand Up @@ -791,7 +811,7 @@ def PackOp : Op<Transform_Dialect, "structured.pack", [
Specifying a packed size of 0 for an iterator removes it from consideration
for packing.

`tensor.pack` (resp. `tensor.unpack`) operations are inserted for the operands
`linalg.pack` (resp. `linalg.unpack`) operations are inserted for the operands
(resp. results) that need to be packed (resp. unpacked) according to the
`packed_sizes` specification.

Expand Down Expand Up @@ -980,7 +1000,7 @@ def PackTransposeOp : Op<Transform_Dialect, "structured.pack_transpose", [
DeclareOpInterfaceMethods<TransformOpInterface>,
ReportTrackingListenerFailuresOpTrait]> {
let description = [{
Apply a transposition to a single `tensor.pack` (resp. `tensor.unpack`) and
Apply a transposition to a single `linalg.pack` (resp. `linalg.unpack`) and
update the `linalg.generic` op that consumes (resp. produces) the operation.

This transform allows composing a simple `structured.pack` with additional
Expand All @@ -989,19 +1009,19 @@ def PackTransposeOp : Op<Transform_Dialect, "structured.pack_transpose", [

The transpose spec must specify at least one of `outer_perm` or `inner_perm`
attributes, which will act upon the `outer_dims_perm` or `inner_dims_pos` of
the specified `tensor.pack` or `tensor.unpack` op.
the specified `linalg.pack` or `linalg.unpack` op.

If the `target` of this op is a `tensor.pack` then a new `tensor.empty` will
be created along with transposed versions of the `tensor.pack` and the
If the `target` of this op is a `linalg.pack` then a new `tensor.empty` will
be created along with transposed versions of the `linalg.pack` and the
consuming `linalg.generic`, which is expected to be the sole consumer.

If the `target` of this op is a `tensor.unpack` then the whole pack / compute
/ unpack chain will be transposed and transposed clones of `tensor.pack`,
the consuming `linalg.generic` and the tail `tensor.pack` will be created.
If the `target` of this op is a `linalg.unpack` then the whole pack / compute
/ unpack chain will be transposed and transposed clones of `linalg.pack`,
the consuming `linalg.generic` and the tail `linalg.pack` will be created.

#### Return modes

This operation targets a single `tensor.pack` / `tensor.unpack` op and a
This operation targets a single `linalg.pack` / `linalg.unpack` op and a
single matching `linalg.generic` that consumes / produces the op. Otherwise,
it produces a silenceableFailure.

Expand All @@ -1011,9 +1031,9 @@ def PackTransposeOp : Op<Transform_Dialect, "structured.pack_transpose", [
reason.

This operation returns 3 handles, one to the transformed LinalgOp, one to
the transformed `tensor.pack` and one to the transformed `tensor.unpack`.
The last handle for `tensor.unpack` is empty if `target_pack_or_unpack_op`
was not itself a `tensor.unpack`.
the transformed `linalg.pack` and one to the transformed `linalg.unpack`.
The last handle for `linalg.unpack` is empty if `target_pack_or_unpack_op`
was not itself a `linalg.unpack`.
}];

let arguments = (ins TransformHandleTypeInterface:$target_pack_or_un_pack_op,
Expand Down Expand Up @@ -1143,7 +1163,7 @@ def HoistPadBuildPackingLoopNestOp :
creates the packing loop nest required by the hoist_pad operation and makes
that functionality available independently.

TODO: In the future, we should consider rewriting as a tensor.pack after
TODO: In the future, we should consider rewriting as a linalg.pack after
hoisting since this abstraction is now available.

#### Return modes
Expand Down Expand Up @@ -1182,7 +1202,7 @@ def HoistPadOp : Op<Transform_Dialect, "structured.hoist_pad",
Hoist the tensor.pad target operation by at most the given number of loops.
Optionally apply the transpose attribute to the inner dimensions.

TODO: In the future, we should consider rewriting as a tensor.pack after
TODO: In the future, we should consider rewriting as a linalg.pack after
hoisting since this abstraction is now available.
TODO: Maybe also return the linalg.generic transpose created at some point.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@ class DialectRegistry;

namespace linalg {
void registerTilingInterfaceExternalModels(DialectRegistry &registry);

/// Similar to the above registeration, but it is only for `tensor.pack` and
/// `tensor.unpack` ops.
void registerTilingInterfaceExternalModelsForPackUnPackOps(
DialectRegistry &registry);
} // namespace linalg
} // namespace mlir

Expand Down
59 changes: 36 additions & 23 deletions mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -572,7 +572,7 @@ buildPackingLoopNest(RewriterBase &rewriter, tensor::PadOp opToHoist,
/// packed tensor. A `transposeVector` can change the storage order of the
/// padded tensor but does not change the order of the pack or compute loops.
///
/// TODO: In the future, we should consider rewriting as a tensor.pack after
/// TODO: In the future, we should consider rewriting as a linalg.pack after
/// hoisting since this abstraction is now available.
///
/// Example in pseudo-mlir:
Expand Down Expand Up @@ -1121,7 +1121,7 @@ struct LowerPackResult {

/// Rewrite pack as pad + reshape + transpose.
FailureOr<LowerPackResult> lowerPack(RewriterBase &rewriter,
tensor::PackOp packOp,
linalg::PackOp packOp,
bool lowerPadLikeWithInsertSlice = true);

struct LowerUnPackOpResult {
Expand All @@ -1133,14 +1133,14 @@ struct LowerUnPackOpResult {

/// Rewrite pack as empty + transpose + reshape + extract_slice.
FailureOr<LowerUnPackOpResult>
lowerUnPack(RewriterBase &rewriter, tensor::UnPackOp unPackOp,
lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp,
bool lowerUnpadLikeWithExtractSlice = true);

/// Struct to hold the result of a `pack` call.
struct PackResult {
SmallVector<tensor::PackOp> packOps;
SmallVector<linalg::PackOp> packOps;
linalg::LinalgOp packedLinalgOp;
SmallVector<tensor::UnPackOp> unPackOps;
SmallVector<linalg::UnPackOp> unPackOps;
};
/// Implement packing of a single LinalgOp by `packedSizes`.
/// There must be one packedSizes entry per `linalgOp` iterator.
Expand All @@ -1150,9 +1150,9 @@ FailureOr<PackResult> pack(RewriterBase &rewriter, linalg::LinalgOp linalgOp,

/// Struct to hold the result of a `packTranspose` call.
struct PackTransposeResult {
tensor::PackOp transposedPackOp;
linalg::PackOp transposedPackOp;
linalg::LinalgOp transposedLinalgOp;
tensor::UnPackOp transposedUnPackOp;
linalg::UnPackOp transposedUnPackOp;
};
/// Transpose a single PackOp -> LinalgOp -> UnPackOp chain and return the
/// transposed PackOp -> LinalgOp -> UnPackOp chain after replacements.
Expand All @@ -1163,8 +1163,8 @@ struct PackTransposeResult {
/// 3. `outerPerm` (resp. `innerPerm`) must be valid permutations of
/// `packOp.getOuterDimsPerm` (resp. `packOp.getInnerDimsPerm`) or empty.
FailureOr<PackTransposeResult>
packTranspose(RewriterBase &rewriter, tensor::PackOp packOp,
linalg::LinalgOp linalgOp, tensor::UnPackOp maybeUnPackOp,
packTranspose(RewriterBase &rewriter, linalg::PackOp packOp,
linalg::LinalgOp linalgOp, linalg::UnPackOp maybeUnPackOp,
ArrayRef<int64_t> outerPerm, ArrayRef<int64_t> innerPerm);

/// Pack a LinalgOp by greedily inferring matmul dimensions (m, n, k) where m
Expand Down Expand Up @@ -1525,15 +1525,15 @@ struct DecomposePadOpPattern : public OpRewritePattern<tensor::PadOp> {
const SmallVector<Value> &dynSizes) const;
};

/// Rewrites a tensor::PackOp into a sequence of:
/// Rewrites a linalg::PackOp into a sequence of:
/// * tensor::PadOp + linalg::TransposeOp + tensor::EmptyOp +
/// tensor::InsertSliceOp ops.
///
/// Requires that all the outer dims of the input tensor::PackOp are 1.
/// Requires that all the outer dims of the input linalg::PackOp are 1.
///
/// Before:
/// ```
/// %packed = tensor.pack %input
/// %packed = linalg.pack %input
/// padding_value(%pad : f32)
/// inner_dims_pos = [1, 0]
/// inner_tiles = [2, %high]
Expand All @@ -1559,20 +1559,20 @@ struct DecomposePadOpPattern : public OpRewritePattern<tensor::PadOp> {
/// : tensor<2x?xf32> into tensor<1x1x2x?xf32>
/// ```
struct DecomposeOuterUnitDimsPackOpPattern
: public OpRewritePattern<tensor::PackOp> {
using OpRewritePattern<tensor::PackOp>::OpRewritePattern;
LogicalResult matchAndRewrite(tensor::PackOp packOp,
: public OpRewritePattern<linalg::PackOp> {
using OpRewritePattern<linalg::PackOp>::OpRewritePattern;
LogicalResult matchAndRewrite(linalg::PackOp packOp,
PatternRewriter &rewriter) const override;
};

/// Rewrites a tensor::UnPackOp into a sequence of rank-reduced
/// Rewrites a linalg::UnPackOp into a sequence of rank-reduced
/// * tensor::ExtractSliceOp + linalg::TransposeOp + tensor::InsertSliceOp
///
/// Requires that all the outer dims of the input tensor::PackOp are 1.
/// Requires that all the outer dims of the input linalg::PackOp are 1.
///
/// Before:
/// ```
/// %packed = tensor.unpack %input
/// %packed = linalg.unpack %input
/// inner_dims_pos = [1, 0]
/// inner_tiles = [2, 8]
/// into %output : tensor<1x1x2x8xf32> -> tensor<5x1xf32>
Expand All @@ -1593,9 +1593,9 @@ struct DecomposeOuterUnitDimsPackOpPattern
/// : tensor<8x2xf32> to tensor<5x1xf32>
/// ```
struct DecomposeOuterUnitDimsUnPackOpPattern
: public OpRewritePattern<tensor::UnPackOp> {
using OpRewritePattern<tensor::UnPackOp>::OpRewritePattern;
LogicalResult matchAndRewrite(tensor::UnPackOp unpackOp,
: public OpRewritePattern<linalg::UnPackOp> {
using OpRewritePattern<linalg::UnPackOp>::OpRewritePattern;
LogicalResult matchAndRewrite(linalg::UnPackOp unpackOp,
PatternRewriter &rewriter) const override;
};

Expand Down Expand Up @@ -1717,7 +1717,7 @@ void populateLinalgGenericOpsSpecializationPatterns(
void populateDecomposeConvolutionPatterns(RewritePatternSet &patterns,
PatternBenefit benefit = 1);

/// Populates patterns to decompose tensor.pack and tensor.unpack Ops into e.g.
/// Populates patterns to decompose linalg.pack and linalg.unpack Ops into e.g.
/// tensor.pad, linalg.transpose, tensor.{insert|extract}_slice. Require all
/// outer dims to be unit.
void populateDecomposePackUnpackPatterns(RewritePatternSet &patterns);
Expand Down Expand Up @@ -1779,7 +1779,7 @@ void populateElementwiseOpsFusionPatterns(
RewritePatternSet &patterns,
const ControlFusionFn &controlElementwiseOpFusion);

/// Function type which is used to control propagation of tensor.pack/unpack
/// Function type which is used to control propagation of linalg.pack/unpack
/// ops.
using ControlPropagationFn = std::function<bool(OpOperand *opOperand)>;

Expand Down Expand Up @@ -1888,6 +1888,19 @@ void populateDecomposeWinogradOpsPatterns(RewritePatternSet &patterns);
/// convert to a `linalg.dot`.
void populateContractionOpRankReducingPatterns(RewritePatternSet &patterns);

/// Populates `patterns` with patterns that fold operations like `tensor.pad`
/// and `tensor.extract_slice` into `tensor.pack` and `tensor.unpack` operations
/// respectively.
void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns);

/// Populates `patterns` with patterns that fold operations like `linalg.pack`
/// and `linalg.unpack` into `tensor.empty`.
void populateFoldPackUnpackIntoTensorEmptyPatterns(RewritePatternSet &patterns);

/// Populates `patterns` with patterns that simplify `tensor.pack` and
/// `tensor.unpack` operations.
void populateSimplifyPackAndUnpackPatterns(RewritePatternSet &patterns);

} // namespace linalg
} // namespace mlir

Expand Down
Loading

0 comments on commit 2ac6868

Please sign in to comment.