Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into ravil/sched-hint-refa…
Browse files Browse the repository at this point in the history
…ctor
  • Loading branch information
antiagainst committed Feb 5, 2025
2 parents c5483bd + 87187d1 commit 5ffbdaf
Show file tree
Hide file tree
Showing 76 changed files with 2,833 additions and 1,297 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/wheels_v2.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ jobs:
docker container prune -f
- name: Checkout
uses: actions/checkout@v3
uses: actions/checkout@v4

# The LATEST_DATE here should be kept in sync with the one in Patch setup.py
- id: check-version
Expand Down
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ test-regression: all
test-interpret: all
cd python/test/unit && TRITON_INTERPRET=1 $(PYTEST) -s -n 16 -m interpreter cuda language/test_core.py language/test_standard.py \
language/test_random.py language/test_block_pointer.py language/test_subprocess.py language/test_line_info.py \
runtime/test_autotuner.py::test_kwargs[False] \
language/test_tuple.py runtime/test_autotuner.py::test_kwargs[False] \
../../tutorials/06-fused-attention.py::test_op --device=cpu

.PHONY: test-proton
Expand Down
4 changes: 2 additions & 2 deletions include/triton/Conversion/TritonGPUToLLVM/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -427,8 +427,8 @@ class SharedMemoryObject {
SmallVector<Value> getStrides(triton::gpu::MemDescType memDesc, Location loc,
RewriterBase &rewriter) const {
auto allocShape = memDesc.getAllocShape();
auto allocShapePerCTA =
triton::gpu::getShapePerCTA(memDesc.getEncoding(), allocShape);
auto allocShapePerCTA = triton::gpu::getAllocationShapePerCTA(
memDesc.getEncoding(), allocShape);
auto layoutOrder = triton::gpu::getOrder(memDesc.getEncoding());
auto allocStrides = SharedMemoryObject::getStridesForShape(
allocShapePerCTA, layoutOrder, loc, rewriter);
Expand Down
2 changes: 2 additions & 0 deletions include/triton/Dialect/Triton/IR/OpInterfaces.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ namespace impl {

LogicalResult verifyTransposeOpInterface(Operation *op);

LogicalResult verifyDotOpInterface(Operation *op);

} // namespace impl

} // namespace triton
Expand Down
47 changes: 0 additions & 47 deletions include/triton/Dialect/Triton/IR/Traits.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,53 +58,6 @@ class VerifyTensorLayoutsTrait
}
};

// Verify if the op is a dot-like operation.
// A dot-like operation should have three operands.
// The first two operands should share a common dimension, and the result
// should have the dimensions of the two operands that are not shared.
// A dot-like operation can be either 2d or 3d.
// In the 3d case, the first dimension of operands is the batch dimension.
template <class ConcreteType>
class DotLike : public TraitBase<ConcreteType, DotLike> {
public:
static LogicalResult verifyTrait(Operation *op) {
if (op->getNumOperands() < 3)
return op->emitOpError("expected at least 3 operands");
auto aTy = cast<ShapedType>(op->getOperand(0).getType());
auto bTy = cast<ShapedType>(op->getOperand(1).getType());
auto cTy = cast<ShapedType>(op->getOperand(2).getType());
auto aShape = aTy.getShape();
auto bShape = bTy.getShape();
auto cShape = cTy.getShape();
// Check if all 3d or all 2d
if (aShape.size() != 2 && aShape.size() != 3)
return op->emitOpError("expected operands to be 2d or 3d");
if (aShape.size() != bShape.size() || aShape.size() != cShape.size())
return op->emitOpError("expected all operands to have the same rank");
// Check if the first two operands share a common dimension
// TODO: enable back with an interface to support scaled dot.
// if (aShape[aShape.size() - 1] != bShape[aShape.size() - 2])
// return op->emitOpError("expected the last dimension of the first
// operand "
// "to be equal to the second-to-last dimension of
// " "the second operand");
// Check the batch dimension
if (aShape.size() == 3 &&
(aShape[0] != cShape[0] || bShape[0] != cShape[0]))
return op->emitOpError("expected the first dimension of the first "
"operand to be equal to the first dimension of "
"the result");
// Check the output shape
if (cShape[cShape.size() - 2] != aShape[aShape.size() - 2] ||
cShape[cShape.size() - 1] != bShape[aShape.size() - 1])
return op->emitOpError(
"expected the output shape to be the concatenation of the last "
"dimension of the first operand and the last dimension of the "
"second ");
return success();
}
};

