Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Scalarization Patterns for AtenToDtypeOp, AtenNegOp, AtenRemainderTensorOp #3861

Merged
merged 11 commits into from
Nov 12, 2024
26 changes: 24 additions & 2 deletions lib/Conversion/TorchToArith/TorchToArith.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,25 @@ class ConvertAtenBinaryOp : public OpConversionPattern<AtenOp> {
};
} // namespace

namespace {
class ConvertAtenNegIntOp : public OpConversionPattern<AtenNegIntOp> {
public:
using OpConversionPattern<AtenNegIntOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenNegIntOp op,
typename OpConversionPattern<AtenNegIntOp>::OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value a = adaptor.getA();
rewriter.replaceOpWithNewOp<arith::SubIOp>(
op,
rewriter.create<arith::ConstantIntOp>(op.getLoc(), /*value=*/0,
/*bitwidth=*/64),
a);
return success();
}
};
} // namespace

namespace {
template <typename AtenOp, typename UnaryOp>
class ConvertAtenUnaryOpToFloatMathOp : public OpConversionPattern<AtenOp> {
Expand Down Expand Up @@ -465,11 +484,14 @@ class ConvertTorchToArith

target.addIllegalOp<AtenAddOp>();
patterns.add<ConvertAtenAddOp>(typeConverter, context);

target.addIllegalOp<AtenNegIntOp>();
patterns.add<ConvertAtenNegIntOp>(typeConverter, context);
target.addIllegalOp<AtenAddIntOp, AtenAddFloatIntOp, AtenSubIntOp,
AtenMulIntOp>();
AtenMulIntOp, AtenRemainderIntOp>();
patterns.add<ConvertAtenBinaryOp<AtenAddIntOp, arith::AddIOp>>(
typeConverter, context);
patterns.add<ConvertAtenBinaryOp<AtenRemainderIntOp, arith::RemSIOp>>(
typeConverter, context);
patterns.add<ConvertAtenBinaryOp<AtenAddFloatIntOp, arith::AddFOp>>(
typeConverter, context);
patterns.add<ConvertAtenBinaryOp<AtenSubIntOp, arith::SubIOp>>(
Expand Down
4 changes: 4 additions & 0 deletions lib/Dialect/Torch/IR/TorchOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4068,6 +4068,10 @@ OpFoldResult AtenMulIntOp::fold(FoldAdaptor adaptor) {
int64_t lhs, rhs;
bool lConstant = matchPattern(getOperand(0), m_TorchConstantInt(&lhs));
bool rConstant = matchPattern(getOperand(1), m_TorchConstantInt(&rhs));
if (lConstant && lhs == 1)
return getOperand(1);
if (rConstant && rhs == 1)
return getOperand(0);
if ((lConstant && lhs == 0) || (rConstant && rhs == 0))
return getI64IntegerAttr(getContext(), 0);
if (lConstant && rConstant)
Expand Down
5 changes: 5 additions & 0 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4587,6 +4587,11 @@ class DecomposeAtenUnflattenIntOp
if (!isValidDim(dimInt, inputRank))
return rewriter.notifyMatchFailure(op, "dim is not a valid dim");

if (inputShape[dimInt] == Torch::kUnknownSize &&
llvm::count(sizesInts, -1) > 0)
return rewriter.notifyMatchFailure(
op, "Unimplemented: dynamic unflatten dim with an inferred size.");

SmallVector<Value> sizesTorchInt;
if (!getListConstructElements(op.getSizes(), sizesTorchInt))
return rewriter.notifyMatchFailure(
Expand Down
228 changes: 216 additions & 12 deletions lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -714,7 +714,7 @@ class PropagateAtenItemPattern : public OpRewritePattern<AtenItemOp> {
ImplicitLocOpBuilder b(op.getLoc(), rewriter);

// Rank 0 item op prop
if (selfTy.getSizes().size() == 0) {
if (selfTy.getSizes().empty()) {
auto numToTensor = self.getDefiningOp<Torch::PrimNumToTensorScalarOp>();
auto squeezeDim = self.getDefiningOp<AtenSqueezeDimOp>();
if (!squeezeDim && !numToTensor)
Expand Down Expand Up @@ -746,6 +746,109 @@ class PropagateAtenItemPattern : public OpRewritePattern<AtenItemOp> {
};
} // namespace

namespace {

LogicalResult convertOpFoldResults(ImplicitLocOpBuilder &b,
SmallVector<OpFoldResult> &converted,
SmallVector<OpFoldResult> &elements,
Type inputDtype, Type resultDtype) {
auto inputIsInt = dyn_cast<mlir::IntegerType>(inputDtype);
zjgarvey marked this conversation as resolved.
Show resolved Hide resolved
auto resultIsInt = dyn_cast<mlir::IntegerType>(resultDtype);
if (!inputIsInt && !isa<mlir::FloatType>(inputDtype))
return failure();
if (!resultIsInt && !isa<mlir::FloatType>(resultDtype))
return failure();

// if dtypes are both int or both float, no conversion needed
if (static_cast<bool>(inputIsInt) == static_cast<bool>(resultIsInt)) {
zjgarvey marked this conversation as resolved.
Show resolved Hide resolved
converted = elements;
return success();
}

if (resultIsInt) {
for (auto &e : elements) {
auto eValue = dyn_cast<Value>(e);
if (eValue) {
converted.push_back(b.createOrFold<AtenIntScalarOp>(eValue));
continue;
}
auto eAttr = dyn_cast<Attribute>(e);
auto eFloatAttr = dyn_cast_or_null<FloatAttr>(eAttr);
if (!eFloatAttr)
return failure();

converted.push_back(IntegerAttr::get(
resultDtype, static_cast<int64_t>(eFloatAttr.getValueAsDouble())));
}
return success();
}

// result is float
for (auto &e : elements) {
auto eValue = dyn_cast<Value>(e);
if (eValue) {
converted.push_back(b.createOrFold<AtenFloatScalarOp>(eValue));
continue;
}
auto eAttr = dyn_cast<Attribute>(e);
auto eIntAttr = dyn_cast<IntegerAttr>(eAttr);
if (!eIntAttr)
return failure();

auto eInt = (inputIsInt.isSigned()) ? eIntAttr.getValue().getSExtValue()
: eIntAttr.getValue().getZExtValue();
converted.push_back(FloatAttr::get(resultDtype, static_cast<double>(eInt)));
}
return success();
}

class PropagateAtenToDtypePattern : public OpRewritePattern<AtenToDtypeOp> {
public:
using OpRewritePattern<AtenToDtypeOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AtenToDtypeOp op,
PatternRewriter &rewriter) const override {
bool nonBlocking, copyArg;
// The non_blocking arg must be `False`.
if (!matchPattern(op.getNonBlocking(), m_TorchConstantBool(&nonBlocking)) ||
nonBlocking)
return failure();
// The copy arg must be `False`.
if (!matchPattern(op.getCopy(), m_TorchConstantBool(&copyArg)) || copyArg)
return failure();
// The memory_format arg must be `none`.
if (!isa<Torch::NoneType>(op.getMemoryFormat().getType()))
return failure();

auto inputType = dyn_cast<ValueTensorType>(op.getSelf().getType());
auto resultType = dyn_cast<ValueTensorType>(op.getType());
if (!inputType || !resultType || !inputType.hasDtype() ||
!resultType.hasDtype())
return failure();
auto inputDtype = inputType.getDtype();
auto resultDtype = resultType.getDtype();

SmallVector<OpFoldResult> elements;
if (failed(getListFromTensor(op.getSelf(), elements)))
return failure();

ImplicitLocOpBuilder b(op.getLoc(), rewriter);
SmallVector<OpFoldResult> converted;
if (failed(convertOpFoldResults(b, converted, elements, inputDtype,
resultDtype)))
return rewriter.notifyMatchFailure(
op, "Unhandled attribute type encountered.");

SmallVector<Value> vals;
if (failed(materializeFolds(b, converted, vals)))
return failure();

Value result = constructAtenTensorOpFromList(b, op.getType(), vals);
rewriter.replaceOp(op, result);
return success();
}
};
} // namespace

namespace {
template <typename AtenViewLikeOp>
class PropagateAtenViewLikePattern : public OpRewritePattern<AtenViewLikeOp> {
Expand Down Expand Up @@ -828,7 +931,7 @@ class PropagateAtenArithmeticPattern : public OpRewritePattern<OpTy> {
if (failed(materializeFolds(b, resultFolds, resultVals)))
return failure();

if (resultTy.getSizes().size() == 0) {
if (resultTy.getSizes().empty()) {
rewriter.replaceOpWithNewOp<Torch::PrimNumToTensorScalarOp>(
op, resultTy, resultVals.front());
return success();
Expand All @@ -841,6 +944,48 @@ class PropagateAtenArithmeticPattern : public OpRewritePattern<OpTy> {
};
} // namespace

namespace {
template <typename OpTy, typename ScalarOpTy>
class PropagateAtenUnaryPattern : public OpRewritePattern<OpTy> {
public:
using OpRewritePattern<OpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override {
// Check type
auto resultTy = cast<ValueTensorType>(op.getType());
if (resultTy.getSizes().size() > 1)
return rewriter.notifyMatchFailure(op, "unsupported: rank > 1");
if (!resultTy.hasDtype() || !isa<mlir::IntegerType>(resultTy.getDtype()))
return rewriter.notifyMatchFailure(op, "not an int type");

ImplicitLocOpBuilder b(op.getLoc(), rewriter);
SmallVector<OpFoldResult> selfFold;
if (failed(getListFromTensor(op.getSelf(), selfFold)))
return failure();
SmallVector<Value> selfVals;
if (failed(materializeFolds(b, selfFold, selfVals)))
return failure();
SmallVector<OpFoldResult> resultFolds;
for (uint64_t i = 0; i < selfVals.size(); i++) {
resultFolds.push_back(
b.createOrFold<ScalarOpTy>(selfVals[i].getType(), selfVals[i]));
}
SmallVector<Value> resultVals;
if (failed(materializeFolds(b, resultFolds, resultVals)))
return failure();

if (resultTy.getSizes().size() == 0) {
zjgarvey marked this conversation as resolved.
Show resolved Hide resolved
rewriter.replaceOpWithNewOp<Torch::PrimNumToTensorScalarOp>(
op, resultTy, resultVals.front());
return success();
}

Value result = constructAtenTensorOpFromList(b, resultTy, resultVals);
rewriter.replaceOp(op, result);
return success();
}
};
} // namespace
/// ------ Fold Patterns ------ ///
// These are shape-specific folding patterns

Expand Down Expand Up @@ -915,19 +1060,22 @@ class FoldAtenTensorSplatPattern : public OpRewritePattern<AtenTensorOp> {
auto resultTy = cast<BaseTensorType>(op.getType());
if (!resultTy.hasSizes() || !resultTy.areAllSizesKnown())
return rewriter.notifyMatchFailure(op, "dynamic output shape");
if (resultTy.getSizes().size() == 0) {
zjgarvey marked this conversation as resolved.
Show resolved Hide resolved
rewriter.replaceOpWithNewOp<Torch::PrimNumToTensorScalarOp>(
op, op.getType(), elements.front());
return success();
}

auto loc = op.getLoc();
SmallVector<Value> sizes;
for (auto size : resultTy.getSizes())
sizes.push_back(rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(size)));

Value one = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getType<Torch::IntType>(), 1);
Value sizeList = rewriter.create<Torch::PrimListConstructOp>(
loc,
rewriter.getType<Torch::ListType>(rewriter.getType<Torch::IntType>()),
one);
sizes);

Value none = rewriter.create<Torch::ConstantNoneOp>(loc);
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(loc, false);
Expand Down Expand Up @@ -1031,6 +1179,24 @@ class FoldAtenWhereSelf : public OpRewritePattern<AtenWhereSelfOp> {
};
} // namespace

namespace {
// fold ridiculous patterns like size.int -> float.scalar -> int.scalar
class FoldAtenIntScalarPattern : public OpRewritePattern<AtenIntScalarOp> {
public:
using OpRewritePattern<AtenIntScalarOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AtenIntScalarOp op,
PatternRewriter &rewriter) const override {
auto floatScalarOp = op.getA().getDefiningOp<AtenFloatScalarOp>();
if (!floatScalarOp)
return failure();
auto sizeOp = floatScalarOp.getA().getDefiningOp<AtenSizeIntOp>();
if (!sizeOp)
return failure();
rewriter.replaceOp(op, floatScalarOp.getA());
return success();
}
};
} // namespace
namespace {
class FoldAtenUnsqueezePattern : public OpRewritePattern<AtenUnsqueezeOp> {
public:
Expand Down Expand Up @@ -1182,8 +1348,29 @@ class CanonicalizeAtenViewPattern : public OpRewritePattern<AtenViewOp> {
if (inputUnmatched == 1 && outputUnmatched > 1) {
Value dimVal =
rewriter.create<Torch::ConstantIntOp>(op.getLoc(), leftMatchEnd);
ArrayRef<Value> unflattenSizes(viewSizes.begin() + leftMatchEnd,
viewSizes.end() - rightMatchEnd);
SmallVector<Value> unflattenSizes(viewSizes.begin() + leftMatchEnd,
viewSizes.end() - rightMatchEnd);
// try to convert a single dynamic size input to -1
int64_t dynCount = 0;
int64_t dynIdx = 0;
for (auto [i, v] : llvm::enumerate(unflattenSizes)) {
int64_t szeInt;
if (!matchPattern(v, m_TorchConstantInt(&szeInt))) {
dynCount++;
dynIdx = i;
continue;
}
// if we have a -1 already, make dynCount invalid and break
if (szeInt == -1) {
dynCount = -1;
break;
}
}
// if only one size is dynamic, make it -1
if (dynCount == 1)
unflattenSizes[dynIdx] =
rewriter.create<Torch::ConstantIntOp>(op.getLoc(), -1);

Value unflattenList = rewriter.create<Torch::PrimListConstructOp>(
op.getLoc(), op.getSize().getType(), unflattenSizes);
rewriter.replaceOpWithNewOp<AtenUnflattenIntOp>(
Expand Down Expand Up @@ -1227,6 +1414,18 @@ template <typename T> class RemoveUnusedPattern : public OpRewritePattern<T> {

namespace {

bool isItemForSliceOp(Operation *op) {
auto itemOp = dyn_cast_or_null<AtenItemOp>(op);
if (!itemOp)
return false;
for (OpOperand &use : op->getUses()) {
Operation *userOp = use.getOwner();
if (isa<AtenSliceTensorOp>(userOp))
return true;
}
return false;
}

bool isSourceOpForShapeScalarization(Operation *op) {
return llvm::isa<AtenSizeIntOp, Torch::ConstantIntOp, Torch::ConstantBoolOp,
Aten_ShapeAsTensorOp, Torch::ValueTensorLiteralOp>(op);
Expand All @@ -1244,7 +1443,7 @@ bool isPrimListOfInts(Operation *op) {

bool isAnchorOp(Operation *op) {
return isa<Torch::RuntimeAssertOp>(op) || isa<AtenArangeStartStepOp>(op) ||
isPrimListOfInts(op);
isPrimListOfInts(op) || isItemForSliceOp(op);
}

// The argument to this function, op, is the use of some source op, srcOp. If
Expand Down Expand Up @@ -1278,9 +1477,9 @@ bool isInvalidValidViewConsumer(Operation *op,
void populateScalarizationFoldPatterns(RewritePatternSet &patterns) {
patterns.insert<FoldAtenSqueezePattern<AtenSqueezeOp>,
FoldAtenSqueezePattern<AtenSqueezeDimOp>,
FoldAtenUnsqueezePattern, FoldAtenWhereSelf,
FoldAtenTensorSplatPattern, FoldAtenEqIntPattern>(
patterns.getContext());
FoldAtenIntScalarPattern, FoldAtenUnsqueezePattern,
FoldAtenWhereSelf, FoldAtenTensorSplatPattern,
FoldAtenEqIntPattern>(patterns.getContext());
}

void populateScalarizationCanonicalizePatterns(RewritePatternSet &patterns) {
Expand All @@ -1303,24 +1502,29 @@ void populateScalarizationPropagationPatterns(RewritePatternSet &patterns) {
PropagateAtenItemPattern, PropagateAtenShapeToTensorPattern,
PropagateAtenSliceTensorPattern, PropagateAtenEqTensorPattern,
PropagateAtenWhereSelfPattern, PropagateAtenBroadcastToPattern,
PropagateAtenTransposeIntPattern,
PropagateAtenTransposeIntPattern, PropagateAtenToDtypePattern,
PropagateAtenUnaryPattern<AtenNegOp, AtenNegIntOp>,
PropagateAtenArithmeticPattern<AtenAddTensorOp, AtenAddIntOp>,
PropagateAtenArithmeticPattern<AtenSubTensorOp, AtenSubIntOp>,
PropagateAtenArithmeticPattern<AtenMulTensorOp, AtenMulIntOp>,
PropagateAtenArithmeticPattern<AtenRemainderTensorOp, AtenRemainderIntOp>,
PropagateAtenArithmeticPattern<AtenDivTensorOp, AtenFloordivIntOp>>(
patterns.getContext());
}

void populateScalarizationRemovePatterns(RewritePatternSet &patterns) {
patterns.insert<RemoveUnusedPattern<Torch::AtenIntBoolOp>,
RemoveUnusedPattern<Torch::AtenEqIntOp>,
RemoveUnusedPattern<Torch::AtenToDtypeOp>,
RemoveUnusedPattern<Torch::PrimNumToTensorScalarOp>,
RemoveUnusedPattern<Torch::AtenFullOp>,
RemoveUnusedPattern<Torch::AtenUnsqueezeOp>,
RemoveUnusedPattern<Torch::AtenSqueezeDimOp>,
RemoveUnusedPattern<Torch::AtenSizeIntOp>,
RemoveUnusedPattern<Torch::AtenSliceTensorOp>,
RemoveUnusedPattern<Torch::AtenTensorOp>,
RemoveUnusedPattern<Torch::AtenFloatScalarOp>,
RemoveUnusedPattern<Torch::AtenIntScalarOp>,
RemoveUnusedPattern<Torch::PrimListConstructOp>>(
patterns.getContext());
}
Expand Down
Loading
Loading