From f20f7bd1037fc121022ea0155cc50e032287d921 Mon Sep 17 00:00:00 2001 From: zjgarvey Date: Thu, 7 Nov 2024 16:51:31 -0600 Subject: [PATCH 01/10] Add propagation for aten.to.dtype --- .../Torch/Transforms/ScalarizeShapes.cpp | 119 +++++++++++++++++- 1 file changed, 115 insertions(+), 4 deletions(-) diff --git a/lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp b/lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp index 9a85fbaa8646..5f81958b9604 100644 --- a/lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp +++ b/lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp @@ -746,6 +746,98 @@ class PropagateAtenItemPattern : public OpRewritePattern { }; } // namespace +namespace { + +LogicalResult convertOpFoldResults(ImplicitLocOpBuilder &b, + SmallVector &converted, + SmallVector &elements, + Type inputDtype, Type resultDtype) { + auto inputIsInt = dyn_cast(inputDtype); + auto resultIsInt = dyn_cast(resultDtype); + if (!inputIsInt && !isa(inputDtype)) + return failure(); + if (!resultIsInt && !isa(resultDtype)) + return failure(); + // if dtypes are both int or both float, no conversion needed + if (static_cast(inputIsInt) == static_cast(resultIsInt)) { + converted = elements; + return success(); + } + for (auto e : elements) { + auto eValue = dyn_cast(e); + if (eValue && resultIsInt) { + converted.push_back(b.createOrFold(eValue)); + continue; + } + if (eValue && !resultIsInt) { + converted.push_back(b.createOrFold(eValue)); + continue; + } + auto eAttr = dyn_cast(e); + if (auto eIntAttr = dyn_cast_or_null(eAttr)) { + auto eInt = (inputIsInt.isSigned()) ? eIntAttr.getValue().getSExtValue() + : eIntAttr.getValue().getZExtValue(); + converted.push_back(FloatAttr::get(cast(resultDtype), + static_cast(eInt))); + continue; + } + if (auto eFloatAttr = dyn_cast_or_null(eAttr)) { + converted.push_back(IntegerAttr::get( + resultDtype, static_cast(eFloatAttr.getValueAsDouble()))); + continue; + } + return failure(); + } + return success(); +} + +class PropagateAtenToDtypePattern : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenToDtypeOp op, + PatternRewriter &rewriter) const override { + bool nonBlocking, copyArg; + // The non_blocking arg must be `False`. + if (!matchPattern(op.getNonBlocking(), m_TorchConstantBool(&nonBlocking)) || + nonBlocking) + return failure(); + // The copy arg must be `False`. + if (!matchPattern(op.getCopy(), m_TorchConstantBool(©Arg)) || copyArg) + return failure(); + // The memory_format arg must be `none`. + if (!isa(op.getMemoryFormat().getType())) + return failure(); + + auto inputType = dyn_cast(op.getSelf().getType()); + auto resultType = dyn_cast(op.getType()); + if (!inputType || !resultType || !inputType.hasDtype() || + !resultType.hasDtype()) + return failure(); + auto inputDtype = inputType.getDtype(); + auto resultDtype = resultType.getDtype(); + + SmallVector elements; + if (failed(getListFromTensor(op.getSelf(), elements))) + return failure(); + + ImplicitLocOpBuilder b(op.getLoc(), rewriter); + SmallVector converted; + if (failed(convertOpFoldResults(b, converted, elements, inputDtype, + resultDtype))) + return rewriter.notifyMatchFailure( + op, "Unhandled attribute type encountered."); + + SmallVector vals; + if (failed(materializeFolds(b, converted, vals))) + return failure(); + + Value result = constructAtenTensorOpFromList(b, op.getType(), vals); + rewriter.replaceOp(op, result); + return success(); + } +}; +} // namespace + namespace { template class PropagateAtenViewLikePattern : public OpRewritePattern { @@ -1031,6 +1123,24 @@ class FoldAtenWhereSelf : public OpRewritePattern { }; } // namespace +namespace { +// fold ridiculous patterns like size.int -> float.scalar -> int.scalar +class FoldAtenIntScalarPattern : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenIntScalarOp op, + PatternRewriter &rewriter) const override { + auto floatScalarOp = op.getA().getDefiningOp(); + if (!floatScalarOp) + return failure(); + auto sizeOp = floatScalarOp.getA().getDefiningOp(); + if (!sizeOp) + return failure(); + rewriter.replaceOp(op, floatScalarOp.getA()); + return success(); + } +}; +} // namespace namespace { class FoldAtenUnsqueezePattern : public OpRewritePattern { public: @@ -1263,9 +1373,9 @@ bool isInvalidValidViewConsumer(Operation *op, void populateScalarizationFoldPatterns(RewritePatternSet &patterns) { patterns.insert, FoldAtenSqueezePattern, - FoldAtenUnsqueezePattern, FoldAtenWhereSelf, - FoldAtenTensorSplatPattern, FoldAtenEqIntPattern>( - patterns.getContext()); + FoldAtenIntScalarPattern, FoldAtenUnsqueezePattern, + FoldAtenWhereSelf, FoldAtenTensorSplatPattern, + FoldAtenEqIntPattern>(patterns.getContext()); } void populateScalarizationCanonicalizePatterns(RewritePatternSet &patterns) { @@ -1288,7 +1398,7 @@ void populateScalarizationPropagationPatterns(RewritePatternSet &patterns) { PropagateAtenItemPattern, PropagateAtenShapeToTensorPattern, PropagateAtenSliceTensorPattern, PropagateAtenEqTensorPattern, PropagateAtenWhereSelfPattern, PropagateAtenBroadcastToPattern, - PropagateAtenTransposeIntPattern, + PropagateAtenTransposeIntPattern, PropagateAtenToDtypePattern, PropagateAtenArithmeticPattern, PropagateAtenArithmeticPattern, PropagateAtenArithmeticPattern, @@ -1299,6 +1409,7 @@ void populateScalarizationPropagationPatterns(RewritePatternSet &patterns) { void populateScalarizationRemovePatterns(RewritePatternSet &patterns) { patterns.insert, RemoveUnusedPattern, + RemoveUnusedPattern, RemoveUnusedPattern, RemoveUnusedPattern, RemoveUnusedPattern, From 78734305154406d8202643b6ecb91c6e7bba44a1 Mon Sep 17 00:00:00 2001 From: zjgarvey Date: Fri, 8 Nov 2024 19:17:15 -0600 Subject: [PATCH 02/10] Add scalarization for NegTensor and RemainderTensor Add canonicalization to remove complex dynamic shape calculations from unflatten ops whenever these could just be -1 Add lowering to arith for neg.int and remainder.int ops --- lib/Conversion/TorchToArith/TorchToArith.cpp | 23 +++++++++- lib/Dialect/Torch/IR/TorchOps.cpp | 31 +++++++++++++ .../Torch/Transforms/ScalarizeShapes.cpp | 44 +++++++++++++++++++ 3 files changed, 96 insertions(+), 2 deletions(-) diff --git a/lib/Conversion/TorchToArith/TorchToArith.cpp b/lib/Conversion/TorchToArith/TorchToArith.cpp index 143b46694030..8e28c2f2ca0f 100644 --- a/lib/Conversion/TorchToArith/TorchToArith.cpp +++ b/lib/Conversion/TorchToArith/TorchToArith.cpp @@ -82,6 +82,22 @@ class ConvertAtenBinaryOp : public OpConversionPattern { }; } // namespace +namespace { +class ConvertAtenNegIntOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(AtenNegIntOp op, + typename OpConversionPattern::OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value a = adaptor.getA(); + rewriter.replaceOpWithNewOp( + op, a, rewriter.create(op.getLoc(), -1, 64)); + return success(); + } +}; +} // namespace + namespace { template class ConvertAtenUnaryOpToFloatMathOp : public OpConversionPattern { @@ -465,11 +481,14 @@ class ConvertTorchToArith target.addIllegalOp(); patterns.add(typeConverter, context); - + target.addIllegalOp(); + patterns.add(typeConverter, context); target.addIllegalOp(); + AtenMulIntOp, AtenRemainderIntOp>(); patterns.add>( typeConverter, context); + patterns.add>( + typeConverter, context); patterns.add>( typeConverter, context); patterns.add>( diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 3774e65f0859..fee2d9c086c5 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -2274,6 +2274,33 @@ OpFoldResult AtenUnflattenIntOp::fold(FoldAdaptor adaptor) { void AtenUnflattenIntOp::getCanonicalizationPatterns( RewritePatternSet &patterns, MLIRContext *context) { + // if all of the sizes are constant except one, make it -1 to improve shape + // inference. + patterns.add(+[](AtenUnflattenIntOp op, PatternRewriter &rewriter) { + SmallVector sizeValues; + if (!getListConstructElements(op.getSizes(), sizeValues)) + return rewriter.notifyMatchFailure(op, + "sizes must come from list construct"); + int64_t nonConstantDims = 0; + int64_t nonConstantPos = -1; + for (auto [i, val] : llvm::enumerate(sizeValues)) { + int64_t dimSize; + bool isConstant = matchPattern(val, m_TorchConstantInt(&dimSize)) && + (dimSize != Torch::kUnknownSize); + nonConstantDims += static_cast(!isConstant); + if (!isConstant) + nonConstantPos = i; + } + if (nonConstantDims != 1) + return failure(); + sizeValues[nonConstantPos] = + rewriter.create(op.getLoc(), -1); + Value list = rewriter.create( + op.getLoc(), op.getSizes().getType(), sizeValues); + rewriter.replaceOpWithNewOp( + op, op.getType(), op.getSelf(), op.getDim(), list); + return success(); + }); // if there are only two sizes and one of them is statically 1, then convert // to an unqueeze. patterns.add(+[](AtenUnflattenIntOp op, PatternRewriter &rewriter) { @@ -4068,6 +4095,10 @@ OpFoldResult AtenMulIntOp::fold(FoldAdaptor adaptor) { int64_t lhs, rhs; bool lConstant = matchPattern(getOperand(0), m_TorchConstantInt(&lhs)); bool rConstant = matchPattern(getOperand(1), m_TorchConstantInt(&rhs)); + if ((lConstant && lhs == 1)) + return getOperand(1); + if ((rConstant && rhs == 1)) + return getOperand(0); if ((lConstant && lhs == 0) || (rConstant && rhs == 0)) return getI64IntegerAttr(getContext(), 0); if (lConstant && rConstant) diff --git a/lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp b/lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp index 7e35b14eeef2..6d0e0831ae8f 100644 --- a/lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp +++ b/lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp @@ -933,6 +933,48 @@ class PropagateAtenArithmeticPattern : public OpRewritePattern { }; } // namespace +namespace { +template +class PropagateAtenUnaryPattern : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(OpTy op, + PatternRewriter &rewriter) const override { + // Check type + auto resultTy = cast(op.getType()); + if (resultTy.getSizes().size() > 1) + return rewriter.notifyMatchFailure(op, "unsupported: rank > 1"); + if (!resultTy.hasDtype() || !isa(resultTy.getDtype())) + return rewriter.notifyMatchFailure(op, "not an int type"); + + ImplicitLocOpBuilder b(op.getLoc(), rewriter); + SmallVector selfFold; + if (failed(getListFromTensor(op.getSelf(), selfFold))) + return failure(); + SmallVector selfVals; + if (failed(materializeFolds(b, selfFold, selfVals))) + return failure(); + SmallVector resultFolds; + for (uint64_t i = 0; i < selfVals.size(); i++) { + resultFolds.push_back( + b.createOrFold(selfVals[i].getType(), selfVals[i])); + } + SmallVector resultVals; + if (failed(materializeFolds(b, resultFolds, resultVals))) + return failure(); + + if (resultTy.getSizes().size() == 0) { + rewriter.replaceOpWithNewOp( + op, resultTy, resultVals.front()); + return success(); + } + + Value result = constructAtenTensorOpFromList(b, resultTy, resultVals); + rewriter.replaceOp(op, result); + return success(); + } +}; +} // namespace /// ------ Fold Patterns ------ /// // These are shape-specific folding patterns @@ -1414,9 +1456,11 @@ void populateScalarizationPropagationPatterns(RewritePatternSet &patterns) { PropagateAtenSliceTensorPattern, PropagateAtenEqTensorPattern, PropagateAtenWhereSelfPattern, PropagateAtenBroadcastToPattern, PropagateAtenTransposeIntPattern, PropagateAtenToDtypePattern, + PropagateAtenUnaryPattern, PropagateAtenArithmeticPattern, PropagateAtenArithmeticPattern, PropagateAtenArithmeticPattern, + PropagateAtenArithmeticPattern, PropagateAtenArithmeticPattern>( patterns.getContext()); } From 9c7e3aadfd51c2bffd7f9f55a48feb66ee8398e7 Mon Sep 17 00:00:00 2001 From: zjgarvey Date: Fri, 8 Nov 2024 20:12:35 -0600 Subject: [PATCH 03/10] fix failing lit tests --- test/Dialect/Torch/decompose-complex-ops.mlir | 6 +++--- test/Dialect/Torch/scalarize-shapes.mlir | 12 +++++------- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/test/Dialect/Torch/decompose-complex-ops.mlir b/test/Dialect/Torch/decompose-complex-ops.mlir index f938a2637835..4da482af03f3 100644 --- a/test/Dialect/Torch/decompose-complex-ops.mlir +++ b/test/Dialect/Torch/decompose-complex-ops.mlir @@ -105,9 +105,9 @@ func.func @torch.aten.fake_quantize_per_channel_affine_cachemask(%arg0: !torch.v // CHECK-LABEL: test_einsum_inner_prod func.func @test_einsum_inner_prod(%arg0: !torch.vtensor<[5],f64>, %arg1: !torch.vtensor<[5],f64>) -> !torch.vtensor<[],f64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 17 : si64} { - // CHECK: %[[INT5:.+]] = torch.constant.int 5 - // CHECK: %[[INT1:.+]] = torch.constant.int 1 - // CHECK: %[[INT0:.+]] = torch.constant.int 0 + // CHECK-DAG: %[[INT5:.+]] = torch.constant.int 5 + // CHECK-DAG: %[[INT1:.+]] = torch.constant.int 1 + // CHECK-DAG: %[[INT0:.+]] = torch.constant.int 0 // CHECK: %[[LHS_LIST:.+]] = torch.prim.ListConstruct %[[INT0]] // CHECK: %[[LHS_PERM:.+]] = torch.aten.permute %arg0, %[[LHS_LIST]] // CHECK: %[[RHS_LIST:.+]] = torch.prim.ListConstruct %[[INT0]] diff --git a/test/Dialect/Torch/scalarize-shapes.mlir b/test/Dialect/Torch/scalarize-shapes.mlir index 5ea715735c70..edf36c2cae2d 100644 --- a/test/Dialect/Torch/scalarize-shapes.mlir +++ b/test/Dialect/Torch/scalarize-shapes.mlir @@ -89,14 +89,12 @@ func.func @arith_prop(%arg0 : !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?] // CHECK: %[[x2:.*]] = torch.aten.floordiv.int %[[x0]], %[[int12]] : !torch.int, !torch.int -> !torch.int // CHECK: %[[x3:.*]] = torch.aten.floordiv.int %[[x1]], %[[int1_0]] : !torch.int, !torch.int -> !torch.int // CHECK: %[[int12_1:.*]] = torch.constant.int 12 - // CHECK: %[[int1_2:.*]] = torch.constant.int 1 // CHECK: %[[x4:.*]] = torch.aten.mul.int %[[x2]], %[[int12_1]] : !torch.int, !torch.int -> !torch.int - // CHECK: %[[x5:.*]] = torch.aten.mul.int %[[x3]], %[[int1_2]] : !torch.int, !torch.int -> !torch.int - // CHECK: %[[x6:.*]] = torch.aten.sub.int %[[x0]], %[[x4]] : !torch.int, !torch.int -> !torch.int - // CHECK: %[[x7:.*]] = torch.aten.sub.int %[[x1]], %[[x5]] : !torch.int, !torch.int -> !torch.int - // CHECK: %[[x8:.*]] = torch.prim.ListConstruct %[[x7]], %[[x6]] : (!torch.int, !torch.int) -> !torch.list - // CHECK: %[[x9:.*]] = torch.aten.constant_pad_nd %arg0, %[[x8]], %[[float0]] : !torch.vtensor<[?,?],f32>, !torch.list, !torch.float -> !torch.vtensor<[?,?],f32> - // CHECK: return %[[x9]] : !torch.vtensor<[?,?],f32> + // CHECK: %[[x5:.*]] = torch.aten.sub.int %[[x0]], %[[x4]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[x6:.*]] = torch.aten.sub.int %[[x1]], %[[x3]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[x7:.*]] = torch.prim.ListConstruct %[[x6]], %[[x5]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[x8:.*]] = torch.aten.constant_pad_nd %arg0, %[[x7]], %[[float0]] : !torch.vtensor<[?,?],f32>, !torch.list, !torch.float -> !torch.vtensor<[?,?],f32> + // CHECK: return %[[x8]] : !torch.vtensor<[?,?],f32> %0 = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> %1 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> %float0.000000e00 = torch.constant.float 0.000000e+00 From db19a25e568c4449928c380b6b4133665578a4a5 Mon Sep 17 00:00:00 2001 From: zjgarvey Date: Fri, 8 Nov 2024 20:19:15 -0600 Subject: [PATCH 04/10] don't do something dumb --- lib/Dialect/Torch/IR/TorchOps.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index fee2d9c086c5..bfc089dce929 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -2285,8 +2285,10 @@ void AtenUnflattenIntOp::getCanonicalizationPatterns( int64_t nonConstantPos = -1; for (auto [i, val] : llvm::enumerate(sizeValues)) { int64_t dimSize; - bool isConstant = matchPattern(val, m_TorchConstantInt(&dimSize)) && - (dimSize != Torch::kUnknownSize); + bool isConstant = matchPattern(val, m_TorchConstantInt(&dimSize)); + if (isConstant && dimSize == -1) { + return failure(); + } nonConstantDims += static_cast(!isConstant); if (!isConstant) nonConstantPos = i; From 441beaaf5da0ccb0fafd3345d32bc4ce7b87189b Mon Sep 17 00:00:00 2001 From: zjgarvey Date: Mon, 11 Nov 2024 11:25:27 -0600 Subject: [PATCH 05/10] add better handling of rank 0 tensors; fix a bug in unflatten decomp --- lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp | 5 +++++ lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp | 9 ++++++--- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index aa15e3735dae..8410191bd5d7 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -4587,6 +4587,11 @@ class DecomposeAtenUnflattenIntOp if (!isValidDim(dimInt, inputRank)) return rewriter.notifyMatchFailure(op, "dim is not a valid dim"); + if (inputShape[dimInt] == Torch::kUnknownSize && + llvm::count(sizesInts, -1) > 0) + return rewriter.notifyMatchFailure( + op, "Unimplmented: dynamic unflatten dim with an inferred size."); + SmallVector sizesTorchInt; if (!getListConstructElements(op.getSizes(), sizesTorchInt)) return rewriter.notifyMatchFailure( diff --git a/lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp b/lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp index 6d0e0831ae8f..9a452f1909c4 100644 --- a/lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp +++ b/lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp @@ -1049,6 +1049,11 @@ class FoldAtenTensorSplatPattern : public OpRewritePattern { auto resultTy = cast(op.getType()); if (!resultTy.hasSizes() || !resultTy.areAllSizesKnown()) return rewriter.notifyMatchFailure(op, "dynamic output shape"); + if (resultTy.getSizes().size() == 0) { + rewriter.replaceOpWithNewOp( + op, op.getType(), elements.front()); + return success(); + } auto loc = op.getLoc(); SmallVector sizes; @@ -1056,12 +1061,10 @@ class FoldAtenTensorSplatPattern : public OpRewritePattern { sizes.push_back(rewriter.create( loc, rewriter.getI64IntegerAttr(size))); - Value one = rewriter.create( - loc, rewriter.getType(), 1); Value sizeList = rewriter.create( loc, rewriter.getType(rewriter.getType()), - one); + sizes); Value none = rewriter.create(loc); Value cstFalse = rewriter.create(loc, false); From 0c080d20797c033c1f5a9b0d0bb21f6a154d63b1 Mon Sep 17 00:00:00 2001 From: zjgarvey Date: Mon, 11 Nov 2024 12:21:57 -0600 Subject: [PATCH 06/10] add item ops to worklist if they get consumed by slice ops --- lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp b/lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp index 9a452f1909c4..81d6893139cb 100644 --- a/lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp +++ b/lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp @@ -1382,6 +1382,18 @@ template class RemoveUnusedPattern : public OpRewritePattern { namespace { +bool isItemForSliceOp(Operation *op) { + auto itemOp = dyn_cast_or_null(op); + if (!itemOp) + return false; + for (OpOperand &use : op->getUses()) { + Operation *userOp = use.getOwner(); + if (isa(userOp)) + return true; + } + return false; +} + bool isSourceOpForShapeScalarization(Operation *op) { return llvm::isa(op); @@ -1399,7 +1411,7 @@ bool isPrimListOfInts(Operation *op) { bool isAnchorOp(Operation *op) { return isa(op) || isa(op) || - isPrimListOfInts(op); + isPrimListOfInts(op) || isItemForSliceOp(op); } // The argument to this function, op, is the use of some source op, srcOp. If From 5bea9eb6321c3b947e17a11b96752a789b2f956c Mon Sep 17 00:00:00 2001 From: zjgarvey Date: Mon, 11 Nov 2024 12:33:03 -0600 Subject: [PATCH 07/10] move -1 unflatten inference to scalarize shapes --- lib/Dialect/Torch/IR/TorchOps.cpp | 29 ------------------- .../Torch/Transforms/ScalarizeShapes.cpp | 25 ++++++++++++++-- 2 files changed, 23 insertions(+), 31 deletions(-) diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index bfc089dce929..ca81f780c3a6 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -2274,35 +2274,6 @@ OpFoldResult AtenUnflattenIntOp::fold(FoldAdaptor adaptor) { void AtenUnflattenIntOp::getCanonicalizationPatterns( RewritePatternSet &patterns, MLIRContext *context) { - // if all of the sizes are constant except one, make it -1 to improve shape - // inference. - patterns.add(+[](AtenUnflattenIntOp op, PatternRewriter &rewriter) { - SmallVector sizeValues; - if (!getListConstructElements(op.getSizes(), sizeValues)) - return rewriter.notifyMatchFailure(op, - "sizes must come from list construct"); - int64_t nonConstantDims = 0; - int64_t nonConstantPos = -1; - for (auto [i, val] : llvm::enumerate(sizeValues)) { - int64_t dimSize; - bool isConstant = matchPattern(val, m_TorchConstantInt(&dimSize)); - if (isConstant && dimSize == -1) { - return failure(); - } - nonConstantDims += static_cast(!isConstant); - if (!isConstant) - nonConstantPos = i; - } - if (nonConstantDims != 1) - return failure(); - sizeValues[nonConstantPos] = - rewriter.create(op.getLoc(), -1); - Value list = rewriter.create( - op.getLoc(), op.getSizes().getType(), sizeValues); - rewriter.replaceOpWithNewOp( - op, op.getType(), op.getSelf(), op.getDim(), list); - return success(); - }); // if there are only two sizes and one of them is statically 1, then convert // to an unqueeze. patterns.add(+[](AtenUnflattenIntOp op, PatternRewriter &rewriter) { diff --git a/lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp b/lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp index 81d6893139cb..afb2967347ab 100644 --- a/lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp +++ b/lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp @@ -1337,8 +1337,29 @@ class CanonicalizeAtenViewPattern : public OpRewritePattern { if (inputUnmatched == 1 && outputUnmatched > 1) { Value dimVal = rewriter.create(op.getLoc(), leftMatchEnd); - ArrayRef unflattenSizes(viewSizes.begin() + leftMatchEnd, - viewSizes.end() - rightMatchEnd); + SmallVector unflattenSizes(viewSizes.begin() + leftMatchEnd, + viewSizes.end() - rightMatchEnd); + // try to convert a single dynamic size input to -1 + int64_t dynCount = 0; + int64_t dynIdx = 0; + for (auto [i, v] : llvm::enumerate(unflattenSizes)) { + int64_t szeInt; + if (!matchPattern(v, m_TorchConstantInt(&szeInt))) { + dynCount++; + dynIdx = i; + continue; + } + // if we a -1 already, make dynCount invalid and break + if (szeInt == -1) { + dynCount = -1; + break; + } + } + // if only one size is dynamic, make it -1 + if (dynCount == 1) + unflattenSizes[dynIdx] = + rewriter.create(op.getLoc(), -1); + Value unflattenList = rewriter.create( op.getLoc(), op.getSize().getType(), unflattenSizes); rewriter.replaceOpWithNewOp( From 5a5adcd5d4b57f4380f24be436b2a2e5c0d4137f Mon Sep 17 00:00:00 2001 From: zjgarvey Date: Mon, 11 Nov 2024 13:08:20 -0600 Subject: [PATCH 08/10] fix lit test --- test/Dialect/Torch/scalarize-shapes.mlir | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/test/Dialect/Torch/scalarize-shapes.mlir b/test/Dialect/Torch/scalarize-shapes.mlir index edf36c2cae2d..60c23274f7c9 100644 --- a/test/Dialect/Torch/scalarize-shapes.mlir +++ b/test/Dialect/Torch/scalarize-shapes.mlir @@ -27,12 +27,8 @@ func.func @shape_as_tensor(%arg0 : !torch.vtensor<[5,?,?],f32>) -> !torch.vtenso // CHECK-LABEL: @shape_as_tensor_dim func.func @shape_as_tensor_dim(%arg0 : !torch.vtensor<[5,?,?],f32>) -> !torch.vtensor<[],si32> { // CHECK: %[[INT1:.+]] = torch.constant.int 1 - // CHECK: %[[SZ:.+]] = torch.aten.size.int %arg0, %[[INT1]] - // CHECK: %[[INT1_0:.+]] = torch.constant.int 1 - // CHECK-DAG: %[[FALSE:.+]] = torch.constant.bool false - // CHECK-DAG: %[[NONE:.+]] = torch.constant.none - // CHECK-DAG: %[[LIST:.+]] = torch.prim.ListConstruct %[[INT1_0]] - // CHECK: %[[TENSOR:.+]] = torch.aten.full %[[LIST]], %[[SZ]], %[[NONE]], %[[NONE]], %[[NONE]], %[[FALSE]] + // CHECK-DAG: %[[SZ:.+]] = torch.aten.size.int %arg0, %[[INT1]] + // CHECK: %[[TENSOR:.+]] = torch.prim.NumToTensor.Scalar %[[SZ]] : !torch.int -> !torch.vtensor<[],si32> // CHECK: return %[[TENSOR]] : !torch.vtensor<[],si32> %shape = torch.aten._shape_as_tensor %arg0 : !torch.vtensor<[5,?,?],f32> -> !torch.vtensor<[3],si32> %dim = torch.constant.int 0 From 28299c438d31994a23ddac95de393df2f61b6506 Mon Sep 17 00:00:00 2001 From: zjgarvey Date: Mon, 11 Nov 2024 13:31:21 -0600 Subject: [PATCH 09/10] add some tests and remove unused scalar cast ops --- .../Torch/Transforms/ScalarizeShapes.cpp | 2 + test/Dialect/Torch/scalarize-shapes.mlir | 43 +++++++++++++++++++ 2 files changed, 45 insertions(+) diff --git a/lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp b/lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp index afb2967347ab..efdc5aea1eac 100644 --- a/lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp +++ b/lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp @@ -1512,6 +1512,8 @@ void populateScalarizationRemovePatterns(RewritePatternSet &patterns) { RemoveUnusedPattern, RemoveUnusedPattern, RemoveUnusedPattern, + RemoveUnusedPattern, + RemoveUnusedPattern, RemoveUnusedPattern>( patterns.getContext()); } diff --git a/test/Dialect/Torch/scalarize-shapes.mlir b/test/Dialect/Torch/scalarize-shapes.mlir index 60c23274f7c9..c7fc2c280a2b 100644 --- a/test/Dialect/Torch/scalarize-shapes.mlir +++ b/test/Dialect/Torch/scalarize-shapes.mlir @@ -39,6 +39,49 @@ func.func @shape_as_tensor_dim(%arg0 : !torch.vtensor<[5,?,?],f32>) -> !torch.vt return %select : !torch.vtensor<[],si32> } +// ----- + +// CHECK-LABEL: @cast_int_int +func.func @cast_int_int(%arg0 : !torch.vtensor<[5,?,?],f32>) -> !torch.vtensor<[],si64> { + // CHECK: %[[I1:.*]] = torch.constant.int 1 + // CHECK: %[[SZE:.*]] = torch.aten.size.int %arg0, %[[I1]] : !torch.vtensor<[5,?,?],f32>, !torch.int -> !torch.int + // CHECK: %[[TENSOR:.*]] = torch.prim.NumToTensor.Scalar %[[SZE]] : !torch.int -> !torch.vtensor<[],si64> + // CHECK: return %[[TENSOR]] : !torch.vtensor<[],si64> + %int4 = torch.constant.int 4 + %false = torch.constant.bool false + %none = torch.constant.none + %shape = torch.aten._shape_as_tensor %arg0 : !torch.vtensor<[5,?,?],f32> -> !torch.vtensor<[3],si32> + %cast_shape = torch.aten.to.dtype %shape, %int4, %false, %false, %none : !torch.vtensor<[3],si32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3],si64> + %dim = torch.constant.int 0 + %idx = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si32> + %select = torch.aten.index_select %cast_shape, %dim, %idx : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[],si32> -> !torch.vtensor<[],si64> + %item = torch.aten.item %select : !torch.vtensor<[],si64> -> !torch.int + %list = torch.prim.ListConstruct %item : (!torch.int) -> !torch.list + return %select : !torch.vtensor<[],si64> +} + +// ----- + +// CHECK-LABEL: @cast_int_float +func.func @cast_int_float(%arg0 : !torch.vtensor<[5,?,?],f32>) -> !torch.vtensor<[],f32> { + // CHECK: %[[I1:.*]] = torch.constant.int 1 + // CHECK: %[[SZE:.*]] = torch.aten.size.int %arg0, %[[I1]] : !torch.vtensor<[5,?,?],f32>, !torch.int -> !torch.int + // CHECK: %[[FLOAT:.*]] = torch.aten.Float.Scalar %[[SZE]] : !torch.int -> !torch.float + // CHECK: %[[TENSOR:.*]] = torch.prim.NumToTensor.Scalar %[[FLOAT]] : !torch.float -> !torch.vtensor<[],f32> + // CHECK: return %[[TENSOR]] : !torch.vtensor<[],f32> + %int6 = torch.constant.int 6 + %false = torch.constant.bool false + %none = torch.constant.none + %shape = torch.aten._shape_as_tensor %arg0 : !torch.vtensor<[5,?,?],f32> -> !torch.vtensor<[3],si32> + %cast_shape = torch.aten.to.dtype %shape, %int6, %false, %false, %none : !torch.vtensor<[3],si32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3],f32> + %dim = torch.constant.int 0 + %idx = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si32> + %select = torch.aten.index_select %cast_shape, %dim, %idx : !torch.vtensor<[3],f32>, !torch.int, !torch.vtensor<[],si32> -> !torch.vtensor<[],f32> + %item = torch.aten.item %select : !torch.vtensor<[],f32> -> !torch.float + %item_int = torch.aten.Int.Scalar %item : !torch.float -> !torch.int + %list = torch.prim.ListConstruct %item_int : (!torch.int) -> !torch.list + return %select : !torch.vtensor<[],f32> +} // ----- From 0088a40d1d4a7ac1cb7ed2c10e1805fb751c7bad Mon Sep 17 00:00:00 2001 From: zjgarvey Date: Tue, 12 Nov 2024 12:57:41 -0600 Subject: [PATCH 10/10] address comments --- lib/Conversion/TorchToArith/TorchToArith.cpp | 7 ++- lib/Dialect/Torch/IR/TorchOps.cpp | 4 +- .../Torch/Transforms/DecomposeComplexOps.cpp | 2 +- .../Torch/Transforms/ScalarizeShapes.cpp | 55 +++++++++++-------- 4 files changed, 41 insertions(+), 27 deletions(-) diff --git a/lib/Conversion/TorchToArith/TorchToArith.cpp b/lib/Conversion/TorchToArith/TorchToArith.cpp index 8e28c2f2ca0f..458ea31852ec 100644 --- a/lib/Conversion/TorchToArith/TorchToArith.cpp +++ b/lib/Conversion/TorchToArith/TorchToArith.cpp @@ -91,8 +91,11 @@ class ConvertAtenNegIntOp : public OpConversionPattern { typename OpConversionPattern::OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Value a = adaptor.getA(); - rewriter.replaceOpWithNewOp( - op, a, rewriter.create(op.getLoc(), -1, 64)); + rewriter.replaceOpWithNewOp( + op, + rewriter.create(op.getLoc(), /*value=*/0, + /*bitwidth=*/64), + a); return success(); } }; diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index ca81f780c3a6..868c5ef67a46 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -4068,9 +4068,9 @@ OpFoldResult AtenMulIntOp::fold(FoldAdaptor adaptor) { int64_t lhs, rhs; bool lConstant = matchPattern(getOperand(0), m_TorchConstantInt(&lhs)); bool rConstant = matchPattern(getOperand(1), m_TorchConstantInt(&rhs)); - if ((lConstant && lhs == 1)) + if (lConstant && lhs == 1) return getOperand(1); - if ((rConstant && rhs == 1)) + if (rConstant && rhs == 1) return getOperand(0); if ((lConstant && lhs == 0) || (rConstant && rhs == 0)) return getI64IntegerAttr(getContext(), 0); diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 8410191bd5d7..9db8a6949063 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -4590,7 +4590,7 @@ class DecomposeAtenUnflattenIntOp if (inputShape[dimInt] == Torch::kUnknownSize && llvm::count(sizesInts, -1) > 0) return rewriter.notifyMatchFailure( - op, "Unimplmented: dynamic unflatten dim with an inferred size."); + op, "Unimplemented: dynamic unflatten dim with an inferred size."); SmallVector sizesTorchInt; if (!getListConstructElements(op.getSizes(), sizesTorchInt)) diff --git a/lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp b/lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp index efdc5aea1eac..989057501957 100644 --- a/lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp +++ b/lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp @@ -714,7 +714,7 @@ class PropagateAtenItemPattern : public OpRewritePattern { ImplicitLocOpBuilder b(op.getLoc(), rewriter); // Rank 0 item op prop - if (selfTy.getSizes().size() == 0) { + if (selfTy.getSizes().empty()) { auto numToTensor = self.getDefiningOp(); auto squeezeDim = self.getDefiningOp(); if (!squeezeDim && !numToTensor) @@ -758,35 +758,46 @@ LogicalResult convertOpFoldResults(ImplicitLocOpBuilder &b, return failure(); if (!resultIsInt && !isa(resultDtype)) return failure(); + // if dtypes are both int or both float, no conversion needed if (static_cast(inputIsInt) == static_cast(resultIsInt)) { converted = elements; return success(); } - for (auto e : elements) { - auto eValue = dyn_cast(e); - if (eValue && resultIsInt) { - converted.push_back(b.createOrFold(eValue)); - continue; + + if (resultIsInt) { + for (auto &e : elements) { + auto eValue = dyn_cast(e); + if (eValue) { + converted.push_back(b.createOrFold(eValue)); + continue; + } + auto eAttr = dyn_cast(e); + auto eFloatAttr = dyn_cast_or_null(eAttr); + if (!eFloatAttr) + return failure(); + + converted.push_back(IntegerAttr::get( + resultDtype, static_cast(eFloatAttr.getValueAsDouble()))); } - if (eValue && !resultIsInt) { + return success(); + } + + // result is float + for (auto &e : elements) { + auto eValue = dyn_cast(e); + if (eValue) { converted.push_back(b.createOrFold(eValue)); continue; } auto eAttr = dyn_cast(e); - if (auto eIntAttr = dyn_cast_or_null(eAttr)) { - auto eInt = (inputIsInt.isSigned()) ? eIntAttr.getValue().getSExtValue() - : eIntAttr.getValue().getZExtValue(); - converted.push_back(FloatAttr::get(cast(resultDtype), - static_cast(eInt))); - continue; - } - if (auto eFloatAttr = dyn_cast_or_null(eAttr)) { - converted.push_back(IntegerAttr::get( - resultDtype, static_cast(eFloatAttr.getValueAsDouble()))); - continue; - } - return failure(); + auto eIntAttr = dyn_cast(eAttr); + if (!eIntAttr) + return failure(); + + auto eInt = (inputIsInt.isSigned()) ? eIntAttr.getValue().getSExtValue() + : eIntAttr.getValue().getZExtValue(); + converted.push_back(FloatAttr::get(resultDtype, static_cast(eInt))); } return success(); } @@ -920,7 +931,7 @@ class PropagateAtenArithmeticPattern : public OpRewritePattern { if (failed(materializeFolds(b, resultFolds, resultVals))) return failure(); - if (resultTy.getSizes().size() == 0) { + if (resultTy.getSizes().empty()) { rewriter.replaceOpWithNewOp( op, resultTy, resultVals.front()); return success(); @@ -1349,7 +1360,7 @@ class CanonicalizeAtenViewPattern : public OpRewritePattern { dynIdx = i; continue; } - // if we a -1 already, make dynCount invalid and break + // if we have a -1 already, make dynCount invalid and break if (szeInt == -1) { dynCount = -1; break;