Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hoisting vector.transfer operations for bf16 type #1012

Merged
merged 3 commits into from
Feb 17, 2025

Conversation

arun-thmn
Copy link
Contributor

@arun-thmn arun-thmn commented Feb 12, 2025

This PR includes changes to extend support for hoisting vector.transfer read/write operations outside the batch and k loop for bf16 type with vnni layout.

@arun-thmn arun-thmn marked this pull request as ready for review February 12, 2025 15:37
@arun-thmn arun-thmn added the benchmark-full Benchmark all targets label Feb 12, 2025
@adam-smnk
Copy link
Contributor

FYI, linalg::hoistRedundantVectorTransfers should be able to handle this hoisting if we manage to relax its current aliasing check:

//   2. source operands for transfer.{read|write} do not originate from
//      Ops implementing ViewLikeOpInterface.

@arun-thmn
Copy link
Contributor Author

arun-thmn commented Feb 17, 2025

FYI, linalg::hoistRedundantVectorTransfers should be able to handle this hoisting if we manage to relax its current aliasing check:

//   2. source operands for transfer.{read|write} do not originate from
//      Ops implementing ViewLikeOpInterface.

Yes, because of subview in memref type, it assumes every reference to subview as an alias. And, avoids hoisting.

@arun-thmn arun-thmn merged commit 1742063 into libxsmm:main Feb 17, 2025
14 checks passed
@rolfmorel
Copy link
Contributor

Just to document this somewhere:

brgemm-with-init-pipeline.mlir

module {
  module {
    func.func @brgemm_tpp(%arg0: tensor<128x256x512xf32>, %arg1: tensor<128x512x256xf32>, %arg2: tensor<256x256xf32>) -> tensor<256x256xf32> {
      %0 = linalg.batch_reduce_matmul ins(%arg0, %arg1 : tensor<128x256x512xf32>, tensor<128x512x256xf32>) outs(%arg2 : tensor<256x256xf32>) -> tensor<256x256xf32>
      return %0 : tensor<256x256xf32>
    }
  }
  module attributes {transform.with_named_sequence} {
    transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
      %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
      %1 = transform.get_parent_op %0 {deduplicate, op_name = "builtin.module"} : (!transform.any_op) -> !transform.any_op
      transform.print %1 : !transform.any_op
      %2 = transform.apply_registered_pass "fold-add-into-dest" to %1 : (!transform.any_op) -> !transform.any_op
      %3 = transform.apply_registered_pass "fold-into-eltwise" to %2 : (!transform.any_op) -> !transform.any_op
      %4 = transform.structured.match ops{["func.func"]} in %3 : (!transform.any_op) -> !transform.any_op
      %5 = transform.apply_registered_pass "convert-linalg-to-inplace" to %4 : (!transform.any_op) -> !transform.any_op
      %6 = transform.apply_registered_pass "rewrite-batch-matmul-to-matmul" to %5 : (!transform.any_op) -> !transform.any_op
      %7 = transform.structured.match ops{["func.func"]} in %3 : (!transform.any_op) -> !transform.any_op
      %8 = transform.apply_registered_pass "conv-init-simplify" to %7 : (!transform.any_op) -> !transform.any_op
      %9 = transform.apply_registered_pass "canonicalize" to %3 : (!transform.any_op) -> !transform.any_op
      transform.apply_cse to %9 : !transform.any_op
      %10 = transform.structured.match ops{["func.func"]} in %9 : (!transform.any_op) -> !transform.any_op
      %11 = transform.apply_registered_pass "pack-conv2DNchwFchw" to %10 : (!transform.any_op) -> !transform.any_op
      %12 = transform.apply_registered_pass "pack-conv2DNhwcHwcf" to %11 : (!transform.any_op) -> !transform.any_op
      %13 = transform.apply_registered_pass "rewrite-conv-to-matmul-or-brgemm" to %12 : (!transform.any_op) -> !transform.any_op
      %14 = transform.apply_registered_pass "pack-matmul" to %13 : (!transform.any_op) -> !transform.any_op
      %15 = transform.apply_registered_pass "pack-vnni" to %14 : (!transform.any_op) -> !transform.any_op
      %16 = transform.structured.match ops{["func.func"]} in %9 : (!transform.any_op) -> !transform.any_op
      %17 = transform.apply_registered_pass "propagate-pack-and-unpack" to %16 : (!transform.any_op) -> !transform.any_op
      %18 = transform.apply_registered_pass "constant-fold-pack" to %9 : (!transform.any_op) -> !transform.any_op
      %19 = transform.structured.match ops{["func.func"]} in %18 : (!transform.any_op) -> !transform.any_op
      %20 = transform.apply_registered_pass "simplify-pack" to %19 : (!transform.any_op) -> !transform.any_op
      %21 = transform.apply_registered_pass "linalg-generalize-named-ops" to %20 : (!transform.any_op) -> !transform.any_op
      %22 = transform.apply_registered_pass "canonicalize" to %18 : (!transform.any_op) -> !transform.any_op
      transform.apply_cse to %22 : !transform.any_op
      %23 = transform.structured.match ops{["func.func"]} in %22 : (!transform.any_op) -> !transform.any_op
      %24 = transform.apply_registered_pass "linalg-convert-compare-select-to-maximumf-pass" to %23 : (!transform.any_op) -> !transform.any_op
      %25 = transform.apply_registered_pass "tile-consumer-and-fuse-producers" to %24 : (!transform.any_op) -> !transform.any_op
      %26 = transform.apply_registered_pass "simplify-pack" to %25 : (!transform.any_op) -> !transform.any_op
      %27 = transform.apply_registered_pass "canonicalize" to %22 : (!transform.any_op) -> !transform.any_op
      transform.apply_cse to %27 : !transform.any_op
      %28 = transform.structured.match ops{["func.func"]} in %27 : (!transform.any_op) -> !transform.any_op
      %29 = transform.apply_registered_pass "lower-packs-unpacks" to %28 : (!transform.any_op) -> !transform.any_op
      %30 = transform.apply_registered_pass "canonicalize" to %27 : (!transform.any_op) -> !transform.any_op
      transform.apply_cse to %30 : !transform.any_op
      %31 = transform.structured.match ops{["func.func"]} in %30 : (!transform.any_op) -> !transform.any_op
      %32 = transform.apply_registered_pass "decompose-aggregated-ops" to %31 : (!transform.any_op) -> !transform.any_op
      transform.print %30 {name = "before-bufferize"} : !transform.any_op
      %33 = transform.structured.match ops{["linalg.batch_reduce_matmul"]} in %30 : (!transform.any_op) -> !transform.any_op
      %tiled_linalg_op, %loops:2 = transform.structured.tile_using_for %33 tile_sizes [0, 32, 0, 64] interchange = [0, 1, 2, 3] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
      %34 = transform.structured.match ops{["func.func"]} in %30 : (!transform.any_op) -> !transform.any_op
      %35 = transform.structured.match ops{["linalg.batch_reduce_matmul"]} in %34 : (!transform.any_op) -> !transform.any_op
      %36 = transform.get_parent_op %35 {op_name = "scf.for"} : (!transform.any_op) -> !transform.any_op
      transform.loop.hoist_loop_invariant_subsets %36 : !transform.any_op
      transform.yield 
    }
  }

