Skip to content

Commit

Permalink
Merge branch 'main' into copy-local-runtime
Browse files Browse the repository at this point in the history
  • Loading branch information
oraluben authored Feb 6, 2025
2 parents 4b324fc + 94643b2 commit 1283b0e
Show file tree
Hide file tree
Showing 40 changed files with 1,676 additions and 593 deletions.
8 changes: 1 addition & 7 deletions include/triton/Dialect/TritonGPU/Transforms/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -200,13 +200,7 @@ StringRef getAMDArch(Operation *module);
std::optional<mlir::triton::gpu::SwizzledSharedEncodingAttr>
getSharedEncIfAllUsersAreDotEnc(Value val, bool &incompatible);

enum class MMALoadType {
SharedV3,
Registers, // may be v2 or v3
DoNotPipeline, // could be a valid shared/registers MMA operand, but skip
// pipelining
};
MMALoadType getMMALoadType(Operation *loadOp);
bool canUseMMAv3Pipelining(Operation *loadOp);

// Convert \param op operands and results to layout \param encoding.
void convertOpEncoding(Attribute encoding, Operation *op);
Expand Down
21 changes: 15 additions & 6 deletions include/triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,18 @@ mlir::LogicalResult createTMADesc(mlir::Value tmaPtr,

Value elemSizeVal = builder.template create<arith::ConstantOp>(
loc, builder.getI64Type(), builder.getI64IntegerAttr(elemSize));
Value globalStride = builder.template create<arith::MulIOp>(
loc, op.getStrides()[0], elemSizeVal);

SmallVector<Value> globalDim(llvm::reverse(op.getShape()));
SmallVector<Value> globalStride;
for (int k = op.getStrides().size() - 2; k >= 0; --k) {
globalStride.push_back(op.getStrides()[k]);
}

SmallVector<Value> elementStride(globalDim.size(), mkI32Constant(1));

for (int i = 0; i < globalStride.size(); ++i)
globalStride[i] = builder.template create<arith::MulIOp>(
loc, globalStride[i], elemSizeVal);

int elemTypeEnum;
switch (elemSize) {
Expand All @@ -75,15 +85,14 @@ mlir::LogicalResult createTMADesc(mlir::Value tmaPtr,
}
}

auto one = mkI32Constant(1);
builder.template create<triton::ExperimentalTensormapCreateOp>(
loc,
/*desc_ptr=*/tmaPtr,
/*global_address=*/op.getBase(),
/*box_dim=*/boxDim,
/*global_dim=*/ValueRange{op.getShape()[1], op.getShape()[0]},
/*global_stride=*/ValueRange{globalStride},
/*element_strides=*/ValueRange{one, one},
/*global_dim=*/globalDim,
/*global_stride=*/globalStride,
/*element_strides=*/elementStride,
/*elem_type*/ builder.getI32IntegerAttr(elemTypeEnum),
/*interleave_layout*/ builder.getI32IntegerAttr(0),
/*swizzle_mode=*/builder.getI32IntegerAttr(swizzle_mode),
Expand Down
2 changes: 1 addition & 1 deletion lib/Analysis/AxisInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -935,7 +935,7 @@ class ShROpAxisInfoVisitor final : public BinaryOpVisitorImpl<OpTy> {
// Treat [2^n,2^n+1,...]'s divisibility as 1 instead of 2^n
lhsDivisibility = 1;
}
return std::max<int64_t>(1, lhsDivisibility / (1 << shift));
return std::max<int64_t>(1, lhsDivisibility / (int64_t(1) << shift));
}

int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,
Expand Down
9 changes: 8 additions & 1 deletion lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,17 @@ Type TritonGPUToLLVMTypeConverter::convertTritonTensorType(
Type TritonGPUToLLVMTypeConverter::convertMemDescType(
MemDescType type, const TargetInfoBase &targetInfo) {
auto ctx = type.getContext();
SmallVector<Type, 4> types;
// base ptr
auto ptrType =
LLVM::LLVMPointerType::get(ctx, targetInfo.getSharedAddressSpace());

if (isa<triton::nvidia_gpu::TensorMemoryEncodingAttr,
triton::nvidia_gpu::TensorMemoryScalesEncodingAttr>(
type.getEncoding())) {
return ptrType;
}

SmallVector<Type, 4> types;
types.push_back(ptrType);
auto rank = type.getRank();
// offsets
Expand Down
6 changes: 3 additions & 3 deletions lib/Dialect/Triton/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1258,15 +1258,15 @@ LogicalResult ExperimentalDescriptorStoreOp::verify() {
LogicalResult ExperimentalTensormapCreateOp::verify() {
auto rank = getBoxDim().size();
if (getGlobalDim().size() != rank) {
return emitError("Rank mismatch for global dim. Got")
return emitError("Rank mismatch for global dim. Got ")
<< getGlobalDim().size() << " but expected " << rank;
}
if (getGlobalStride().size() + 1 != rank) {
return emitError("Rank mismatch for global stride. Got")
return emitError("Rank mismatch for global stride. Got ")
<< getGlobalStride().size() << " but expected " << rank - 1;
}
if (getElementStride().size() != rank) {
return emitError("Rank mismatch for element stride. Got")
return emitError("Rank mismatch for element stride. Got ")
<< getElementStride().size() << " but expected " << rank;
}
return success();
Expand Down
10 changes: 5 additions & 5 deletions lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -196,9 +196,10 @@ LinearLayout sharedToLinearLayoutLeadingOffset(ArrayRef<int64_t> shape,
auto outDimNames = standardOutDimNames(ctx, rank);

// Construct bases for a the layout's 2-dimensional tile.
assert(shape.size() >= 2);
int colDim = shared.getTransposed() ? 0 : 1;
int rowDim = shared.getTransposed() ? 1 : 0;
assert(rank >= 2);
int batchDims = rank - 2;
int colDim = batchDims + (shared.getTransposed() ? 0 : 1);
int rowDim = batchDims + (shared.getTransposed() ? 1 : 0);

int tileRows = 8;
int tileCols = 8 * tileWidthBytes / elemBitWidth;
Expand Down Expand Up @@ -254,8 +255,7 @@ LinearLayout sharedToLinearLayoutLeadingOffset(ArrayRef<int64_t> shape,
LinearLayout({{S("offset"), bases2D}}, {rowDimName, colDimName});

// Add the remaining dimensions.
for (int i = 2; i < rank; i++) {
int dim = shared.getTransposed() ? i : 1 - i;
for (int dim = batchDims - 1; dim >= 0; --dim) {
tileLayout *=
LinearLayout::identity1D(shape[dim], S("offset"), outDimNames[dim]);
}
Expand Down
14 changes: 0 additions & 14 deletions lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -338,20 +338,6 @@ class BlockedToMMA : public mlir::OpRewritePattern<DotOp> {
bool bFromLoad = comesFromLoadOrBlockArg(dotOp.getB());
bool transpose = false;
auto origDotOp = dotOp;
if (aFromLoad && !bFromLoad) {
// If the lhs is not a load and the rhs is, we transpose the inputs
// and the result provided this allows us to use mmav3
// We transpose the result at the end of the rewrite
DotOp transDot = transposeDotOp(rewriter, dotOp);
if (getMMAVersionSafe(computeCapability, transDot) == 3) {
dotOp = transDot;
versionMajor = 3;
transpose = true;
}
std::swap(aFromLoad, bFromLoad);
}
// If !aFromLoad && !bFromLoad, we just accept a shmem roundtrip
// for versionMajor == 3

Value a = dotOp.getA();
Value b = dotOp.getB();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,6 @@ bool isPipeliningBeneficial(Operation *op, Operation *finalUser,
if (isa<tt::ExperimentalDescriptorLoadOp, tt::ExperimentalDescriptorGatherOp>(
op))
return true;
if (isa<ttng::WarpGroupDotOp>(finalUser) &&
getMMALoadType(op) == MMALoadType::DoNotPipeline) {
LDBG("Load " << *op << " used by WarpGroupDotOp with incompatible layout");
return false;
}
if (!canHaveSharedEncoding(cast<tt::LoadOp>(op))) {
LDBG("Load " << *op << " cannot have shared encoding");
return false;
Expand Down
21 changes: 12 additions & 9 deletions lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ struct LoadInfo {
// Blocked encoding is used for loads not used by the dot.
ttg::BlockedEncodingAttr blockedEncoding = nullptr;
bool isMMAv3Shared = false;
bool isMMAv3Registers = false;
bool isMMAv5Scale = false;
int distToUse = 0;
bool usedByDot = false;
Expand Down Expand Up @@ -518,21 +517,23 @@ assignMemoryLayouts(scf::ForOp &forOp,
loadsToPipeline.insert(&op);
LoadInfo loadInfo;
for (auto use : users) {
// By default we will try pipelining with load to registers at the end.
// For mmav3 we can try leaving the operands in shared memory.
bool mmav3Shmem = false;
if (isa<mlir::triton::DotOpInterface>(use)) {
LDBG("set shared encoding with dot user: " << *use);
auto mmaLoadType = getMMALoadType(&op);
auto dot = dyn_cast<tt::DotOp>(use);
auto warpGroupDot = dyn_cast<ttng::WarpGroupDotOp>(use);
bool isMMAv3v5Dot = isa<ttng::WarpGroupDotOp, ttng::TCGen5MMAOp,
ttng::TCGen5MMAScaledOp>(use);
mmav3Shmem = canUseMMAv3Pipelining(&op) && isMMAv3v5Dot;

loadInfo.usedByDot = true;
loadInfo.isMMAv3Shared = mmaLoadType == MMALoadType::SharedV3;
loadInfo.isMMAv3Registers =
(mmaLoadType == MMALoadType::Registers) && warpGroupDot;
loadInfo.isMMAv3Shared = mmav3Shmem;

if (loadInfo.isMMAv3Shared || isTMALoad) {
if (mmav3Shmem || isTMALoad) {
loadInfo.sharedEncoding =
getSharedEncoding(&op, isTMALoad).value_or(nullptr);
} else if (loadInfo.isMMAv3Registers || dot) {
} else if (!mmav3Shmem || dot) {
bool incompatible = false;

loadInfo.sharedEncoding =
Expand All @@ -543,7 +544,9 @@ assignMemoryLayouts(scf::ForOp &forOp,

// If we still don't have a shared encoding, try a "generic" shared
// encoding.
if (!loadInfo.sharedEncoding && !isa<ttng::WarpGroupDotOp>(use)) {
if (!loadInfo.sharedEncoding) {
assert(!loadInfo.isMMAv3Shared &&
"For MMAv3 pipelining we should have shared encoding");
LDBG("try generic shared encoding");
loadInfo.sharedEncoding =
getSharedEncoding(&op, isTMALoad).value_or(nullptr);
Expand Down
33 changes: 13 additions & 20 deletions lib/Dialect/TritonGPU/Transforms/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1045,16 +1045,22 @@ getSharedEncIfAllUsersAreDotEnc(Value val, bool &incompatible) {
return attr;
}

MMALoadType getMMALoadType(Operation *loadOp) {
if (!loadOp->hasOneUse())
return MMALoadType::DoNotPipeline;
bool canUseMMAv3Pipelining(Operation *loadOp) {
Operation *user = *loadOp->getUsers().begin();
while (isa<triton::TransOp, triton::ReshapeOp>(user)) {
if (!user->hasOneUse())
return false;
user = *user->getUsers().begin();
}
if (!user)
return false;

if (auto alloc = dyn_cast<ttg::LocalAllocOp>(*loadOp->getUsers().begin())) {
if (auto alloc = dyn_cast<ttg::LocalAllocOp>(user)) {
auto sharedEnc =
dyn_cast<ttg::NVMMASharedEncodingAttr>(alloc.getType().getEncoding());

if (!sharedEnc)
return MMALoadType::DoNotPipeline;
return false;

// MMA V3 case.
SmallVector<unsigned> newOrder = getOrder(sharedEnc);
Expand All @@ -1065,22 +1071,9 @@ MMALoadType getMMALoadType(Operation *loadOp) {
// be changed after FuseTranspositions Pass. So we only pipeline the
// load if the order of the loaded BlockedEncoding is the same as the
// order of the SharedEncoding it is converted to.
return oldOrder == newOrder ? MMALoadType::SharedV3
: MMALoadType::DoNotPipeline;
} else if (auto cvt =
dyn_cast<ttg::ConvertLayoutOp>(*loadOp->getUsers().begin())) {
auto resTy = dyn_cast<RankedTensorType>(cvt->getResultTypes()[0]);
if (!resTy) {
return MMALoadType::DoNotPipeline;
}

if (isa<ttg::DotOperandEncodingAttr>(resTy.getEncoding())) {
return MMALoadType::Registers;
}

return MMALoadType::DoNotPipeline;
return oldOrder == newOrder;
} else {
return MMALoadType::DoNotPipeline;
return false;
}
}

Expand Down
Loading

0 comments on commit 1283b0e

Please sign in to comment.