template <typename ConcreteType>
class SameOperandsAndResultEncoding
: public TraitBase<ConcreteType, SameOperandsAndResultEncoding> {
Expand Down
1 change: 0 additions & 1 deletion include/triton/Dialect/Triton/IR/TritonInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ include "mlir/Interfaces/InferTypeOpInterface.td"

def TensorSizeTrait : NativeOpTrait<"TensorSizeTrait">;
def VerifyTensorLayoutsTrait : NativeOpTrait<"VerifyTensorLayoutsTrait">;
def DotLike : NativeOpTrait<"DotLike">;
def SameOperandsEncoding : NativeOpTrait<"SameOperandsEncoding">;
def SameOperandsAndResultEncoding : NativeOpTrait<"SameOperandsAndResultEncoding">;
def SameLoadStoreOperandsShape : NativeOpTrait<"SameLoadStoreOperandsShape">;
Expand Down
22 changes: 21 additions & 1 deletion include/triton/Dialect/Triton/IR/TritonOpInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,27 @@ def TransposeOpInterface : OpInterface<"TransposeOpInterface"> {
/*args=*/(ins)>
];

let verify = [{ return ::mlir::triton::impl::verifyTransposeOpInterface($_op); }];
let verify = [{ return ::mlir::triton::impl::verifyTransposeOpInterface($_op); }];
}

def DotOpInterface : OpInterface<"DotOpInterface"> {
let description = [{
This interface is implemented by operations that perform a dot product.
}];

let cppNamespace = "::mlir::triton";

let methods = [
InterfaceMethod<
/*desc=*/[{
Verifies the dimensions of the A and B DotOp operands.
}],
/*retType=*/"bool",
/*methodName=*/"verifyDims",
/*args=*/(ins)>
];

let verify = [{ return ::mlir::triton::impl::verifyDotOpInterface($_op); }];
}