gives

module {                                                                                                                                                                                                                                                                  
    func.func @brgemm_tpp(%arg0: tensor<128x256x512xf32>, %arg1: tensor<128x512x256xf32>, %arg2: tensor<256x256xf32>) -> tensor<256x256xf32> {                                                                                                                              
      %c0 = arith.constant 0 : index                                                                                                                                                                                                                                        
      %c256 = arith.constant 256 : index                                                                                                                                                                                                                                    
      %c512 = arith.constant 512 : index                                                                                                                                                                                                                                    
      %c32 = arith.constant 32 : index                                                                                                                                                                                                                                      
      %c64 = arith.constant 64 : index                                                                                                                                                                                                                                      
      %0 = scf.for %arg3 = %c0 to %c256 step %c32 iter_args(%arg4 = %arg2) -> (tensor<256x256xf32>) {                                                                                                                                                                       
        %extracted_slice = tensor.extract_slice %arg4[%arg3, 0] [32, 256] [1, 1] : tensor<256x256xf32> to tensor<32x256xf32>                                                                                                                                                
        %1 = scf.for %arg5 = %c0 to %c512 step %c64 iter_args(%arg6 = %extracted_slice) -> (tensor<32x256xf32>) {                                                                                                                                                           
          %extracted_slice_0 = tensor.extract_slice %arg0[0, %arg3, %arg5] [128, 32, 64] [1, 1, 1] : tensor<128x256x512xf32> to tensor<128x32x64xf32>                                                                                                                       
          %extracted_slice_1 = tensor.extract_slice %arg1[0, %arg5, 0] [128, 64, 256] [1, 1, 1] : tensor<128x512x256xf32> to tensor<128x64x256xf32>                                                                                                                         
          %2 = linalg.batch_reduce_matmul ins(%extracted_slice_0, %extracted_slice_1 : tensor<128x32x64xf32>, tensor<128x64x256xf32>) outs(%arg6 : tensor<32x256xf32>) -> tensor<32x256xf32>                                                                                
          scf.yield %2 : tensor<32x256xf32>                                                                                           
        }                                                          
        %inserted_slice = tensor.insert_slice %1 into %arg4[%arg3, 0] [32, 256] [1, 1] : tensor<32x256xf32> into tensor<256x256xf32>                                                                                                                                        
        scf.yield %inserted_slice : tensor<256x256xf32>                                                                               
      }                                                            
      return %0 : tensor<256x256xf32>                              
    }                                                              
  }

when run with bin/tpp-opt -load-tpp-dialects -transform-interpreter -canonicalize (where the first pass is from the schedules PR).

That is to say: hoisting works on linalg with upstream tools though the current blocker is that bufferization doesn't do the right thing on the above IR. So if coming from linalg we can do tiling and hoisting there. If we are not coming from linalg we still need the appropriate transforms on vector.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
benchmark-full Benchmark all targets
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants