Skip to content

Commit

Permalink
Enable vector to amx code generation and execution using libxsmm plat… (
Browse files Browse the repository at this point in the history
#1006)

…form setup call
  • Loading branch information
shahidact authored Feb 18, 2025
1 parent 1742063 commit 77788a8
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 2 deletions.
4 changes: 3 additions & 1 deletion include/TPP/Transforms/Utils/VNNIUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,16 @@ class LinalgOp;

namespace vnni {
namespace utils {

enum class VnniOperandRank {
TRANSPOSE = 3,
GEMM = 3,
BRGEMM_INS = 4,
BRGEMM_OUTS = 3
};

// Returns True if the current architecture supports AMX instructions.
bool hasAMX();

// Return the VNNI blocking factor if it can be determined for the given type or
// zero, otherwise.
// Optionally, an operation can be provided to give access to DLTI.
Expand Down
5 changes: 4 additions & 1 deletion lib/TPP/DefaultPipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "TPP/Dialect/Perf/PerfOps.h"
#include "TPP/Dialect/Xsmm/XsmmDialect.h"
#include "TPP/PassUtils.h"
#include "TPP/Transforms/Utils/VNNIUtils.h"
#include "mlir/Transforms/Passes.h"

#include <string>
Expand Down Expand Up @@ -187,7 +188,9 @@ struct DefaultPipeline : public tpp::impl::DefaultPipelineBase<DefaultPipeline>,
pm.addPass(createPrintIRPass());

// Lower to LLVM
pm.addPass(createConvertVectorToLLVMPass());
ConvertVectorToLLVMPassOptions options;
options.amx = vnni::utils::hasAMX();
pm.addPass(createConvertVectorToLLVMPass(options));
pm.addPass(createFinalizeMemRefToLLVMConversionPass());
pm.addPass(createConvertSCFToCFPass());
if (defParallel)
Expand Down
6 changes: 6 additions & 0 deletions lib/TPP/Transforms/Utils/VNNIUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@ namespace mlir {
namespace vnni {
namespace utils {

// Returns True if the current architecture supports AMX instructions.
bool hasAMX() {
return (libxsmm_get_target_archid() >= LIBXSMM_X86_AVX512_SPR) &&
(libxsmm_get_target_archid() < LIBXSMM_X86_ALLFEAT);
}

unsigned getVnniBlockingFactor(Type type, Operation *op) {
unsigned blockingFactor = 0;

Expand Down
19 changes: 19 additions & 0 deletions test/Passes/DefaultPipeline/amx-initialization.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@

// RUN: LIBXSMM_TARGET=spr tpp-opt --default-pipeline %s | FileCheck %s --check-prefix=CHECK-AMX-BF16


// CHECK-AMX-BF16-LABEL: llvm.func @entry
// CHECK-AMX-BF16: amx.tileloadd64
// CHECK-AMX-BF16: amx.tdpbf16ps
// CHECK-AMX-BF16: amx.tilestored64
func.func @entry(%arg0: memref<16x32xbf16>,
%arg1: memref<16x32xbf16>,
%arg2: memref<16x16xf32>) {
%0 = arith.constant 0 : index
%1 = amx.tile_load %arg0[%0, %0] : memref<16x32xbf16> into !amx.tile<16x32xbf16>
%2 = amx.tile_load %arg1[%0, %0] : memref<16x32xbf16> into !amx.tile<16x32xbf16>
%3 = amx.tile_zero : !amx.tile<16x16xf32>
%4 = amx.tile_mulf %1, %2, %3 : !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32>
amx.tile_store %arg2[%0, %0], %4 : memref<16x16xf32>, !amx.tile<16x16xf32>
return
}
4 changes: 4 additions & 0 deletions tools/tpp-run/tpp-run.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "llvm/Target/TargetOptions.h"

#include "TPP/Transforms/Utils/TensorInit.h"
#include "libxsmm.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Transforms/Passes.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
Expand Down Expand Up @@ -270,6 +271,9 @@ int main(int argc, char **argv) {
if (failed(validateInput()))
return 1;

// Initialize the underlying platform
// TODO: Move this to use the target information flags
libxsmm_init();
// Initialize the LLVM machinery
llvm::InitLLVM y(argc, argv);
llvm::InitializeNativeTarget();
Expand Down

0 comments on commit 77788a8

Please sign in to comment.