Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Torch] : Implement lowering of torch.frac op #3847

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -5482,6 +5482,29 @@ def Torch_AtenPolarOp : Torch_Op<"aten.polar", [
}];
}

def Torch_AtenFracOp : Torch_Op<"aten.frac", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::frac : (Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenFracOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 1, 1);
}
void AtenFracOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
}

def Torch_AtenUnbindCopyIntOp : Torch_Op<"aten.unbind_copy.int", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
8 changes: 8 additions & 0 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6703,6 +6703,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.frac\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.log\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
Expand Down Expand Up @@ -11853,6 +11857,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.frac\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
" %int6 = torch.constant.int 6\n"
" return %int6 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.clamp_max\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number) -> !torch.int {\n"
" %int4 = torch.constant.int 4\n"
" %int11 = torch.constant.int 11\n"
Expand Down
29 changes: 29 additions & 0 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8012,6 +8012,34 @@ class DecomposeAtenTruncOp : public OpRewritePattern<AtenTruncOp> {
};
} // namespace

namespace {
// decompose `frac(x)` to `x - sign(x) * floor(abs(x))`
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The decomposition written below follows a different implementation. Better to either remove this comment or change it as per the decomposition.

class DecomposeAtenFracOp : public OpRewritePattern<AtenFracOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenFracOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value self = op.getSelf();

auto resultTy = dyn_cast<ValueTensorType>(op.getType());
if (!resultTy || !resultTy.hasDtype()) {
return rewriter.notifyMatchFailure(op, "result must have dtype");
}

if (isa<mlir::FloatType>(resultTy.getDtype())) {
Value trunc = rewriter.create<AtenTruncOp>(loc, resultTy, self);
Value alpha =
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1));
rewriter.replaceOpWithNewOp<AtenSubTensorOp>(op, resultTy, self, trunc,
alpha);
return success();
}

return failure();
}
};
} // namespace

namespace {
// decompose `fmod(x, y)` to `x - trunc(x/y) * y`
class DecomposeAtenFmodTensorOp : public OpRewritePattern<AtenFmodTensorOp> {
Expand Down Expand Up @@ -10262,6 +10290,7 @@ class DecomposeComplexOpsPass
addPatternIfTargetOpIsIllegal<DecomposeAtenRad2degOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenCosineSimilarityOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenTruncOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenFracOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenFmodTensorOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenBaddbmmOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenFloorDivideOp>(patterns);
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -541,6 +541,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenRad2degOp>();
target.addIllegalOp<AtenCosineSimilarityOp>();
target.addIllegalOp<AtenTruncOp>();
target.addIllegalOp<AtenFracOp>();
target.addIllegalOp<AtenNewEmptyStridedOp>();
target.addIllegalOp<AtenEmptyStridedOp>();
target.addIllegalOp<AtenBucketizeTensorOp>();
Expand Down
2 changes: 2 additions & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -3245,6 +3245,8 @@
"Rot90MultipleRotationsModule_basic",
"Rot90NegativeEvenRotationsModule_basic",
"Rot90NegativeOddRotationsModule_basic",
# Error: 'aten::frac' to ONNX opset version 20 is not supported
"ElementwiseFracModule_basic",
# Failure - unknown
"BernoulliModule_basic",
"Conv_Transpose1dModule_basic",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,9 @@ def aten〇ceil〡shape(self: List[int]) -> List[int]:
def aten〇trunc〡shape(self: List[int]) -> List[int]:
return upstream_shape_functions.unary(self)

def aten〇frac〡shape(self: List[int]) -> List[int]:
return upstream_shape_functions.unary(self)

def aten〇log〡shape(self: List[int]) -> List[int]:
return upstream_shape_functions.unary(self)

Expand Down Expand Up @@ -2897,6 +2900,9 @@ def aten〇trunc〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
self_rank, self_dtype = self_rank_dtype
return self_dtype

def aten〇frac〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
return torch.float32

@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, max=0))
def aten〇clamp_max〡dtype(self_rank_dtype: Tuple[int, int], max: Union[int, float, complex]) -> int:
self_rank, self_dtype = self_rank_dtype
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,7 @@ def emit_with_mutating_variants(key, **kwargs):
emit("aten::hardshrink : (Tensor, Scalar) -> (Tensor)")
emit("aten::softshrink : (Tensor, Scalar) -> (Tensor)")
emit("aten::polar : (Tensor, Tensor) -> (Tensor)")
emit("aten::frac : (Tensor) -> (Tensor)")

# Ops with dynamic number of outputs
emit("aten::unbind_copy.int : (Tensor, int) -> (Tensor[])")
Expand Down
23 changes: 23 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -2839,6 +2839,29 @@ def ElementwiseTruncIntModule_basic(module, tu: TestUtils):
# ==============================================================================


class ElementwiseFracModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([1, 6], torch.float32, True),
]
)
def forward(self, a):
return torch.frac(a)


@register_test_case(module_factory=lambda: ElementwiseFracModule())
def ElementwiseFracModule_basic(module, tu: TestUtils):
module.forward(torch.tensor([[-torch.inf, torch.inf, torch.nan, -2.3, 0.0, 1.5]]))

Comment on lines +2842 to +2860
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it take f64 input? If yes, can you please add a test for the same?


# ==============================================================================


class ElementwiseSignModule(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down
Loading