-
Notifications
You must be signed in to change notification settings - Fork 31
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Hoisting vector.transfer operations for bf16 type #1012
Conversation
FYI,
|
Yes, because of |
Just to document this somewhere: brgemm-with-init-pipeline.mlirmodule {
module {
func.func @brgemm_tpp(%arg0: tensor<128x256x512xf32>, %arg1: tensor<128x512x256xf32>, %arg2: tensor<256x256xf32>) -> tensor<256x256xf32> {
%0 = linalg.batch_reduce_matmul ins(%arg0, %arg1 : tensor<128x256x512xf32>, tensor<128x512x256xf32>) outs(%arg2 : tensor<256x256xf32>) -> tensor<256x256xf32>
return %0 : tensor<256x256xf32>
}
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
%1 = transform.get_parent_op %0 {deduplicate, op_name = "builtin.module"} : (!transform.any_op) -> !transform.any_op
transform.print %1 : !transform.any_op
%2 = transform.apply_registered_pass "fold-add-into-dest" to %1 : (!transform.any_op) -> !transform.any_op
%3 = transform.apply_registered_pass "fold-into-eltwise" to %2 : (!transform.any_op) -> !transform.any_op
%4 = transform.structured.match ops{["func.func"]} in %3 : (!transform.any_op) -> !transform.any_op
%5 = transform.apply_registered_pass "convert-linalg-to-inplace" to %4 : (!transform.any_op) -> !transform.any_op
%6 = transform.apply_registered_pass "rewrite-batch-matmul-to-matmul" to %5 : (!transform.any_op) -> !transform.any_op
%7 = transform.structured.match ops{["func.func"]} in %3 : (!transform.any_op) -> !transform.any_op
%8 = transform.apply_registered_pass "conv-init-simplify" to %7 : (!transform.any_op) -> !transform.any_op
%9 = transform.apply_registered_pass "canonicalize" to %3 : (!transform.any_op) -> !transform.any_op
transform.apply_cse to %9 : !transform.any_op
%10 = transform.structured.match ops{["func.func"]} in %9 : (!transform.any_op) -> !transform.any_op
%11 = transform.apply_registered_pass "pack-conv2DNchwFchw" to %10 : (!transform.any_op) -> !transform.any_op
%12 = transform.apply_registered_pass "pack-conv2DNhwcHwcf" to %11 : (!transform.any_op) -> !transform.any_op
%13 = transform.apply_registered_pass "rewrite-conv-to-matmul-or-brgemm" to %12 : (!transform.any_op) -> !transform.any_op
%14 = transform.apply_registered_pass "pack-matmul" to %13 : (!transform.any_op) -> !transform.any_op
%15 = transform.apply_registered_pass "pack-vnni" to %14 : (!transform.any_op) -> !transform.any_op
%16 = transform.structured.match ops{["func.func"]} in %9 : (!transform.any_op) -> !transform.any_op
%17 = transform.apply_registered_pass "propagate-pack-and-unpack" to %16 : (!transform.any_op) -> !transform.any_op
%18 = transform.apply_registered_pass "constant-fold-pack" to %9 : (!transform.any_op) -> !transform.any_op
%19 = transform.structured.match ops{["func.func"]} in %18 : (!transform.any_op) -> !transform.any_op
%20 = transform.apply_registered_pass "simplify-pack" to %19 : (!transform.any_op) -> !transform.any_op
%21 = transform.apply_registered_pass "linalg-generalize-named-ops" to %20 : (!transform.any_op) -> !transform.any_op
%22 = transform.apply_registered_pass "canonicalize" to %18 : (!transform.any_op) -> !transform.any_op
transform.apply_cse to %22 : !transform.any_op
%23 = transform.structured.match ops{["func.func"]} in %22 : (!transform.any_op) -> !transform.any_op
%24 = transform.apply_registered_pass "linalg-convert-compare-select-to-maximumf-pass" to %23 : (!transform.any_op) -> !transform.any_op
%25 = transform.apply_registered_pass "tile-consumer-and-fuse-producers" to %24 : (!transform.any_op) -> !transform.any_op
%26 = transform.apply_registered_pass "simplify-pack" to %25 : (!transform.any_op) -> !transform.any_op
%27 = transform.apply_registered_pass "canonicalize" to %22 : (!transform.any_op) -> !transform.any_op
transform.apply_cse to %27 : !transform.any_op
%28 = transform.structured.match ops{["func.func"]} in %27 : (!transform.any_op) -> !transform.any_op
%29 = transform.apply_registered_pass "lower-packs-unpacks" to %28 : (!transform.any_op) -> !transform.any_op
%30 = transform.apply_registered_pass "canonicalize" to %27 : (!transform.any_op) -> !transform.any_op
transform.apply_cse to %30 : !transform.any_op
%31 = transform.structured.match ops{["func.func"]} in %30 : (!transform.any_op) -> !transform.any_op
%32 = transform.apply_registered_pass "decompose-aggregated-ops" to %31 : (!transform.any_op) -> !transform.any_op
transform.print %30 {name = "before-bufferize"} : !transform.any_op
%33 = transform.structured.match ops{["linalg.batch_reduce_matmul"]} in %30 : (!transform.any_op) -> !transform.any_op
%tiled_linalg_op, %loops:2 = transform.structured.tile_using_for %33 tile_sizes [0, 32, 0, 64] interchange = [0, 1, 2, 3] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
%34 = transform.structured.match ops{["func.func"]} in %30 : (!transform.any_op) -> !transform.any_op
%35 = transform.structured.match ops{["linalg.batch_reduce_matmul"]} in %34 : (!transform.any_op) -> !transform.any_op
%36 = transform.get_parent_op %35 {op_name = "scf.for"} : (!transform.any_op) -> !transform.any_op
transform.loop.hoist_loop_invariant_subsets %36 : !transform.any_op
transform.yield
}
} gives module {
func.func @brgemm_tpp(%arg0: tensor<128x256x512xf32>, %arg1: tensor<128x512x256xf32>, %arg2: tensor<256x256xf32>) -> tensor<256x256xf32> {
%c0 = arith.constant 0 : index
%c256 = arith.constant 256 : index
%c512 = arith.constant 512 : index
%c32 = arith.constant 32 : index
%c64 = arith.constant 64 : index
%0 = scf.for %arg3 = %c0 to %c256 step %c32 iter_args(%arg4 = %arg2) -> (tensor<256x256xf32>) {
%extracted_slice = tensor.extract_slice %arg4[%arg3, 0] [32, 256] [1, 1] : tensor<256x256xf32> to tensor<32x256xf32>
%1 = scf.for %arg5 = %c0 to %c512 step %c64 iter_args(%arg6 = %extracted_slice) -> (tensor<32x256xf32>) {
%extracted_slice_0 = tensor.extract_slice %arg0[0, %arg3, %arg5] [128, 32, 64] [1, 1, 1] : tensor<128x256x512xf32> to tensor<128x32x64xf32>
%extracted_slice_1 = tensor.extract_slice %arg1[0, %arg5, 0] [128, 64, 256] [1, 1, 1] : tensor<128x512x256xf32> to tensor<128x64x256xf32>
%2 = linalg.batch_reduce_matmul ins(%extracted_slice_0, %extracted_slice_1 : tensor<128x32x64xf32>, tensor<128x64x256xf32>) outs(%arg6 : tensor<32x256xf32>) -> tensor<32x256xf32>
scf.yield %2 : tensor<32x256xf32>
}
%inserted_slice = tensor.insert_slice %1 into %arg4[%arg3, 0] [32, 256] [1, 1] : tensor<32x256xf32> into tensor<256x256xf32>
scf.yield %inserted_slice : tensor<256x256xf32>
}
return %0 : tensor<256x256xf32>
}
} when run with That is to say: hoisting works on linalg with upstream tools though the current blocker is that bufferization doesn't do the right thing on the above IR. So if coming from linalg we can do tiling and hoisting there. If we are not coming from linalg we still need the appropriate transforms on vector. |
This PR includes changes to extend support for hoisting
vector.transfer read/write
operations outside thebatch
andk
loop forbf16
type withvnni
layout.