From d790ba6c352abd797e0b59575a1dc6ce568f7b23 Mon Sep 17 00:00:00 2001 From: Razvan Lupusoru Date: Thu, 30 Jan 2025 15:19:11 -0800 Subject: [PATCH] [mlir][acc] Update LegalizeDataValues pass to allow MappableType With the addition of new type interface MappableType, the LegalizeDataValues should not make the assumption it can obtain a pointer to the data (aka acc::getVarPtr() is now not guaranteed to get a value - acc::getVar() must be used instead). Thus update the pass to ensure it handles any var used in its data clause operations. --- .../OpenACC/Transforms/LegalizeDataValues.cpp | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp b/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp index 026b309ce4969..a553653c73479 100644 --- a/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp +++ b/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp @@ -36,17 +36,17 @@ static bool insideAccComputeRegion(mlir::Operation *op) { return false; } -static void collectPtrs(mlir::ValueRange operands, +static void collectVars(mlir::ValueRange operands, llvm::SmallVector> &values, bool hostToDevice) { for (auto operand : operands) { - Value varPtr = acc::getVarPtr(operand.getDefiningOp()); - Value accPtr = acc::getAccPtr(operand.getDefiningOp()); - if (varPtr && accPtr) { + Value var = acc::getVar(operand.getDefiningOp()); + Value accVar = acc::getAccVar(operand.getDefiningOp()); + if (var && accVar) { if (hostToDevice) - values.push_back({varPtr, accPtr}); + values.push_back({var, accVar}); else - values.push_back({accPtr, varPtr}); + values.push_back({accVar, var}); } } } @@ -75,16 +75,16 @@ static void collectAndReplaceInRegion(Op &op, bool hostToDevice) { llvm::SmallVector> values; if constexpr (std::is_same_v) { - collectPtrs(op.getReductionOperands(), values, hostToDevice); - collectPtrs(op.getPrivateOperands(), values, hostToDevice); + collectVars(op.getReductionOperands(), values, hostToDevice); + collectVars(op.getPrivateOperands(), values, hostToDevice); } else { - collectPtrs(op.getDataClauseOperands(), values, hostToDevice); + collectVars(op.getDataClauseOperands(), values, hostToDevice); if constexpr (!std::is_same_v && !std::is_same_v && !std::is_same_v) { - collectPtrs(op.getReductionOperands(), values, hostToDevice); - collectPtrs(op.getPrivateOperands(), values, hostToDevice); - collectPtrs(op.getFirstprivateOperands(), values, hostToDevice); + collectVars(op.getReductionOperands(), values, hostToDevice); + collectVars(op.getPrivateOperands(), values, hostToDevice); + collectVars(op.getFirstprivateOperands(), values, hostToDevice); } }