Expand Down
4 changes: 2 additions & 2 deletions include/triton/Dialect/Triton/IR/TritonOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -631,7 +631,7 @@ def TT_GetNumProgramsOp : TT_Op<"get_num_programs", [Pure]> {
//
def TT_DotOp : TT_Op<"dot", [Pure,
DeclareOpInterfaceMethods<InferTypeOpInterface>,
DotLike,
DeclareOpInterfaceMethods<DotOpInterface>,
TypesMatchWith<"result's type matches accumulator's type",
"d", "c", "$_self">]> {
let summary = "dot";
Expand Down Expand Up @@ -671,7 +671,7 @@ def TT_DotOp : TT_Op<"dot", [Pure,
//
def TT_DotScaledOp : TT_Op<"dot_scaled", [Pure,
AttrSizedOperandSegments,
DotLike,
DeclareOpInterfaceMethods<DotOpInterface>,
TypesMatchWith<"result's type matches accumulator's type",
"d", "c", "$_self">]> {
let summary = "dot_scaled";
Expand Down
8 changes: 8 additions & 0 deletions include/triton/Dialect/TritonGPU/IR/Dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -165,11 +165,19 @@ SmallVector<unsigned> getCTAOrder(Attribute layout);
*/
SmallVector<unsigned> getShapePerCTATile(Attribute layout);

// Returns the "logical" shape per CTA
SmallVector<int64_t> getShapePerCTA(ArrayRef<unsigned> CTASplitNum,
ArrayRef<int64_t> shape);
SmallVector<int64_t> getShapePerCTA(Attribute layout, ArrayRef<int64_t> shape);
SmallVector<int64_t> getShapePerCTA(Type type);

// Returns the shape per CTA, which is "physically" allocated
// Such shapes may be bigger than the logical one due to, for example, padding
// in shared memory.
SmallVector<int64_t> getAllocationShapePerCTA(Attribute layout,
ArrayRef<int64_t> shape);
SmallVector<int64_t> getAllocationShapePerCTA(Type type);

unsigned getNumWarpsPerCTA(Attribute layout);

unsigned getNumCTAs(Attribute layout);
Expand Down
12 changes: 9 additions & 3 deletions include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -419,26 +419,32 @@ def NVMMASharedEncodingAttr :
https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-matrix-shared-memory-layout
}];


// fp4Padded: Indicates that this encoding represents a mixed-precision fp4 operand in MMAv5 scaled dot, which needs
// to be in the special padded layout as described in https://docs.nvidia.com/cuda/parallel-thread-execution/#packing-format-used-for-matrix-a-and-b-by-kind-mxf8f6f4-in-shared-memory
let parameters = (
ins
"unsigned":$swizzlingByteWidth,
"bool":$transposed,
"unsigned":$elementBitWidth,
"bool":$fp4Padded,
"CTALayoutAttr":$CTALayout
);

let builders = [
AttrBuilder<(ins "ArrayRef<int64_t>":$shape,
"ArrayRef<unsigned>":$order,
"CTALayoutAttr":$CTALayout,
"Type":$eltTy), [{
"Type":$eltTy,
"bool": $fp4Padded), [{
auto shapePerCTA = getShapePerCTA(CTALayout.getCTASplitNum(), shape);
int32_t swizzlingByteWidth = 0;
unsigned eleBitWidth = eltTy.getIntOrFloatBitWidth();
int packingFactor = fp4Padded ? 2 : 1;

// get proper shared memory swizzling mode from the contiguous dimension
// size of the origin blocked layout.
auto contigDimSizeInByte = shapePerCTA[order[0]] * eleBitWidth / 8;
auto contigDimSizeInByte = shapePerCTA[order[0]] * packingFactor * eleBitWidth / 8;
if (contigDimSizeInByte >= 128 && contigDimSizeInByte % 128 == 0) {
swizzlingByteWidth = 128;
} else if (contigDimSizeInByte >= 64 && contigDimSizeInByte % 64 == 0) {
Expand All @@ -449,7 +455,7 @@ def NVMMASharedEncodingAttr :
llvm_unreachable("unsupported shared memory layout for MMAv3");
}
bool transposed = order[0] == 0;
return $_get(context, swizzlingByteWidth, transposed, eleBitWidth, CTALayout);
return $_get(context, swizzlingByteWidth, transposed, eleBitWidth, fp4Padded, CTALayout);
}]>
];

Expand Down
8 changes: 1 addition & 7 deletions include/triton/Dialect/TritonGPU/Transforms/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -200,13 +200,7 @@ StringRef getAMDArch(Operation *module);
std::optional<mlir::triton::gpu::SwizzledSharedEncodingAttr>
getSharedEncIfAllUsersAreDotEnc(Value val, bool &incompatible);

enum class MMALoadType {
SharedV3,
Registers, // may be v2 or v3
DoNotPipeline, // could be a valid shared/registers MMA operand, but skip
// pipelining
};
MMALoadType getMMALoadType(Operation *loadOp);
bool canUseMMAv3Pipelining(Operation *loadOp);

// Convert \param op operands and results to layout \param encoding.
void convertOpEncoding(Attribute encoding, Operation *op);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ include "mlir/Dialect/Arith/IR/ArithBase.td"
include "triton/Dialect/Triton/IR/TritonTypes.td"
include "triton/Dialect/Triton/IR/TritonAttrDefs.td"
include "triton/Dialect/Triton/IR/TritonInterfaces.td"
include "triton/Dialect/Triton/IR/TritonOpInterfaces.td"
include "triton/Dialect/TritonGPU/IR/TritonGPUTypes.td"
include "triton/Dialect/TritonGPU/IR/TritonGPUTypeInterfaces.td"
include "mlir/IR/OpBase.td"
Expand Down Expand Up @@ -71,7 +72,7 @@ def TTNG_ClusterWaitOp : TTNG_Op<"cluster_wait", []> {
//
def TTNG_WarpGroupDotOp : TTNG_Op<"warp_group_dot", [DeclareOpInterfaceMethods<InferTypeOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
DotLike,
DeclareOpInterfaceMethods<DotOpInterface>,
TypesMatchWith<"result's type matches accumulator's type",
"d", "c", "$_self">]> {
let summary = "warp group dot";
Expand Down Expand Up @@ -325,7 +326,7 @@ def TTNG_TMAStoreWaitOp : TTNG_Op<"async_tma_store_wait"> {
let assemblyFormat = "attr-dict";
}

def TTNG_TCGen5MMAOp : TTNG_Op<"tc_gen5_mma", [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>, DotLike]> {
def TTNG_TCGen5MMAOp : TTNG_Op<"tc_gen5_mma", [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>, DeclareOpInterfaceMethods<DotOpInterface>]> {
let summary = "block level op mapping to tensorcore gen5 mma";

let description = [{
Expand All @@ -343,11 +344,12 @@ def TTNG_TCGen5MMAOp : TTNG_Op<"tc_gen5_mma", [DeclareOpInterfaceMethods<MemoryE
I1:$pred,
Optional<TTG_MemDescType>:$barrier,
OptionalAttr<UnitAttr>:$two_ctas);

// TODO: improve printing format.
let assemblyFormat = "$a`,` $b`,` $d`,` $useD`,` $pred (`,` $barrier^)? attr-dict `:` functional-type(operands, results)";
}

def TTNG_TCGen5MMAScaledOp : TTNG_Op<"tc_gen5_mma_scaled", [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>, DotLike]> {
def TTNG_TCGen5MMAScaledOp : TTNG_Op<"tc_gen5_mma_scaled", [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>, DeclareOpInterfaceMethods<DotOpInterface>]> {
let summary = "block level op mapping to tensorcore gen5 mma";

let description = [{
Expand All @@ -366,6 +368,7 @@ def TTNG_TCGen5MMAScaledOp : TTNG_Op<"tc_gen5_mma_scaled", [DeclareOpInterfaceMe
I1:$useD,
I1:$pred,
Optional<TTG_MemDescType>:$barrier);

// TODO: improve printing format.
let assemblyFormat = "$a `,` $b `,` $d `,` $a_scale `,` $b_scale `,` $useD`,` $pred `lhs` `=` $a_type `rhs` `=` $b_type (`,` $barrier^)? attr-dict `:` functional-type(operands, results)";
}
Expand Down
21 changes: 15 additions & 6 deletions include/triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,18 @@ mlir::LogicalResult createTMADesc(mlir::Value tmaPtr,

Value elemSizeVal = builder.template create<arith::ConstantOp>(
loc, builder.getI64Type(), builder.getI64IntegerAttr(elemSize));
Value globalStride = builder.template create<arith::MulIOp>(
loc, op.getStrides()[0], elemSizeVal);

SmallVector<Value> globalDim(llvm::reverse(op.getShape()));
SmallVector<Value> globalStride;
for (int k = op.getStrides().size() - 2; k >= 0; --k) {
globalStride.push_back(op.getStrides()[k]);
}

SmallVector<Value> elementStride(globalDim.size(), mkI32Constant(1));

for (int i = 0; i < globalStride.size(); ++i)
globalStride[i] = builder.template create<arith::MulIOp>(
loc, globalStride[i], elemSizeVal);

int elemTypeEnum;
switch (elemSize) {
Expand All @@ -75,15 +85,14 @@ mlir::LogicalResult createTMADesc(mlir::Value tmaPtr,
}
}

auto one = mkI32Constant(1);
builder.template create<triton::ExperimentalTensormapCreateOp>(
loc,
/*desc_ptr=*/tmaPtr,
/*global_address=*/op.getBase(),
/*box_dim=*/boxDim,
/*global_dim=*/ValueRange{op.getShape()[1], op.getShape()[0]},
/*global_stride=*/ValueRange{globalStride},
/*element_strides=*/ValueRange{one, one},
/*global_dim=*/globalDim,
/*global_stride=*/globalStride,
/*element_strides=*/elementStride,
/*elem_type*/ builder.getI32IntegerAttr(elemTypeEnum),
/*interleave_layout*/ builder.getI32IntegerAttr(0),
/*swizzle_mode=*/builder.getI32IntegerAttr(swizzle_mode),
Expand Down
5 changes: 2 additions & 3 deletions include/triton/Tools/LinearLayout.h
Original file line number Diff line number Diff line change
Expand Up @@ -683,9 +683,8 @@ class LinearLayout {
// Otherwise, R could map some tensor index that is not stored in S.
//
// One requirement we *don't* have is that S is injective; we allow two shmem
// offsets to hold the same 2D index. If S is not injective, there's
// ambiguity in which offset we choose for a given (lane, warp). For now we
// don't place any guarantees on the choices made by this function.
// offsets to hold the same 2D index. If S is not injective,
// the algorithm chooses the smallest offset for a given (lane, warp).
[[nodiscard]] LinearLayout invertAndCompose(const LinearLayout &outer) const;

// Get the layout that is the inverse of this layout.
Expand Down
2 changes: 1 addition & 1 deletion lib/Analysis/Allocation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ class AllocationAnalysis {
// Bytes could be a different value once we support padding or other
// allocation policies.
auto allocType = alloc.getType();
auto shapePerCTA = gpu::getShapePerCTA(allocType);
auto shapePerCTA = gpu::getAllocationShapePerCTA(allocType);
auto bytes = product<int64_t>(shapePerCTA) *
allocType.getElementTypeBitWidth() / 8;

Expand Down
2 changes: 1 addition & 1 deletion lib/Analysis/AxisInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -935,7 +935,7 @@ class ShROpAxisInfoVisitor final : public BinaryOpVisitorImpl<OpTy> {
// Treat [2^n,2^n+1,...]'s divisibility as 1 instead of 2^n
lhsDivisibility = 1;
}
return std::max<int64_t>(1, lhsDivisibility / (1 << shift));
return std::max<int64_t>(1, lhsDivisibility / (int64_t(1) << shift));
}

int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,
Expand Down
9 changes: 8 additions & 1 deletion lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,17 @@ Type TritonGPUToLLVMTypeConverter::convertTritonTensorType(
Type TritonGPUToLLVMTypeConverter::convertMemDescType(
MemDescType type, const TargetInfoBase &targetInfo) {
auto ctx = type.getContext();
SmallVector<Type, 4> types;
// base ptr
auto ptrType =
LLVM::LLVMPointerType::get(ctx, targetInfo.getSharedAddressSpace());

if (isa<triton::nvidia_gpu::TensorMemoryEncodingAttr,
triton::nvidia_gpu::TensorMemoryScalesEncodingAttr>(
type.getEncoding())) {
return ptrType;
}

SmallVector<Type, 4> types;
types.push_back(ptrType);
auto rank = type.getRank();
// offsets
Expand Down
Loading

0 comments on commit 5ffbdaf

Please sign in to comment.