Skip to content

Commit

Permalink
Fix computeOrigBitWidth
Browse files Browse the repository at this point in the history
  • Loading branch information
lezcano committed Feb 10, 2025
1 parent 8ab14ad commit 382c3aa
Showing 1 changed file with 8 additions and 23 deletions.
31 changes: 8 additions & 23 deletions lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -221,21 +221,10 @@ getWarpsPerTile(DotOp dotOp, const ArrayRef<int64_t> shape, int version,
static bool bwdFilter(Operation *op) {
return (op->hasTrait<OpTrait::Elementwise>() && isMemoryEffectFree(op)) ||
isa<BroadcastOp, ExpandDimsOp, ReshapeOp, TransOp, Fp4ToFpOp,
ConvertLayoutOp>(op);
ConvertLayoutOp, DotOp>(op);
}

// Finds the first different bitwidth in the chain of shape-preserving
// unary ops that x depends on.
// There are two primary scenarios:
// (1) Upcasting: A sequence such as loading an fp16, followed by arithmetic
// operations, then bitcasting to fp32, and finally computing in fp32.
// (2) Downcasting: This might involve loading an fp32, performing arithmetic
// operations, bitcasting to fp16, and finally computing in fp16.
// In the upcasting scenario, element reordering converts the original
// elements distribution to the order of higher precision primitives. As a
// result, kwidth can be the bitwidth of the lower precision primitive.
// Conversely, in the downcasting scenario, no reordering is performed,
// making it directly use the lower precision primitive.
// Finds the bitwidth with which the value x is loaded
static int computeOrigBitWidth(Value x) {
SetVector<Operation *> slice;
mlir::BackwardSliceOptions opt;
Expand All @@ -249,17 +238,13 @@ static int computeOrigBitWidth(Value x) {
if (llvm::any_of(slice, [](Operation *op) { return isa<Fp4ToFpOp>(op); }))
return 4;

int finalBitWidth = getElementTypeOrSelf(x).getIntOrFloatBitWidth();
int origBitWidth = finalBitWidth;
int origBitWidth = getElementTypeOrSelf(x).getIntOrFloatBitWidth();
for (auto op : slice) {
if (Value arg = op->getOperand(0))
if (auto argTy = dyn_cast<RankedTensorType>(arg.getType())) {
auto argBitWidth = argTy.getElementType().getIntOrFloatBitWidth();
if (argBitWidth != origBitWidth) {
origBitWidth = std::min<int>(origBitWidth, argBitWidth);
break;
}
}
if (isa<LoadOp, ExperimentalDescriptorLoadOp>(op)) {
origBitWidth = std::min<int>(
origBitWidth, cast<RankedTensorType>(op->getResult(0).getType())
.getElementTypeBitWidth());
}
}
return origBitWidth;
}
Expand Down

0 comments on commit 382c3aa

Please sign in to comment.