-
Notifications
You must be signed in to change notification settings - Fork 12.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][tosa] Add FP8 support #127730
base: main
Are you sure you want to change the base?
[mlir][tosa] Add FP8 support #127730
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-tosa Author: Jerry-Ge (Jerry-Ge) ChangesAdd FP8 support to following TOSA operators: ARGMAX Also added verifiers as needed to check input/output element types and renamed inputs of transpose_conv2d and select to match spec. Patch is 73.53 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/127730.diff 13 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index d11ba65a13736..8947f7a9bd9a1 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -41,7 +41,7 @@ def Tosa_ArgMaxOp : Tosa_InferShapedTypeOp<"argmax"> {
}];
let arguments = (ins
- Tosa_Tensor: $input,
+ Tosa_Tensor_Extended: $input,
I32Attr: $axis,
DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
);
@@ -73,7 +73,8 @@ def Tosa_AvgPool2dOp : Tosa_InferShapedTypeOp<"avg_pool2d"> {
}];
let arguments = (ins
- Tosa_Tensor4D:$input,
+ Tosa_Tensor4D_Extended:$input,
+
Tosa_IntArrayAttr2:$kernel,
Tosa_IntArrayAttr2:$stride,
Tosa_IntArrayAttr4:$pad,
@@ -83,7 +84,7 @@ def Tosa_AvgPool2dOp : Tosa_InferShapedTypeOp<"avg_pool2d"> {
);
let results = (outs
- Tosa_Tensor4D:$output
+ Tosa_Tensor4D_Extended:$output
);
let builders = [Tosa_AvgPool2dOpQuantInfoBuilder];
@@ -102,7 +103,7 @@ def Tosa_Conv2DOp : Tosa_ConvOp<"conv2d"> {
}];
let arguments = (ins
- Tosa_Tensor4D:$input,
+ Tosa_Tensor4D_Extended:$input,
TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
Tosa_Tensor1D:$bias,
Optional<Tosa_ScalarTensor>:$input_zp,
@@ -133,11 +134,12 @@ def Tosa_Conv3DOp : Tosa_ConvOp<"conv3d"> {
}];
let arguments = (ins
- Tosa_Tensor5D:$input,
- TosaTensorRankOf<[Tosa_Weight], [5]>:$weight,
- Tosa_Tensor1D:$bias,
+ Tosa_Tensor5D_Extended:$input,
+ TensorRankOf<[Tosa_Weight], [5]>:$weight,
+ Tosa_Tensor1D_Extended:$bias,
Optional<Tosa_ScalarTensor>:$input_zp,
Optional<Tosa_ScalarTensor>:$weight_zp,
+
Tosa_IntArrayAttr6:$pad,
Tosa_IntArrayAttr3:$stride,
Tosa_IntArrayAttr3:$dilation,
@@ -146,7 +148,7 @@ def Tosa_Conv3DOp : Tosa_ConvOp<"conv3d"> {
);
let results = (outs
- Tosa_Tensor5D:$output
+ Tosa_Tensor5D_Extended:$output
);
let builders = [Tosa_ConvOpQuantInfoBuilder];
@@ -165,11 +167,12 @@ def Tosa_DepthwiseConv2DOp : Tosa_ConvOp<"depthwise_conv2d"> {
}];
let arguments = (ins
- Tosa_Tensor4D:$input,
+ Tosa_Tensor4D_Extended:$input,
TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
- Tosa_Tensor1D:$bias,
+ Tosa_Tensor1D_Extended:$bias,
Optional<Tosa_ScalarTensor>:$input_zp,
Optional<Tosa_ScalarTensor>:$weight_zp,
+
Tosa_IntArrayAttr4:$pad,
Tosa_IntArrayAttr2:$stride,
Tosa_IntArrayAttr2:$dilation,
@@ -178,7 +181,7 @@ def Tosa_DepthwiseConv2DOp : Tosa_ConvOp<"depthwise_conv2d"> {
);
let results = (outs
- Tosa_Tensor4D:$output
+ Tosa_Tensor4D_Extended:$output
);
let builders = [Tosa_ConvOpQuantInfoBuilder];
@@ -237,8 +240,8 @@ def Tosa_MatMulOp : Tosa_InferShapedTypeOp<"matmul"> {
}];
let arguments = (ins
- Tosa_Tensor3D:$a,
- Tosa_Tensor3D:$b,
+ Tosa_Tensor3D_Extended:$a,
+ Tosa_Tensor3D_Extended:$b,
OptionalAttr<I32Attr>:$a_zp,
OptionalAttr<I32Attr>:$b_zp
);
@@ -248,6 +251,7 @@ def Tosa_MatMulOp : Tosa_InferShapedTypeOp<"matmul"> {
);
let builders = [Tosa_MatMulOpQuantInfoBuilder];
+ let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
@@ -264,7 +268,7 @@ def Tosa_MaxPool2dOp : Tosa_InferShapedTypeOp<"max_pool2d"> {
}];
let arguments = (ins
- Tosa_Tensor4D:$input,
+ Tosa_Tensor4D_Extended:$input,
Tosa_IntArrayAttr2:$kernel,
Tosa_IntArrayAttr2:$stride,
@@ -273,10 +277,11 @@ def Tosa_MaxPool2dOp : Tosa_InferShapedTypeOp<"max_pool2d"> {
);
let results = (outs
- Tosa_Tensor4D:$output
+ Tosa_Tensor4D_Extended:$output
);
let hasCanonicalizer = 1;
+ let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
@@ -327,11 +332,12 @@ def Tosa_TransposeConv2DOp : Tosa_ConvOp<"transpose_conv2d"> {
}];
let arguments = (ins
- Tosa_Tensor4D:$input,
+ Tosa_Tensor4D_Extended:$input,
TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
- Tosa_Tensor1D:$bias,
+ Tosa_Tensor1D_Extended:$bias,
Optional<Tosa_ScalarTensor>:$input_zp,
Optional<Tosa_ScalarTensor>:$weight_zp,
+
Tosa_IntArrayAttr4:$out_pad,
Tosa_IntArrayAttr2:$stride,
Tosa_IntArrayAttr4:$out_shape,
@@ -1190,9 +1196,9 @@ def Tosa_SelectOp : Tosa_ElementwiseOp<"select"> {
}];
let arguments = (ins
- Tosa_I1Tensor:$pred,
- Tosa_Tensor:$on_true,
- Tosa_Tensor:$on_false
+ Tosa_I1Tensor:$input1,
+ Tosa_Tensor:$input2,
+ Tosa_Tensor:$input3
);
let results = (outs
@@ -1200,9 +1206,10 @@ def Tosa_SelectOp : Tosa_ElementwiseOp<"select"> {
);
let hasCanonicalizeMethod = 1;
let hasFolder = 1;
+ let hasVerifier = 1;
let assemblyFormat = [{
- operands attr-dict `:` `(` type($pred) `,` type($on_true) `,` type($on_false)
+ operands attr-dict `:` `(` type($input1) `,` type($input2) `,` type($input3)
`)` `->` type($output)
}];
}
@@ -1518,16 +1525,17 @@ def Tosa_ConcatOp : Tosa_InferTensorTypeOp<"concat"> {
}];
let arguments = (ins
- Variadic<Tosa_Tensor>:$input1,
+ Variadic<Tosa_Tensor_Extended>:$input1,
I32Attr:$axis
);
let results = (outs
- Tosa_Tensor:$output
+ Tosa_Tensor_Extended:$output
);
let hasCanonicalizer = 1;
let hasFolder = 1;
+ let hasVerifier = 1;
let extraClassDeclaration = [{
/// Returns true when two result types are compatible for this op;
@@ -1563,14 +1571,14 @@ def Tosa_PadOp : Tosa_InferShapedTypeOp<"pad"> {
}];
let arguments = (ins
- Tosa_RankedTensor:$input1,
+ Tosa_RankedTensor_Extended:$input1,
Tosa_Shape:$padding,
- Optional<Tosa_Rank0Tensor>:$pad_const,
+ Optional<Tosa_ScalarTensor_Extended>:$pad_const,
OptionalAttr<I32Attr>:$input_zp
);
let results = (outs
- Tosa_RankedTensor:$output
+ Tosa_RankedTensor_Extended:$output
);
let builders = [Tosa_PadOpQuantInfoBuilder,
@@ -1597,12 +1605,12 @@ def Tosa_ReshapeOp : Tosa_InferTensorTypeOp<"reshape"> {
let hasVerifier = 1;
let arguments = (ins
- Tosa_Tensor:$input1,
+ Tosa_Tensor_Extended:$input1,
Tosa_Shape:$shape
);
let results = (outs
- Tosa_RankedTensor:$output
+ Tosa_RankedTensor_Extended:$output
);
let extraClassDeclaration = [{
@@ -1629,12 +1637,12 @@ def Tosa_ReverseOp: Tosa_Op<"reverse", [
}];
let arguments = (ins
- Tosa_Tensor:$input1,
+ Tosa_Tensor_Extended:$input1,
I32Attr:$axis
);
let results = (outs
- Tosa_Tensor:$output
+ Tosa_Tensor_Extended:$output
);
let hasFolder = 1;
@@ -1656,13 +1664,13 @@ def Tosa_SliceOp : Tosa_InferShapedTypeOp<"slice"> {
}];
let arguments = (ins
- Tosa_Tensor:$input1,
+ Tosa_Tensor_Extended:$input1,
Tosa_Shape:$start,
Tosa_Shape:$size
);
let results = (outs
- Tosa_Tensor:$output
+ Tosa_Tensor_Extended:$output
);
let hasCanonicalizer = 1;
@@ -1681,11 +1689,11 @@ def Tosa_TileOp : Tosa_InferShapedTypeOp<"tile"> {
}];
let arguments = (ins
- Tosa_Tensor:$input1,
+ Tosa_Tensor_Extended:$input1,
Tosa_Shape:$multiples);
let results = (outs
- Tosa_Tensor:$output
+ Tosa_Tensor_Extended:$output
);
let extraClassDeclaration = [{
@@ -1709,12 +1717,12 @@ def Tosa_TransposeOp : Tosa_InferShapedTypeOp<"transpose",
}];
let arguments = (ins
- Tosa_Tensor:$input1,
- Tosa_Int32Tensor:$perms
+ Tosa_Tensor_Extended:$input1,
+ Tosa_Int32Or64Tensor:$perms
);
let results = (
- outs Tosa_Tensor:$output
+ outs Tosa_Tensor_Extended:$output
);
let extraClassDeclaration = [{
@@ -1743,13 +1751,14 @@ def Tosa_GatherOp : Tosa_InferShapedTypeOp<"gather"> {
}];
let arguments = (ins
- Tosa_Tensor3D:$values,
- TosaTensorRankOf<[Tosa_Int32], [2]>:$indices
+ Tosa_Tensor3D_Extended:$values,
+ 2DTensorOf<[Tosa_Int32]>:$indices
);
let results = (outs
- Tosa_Tensor3D:$output
+ Tosa_Tensor3D_Extended:$output
);
+ let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
@@ -1764,14 +1773,15 @@ def Tosa_ScatterOp : Tosa_InferShapedTypeOp<"scatter"> {
}];
let arguments = (ins
- Tosa_Tensor3D:$values_in,
- TosaTensorRankOf<[Tosa_Int32], [2]>:$indices,
- Tosa_Tensor3D:$input
+ Tosa_Tensor3D_Extended:$values_in,
+ 2DTensorOf<[Tosa_Int32]>:$indices,
+ Tosa_Tensor3D_Extended:$input
);
let results = (outs
- Tosa_Tensor3D:$values_out
+ Tosa_Tensor3D_Extended:$values_out
);
+ let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
@@ -1828,37 +1838,66 @@ def Tosa_CastOp: Tosa_Op<"cast", [Pure,
| Mode | Input | Output |
|--------------------------|---------|---------|
- | signed 8 to bool | int8 | Boolean |
- | signed 16 to bool | int16 | Boolean |
- | signed 32 to bool | int32 | Boolean |
- | bool to 8 | Boolean | int8 |
- | bool to 16 | Boolean | int16 |
- | bool to 32 | Boolean | int32 |
- | signed 8 to signed 16 | int8 | int16 |
- | signed 8 to signed 32 | int8 | int32 |
- | signed 16 to signed 8 | int16 | int8 |
- | signed 16 to signed 32 | int16 | int32 |
- | signed 32 to signed 8 | int32 | int8 |
- | signed 32 to signed 16 | int32 | int16 |
- | float to signed 8 | float | int8 |
- | float to signed 16 | float | int16 |
- | signed 8 to float | int8 | float |
- | signed 16 to float | int16 | float |
- | float 32 to float 64 | float32 | float64 |
- | float 64 to float 32 | float64 | float32 |
- }];
-
- let arguments = (ins
- Tosa_Tensor:$input
- );
-
- let results = (outs
- Tosa_Tensor:$output
+ | bool to int 8 | Boolean | int8 |
+ | bool to int 16 | Boolean | int16 |
+ | bool to int 32 | Boolean | int32 |
+ | int 8 to bool | int8 | Boolean |
+ | int 8 to int 16 | int8 | int16 |
+ | int 8 to int 32 | int8 | int32 |
+ | int 8 to fp16 | int8 | float16 |
+ | int 8 to bf16 | int8 | bf16 |
+ | int 8 to fp32 | int8 | float32 |
+ | int 16 to bool | int16 | Boolean |
+ | int 16 to int 8 | int16 | int8 |
+ | int 16 to int 32 | int16 | int32 |
+ | int 16 to fp16 | int16 | float16 |
+ | int 16 to bf16 | int16 | bf16 |
+ | int 16 to fp32 | int16 | float32 |
+ | int 32 to bool | int32 | Boolean |
+ | int 32 to int 8 | int32 | int8 |
+ | int 32 to int 16 | int32 | int16 |
+ | int 32 to fp16 | int32 | float16 |
+ | int 32 to bf16 | int32 | bf16 |
+ | int 32 to fp32 | int32 | float32 |
+ | bf16 to int 8 | bf16 | int8 |
+ | bf16 to int 16 | bf16 | int16 |
+ | bf16 to int 32 | bf16 | int32 |
+ | bf16 to fp8e4m3 | bf16 | fp8e4m3 |
+ | bf16 to fp8e5m2 | bf16 | fp8e5m2 |
+ | bf16 to fp32 | bf16 | float32 |
+ | fp8e4m3 to fp16 | fp8e4m3 | float16 |
+ | fp8e4m3 to bf16 | fp8e4m3 | bf16 |
+ | fp8e4m3 to fp32 | fp8e4m3 | float32 |
+ | fp8e5m2 to fp16 | fp8e5m2 | float16 |
+ | fp8e5m2 to bf16 | fp8e5m2 | bf16 |
+ | fp8e5m2 to fp32 | fp8e5m2 | float32 |
+ | fp16 to int 8 | float16 | int8 |
+ | fp16 to int 16 | float16 | int16 |
+ | fp16 to int 32 | float16 | int32 |
+ | fp16 to fp8e4m3 | float16 | fp8e4m3 |
+ | fp16 to fp8e5m2 | float16 | fp8e5m2 |
+ | fp16 to fp32 | float16 | float32 |
+ | fp32 to int 8 | float32 | int8 |
+ | fp32 to int 16 | float32 | int16 |
+ | fp32 to int 32 | float32 | int32 |
+ | fp32 to fp8e4m3 | float32 | fp8e4m3 |
+ | fp32 to fp8e5m2 | float32 | fp8e5m2 |
+ | fp32 to bf16 | float32 | bf16 |
+ | fp32 to fp16 | float32 | float16 |
+ }];
+
+ let arguments = (ins
+ TensorOf<[AnyTypeOf<[Tosa_AnyNumber_Extended, F64]>]>:$input
+ );
+
+ let results = (outs
+ TensorOf<[AnyTypeOf<[Tosa_AnyNumber_Extended, F64]>]>:$output
);
let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
let hasFolder = 1;
+ let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
@@ -1940,7 +1979,7 @@ def Tosa_ConstOp : Tosa_Op<"const", [ConstantLike, Pure,
);
let results = (outs
- TosaTensorOf<[AnyTypeOf<[Tosa_AnyNumber]>]>:$output
+ TensorOf<[AnyTypeOf<[Tosa_AnyNumber_Extended, F64, Tosa_Int4]>]>:$output
);
let hasFolder = 1;
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
index cf6ddc66f4ada..2c6e647ae32fd 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
@@ -74,16 +74,25 @@ def Tosa_QuantizedInt : AnyTypeOf<[Tosa_QuantizedType<"uint8", [8], 0>,
Tosa_QuantizedType<"int16", [16, 0], 1>,
Tosa_QuantizedType<"int32", [32, 0], 1>]>;
+def Tosa_F8 : AnyTypeOf<[
+ F8E4M3FN,
+ F8E5M2]>;
+
//===----------------------------------------------------------------------===//
// Multi-category types.
//===----------------------------------------------------------------------===//
def Tosa_AnyNumber : AnyTypeOf<[Tosa_Int, Tosa_QuantizedInt, AnyFloat],
"number">;
+// Add F8 type support to Tosa_AnyNumber
+def Tosa_AnyNumber_Extended : AnyTypeOf<[Tosa_AnyNumber, Tosa_F8],
+ "number_extended">;
+
// For weight tensors from tosa::Conv2DOp, tosa::Conv3DOp,
// tosa::DepthwiseConv2DOp, tosa::TransposeConv2DOp
def Tosa_Weight : AnyTypeOf<[Tosa_Int4, Tosa_Int8,
- Tosa_QuantizedInt, AnyFloat]>;
+ Tosa_QuantizedInt, AnyFloat, Tosa_F8]>;
+
//===----------------------------------------------------------------------===//
// TOSA Tensor Conformance
@@ -130,9 +139,11 @@ def Tosa_FloatTensor : TosaTensorOf<[AnyFloat]>;
// Either ranked or unranked tensor of TOSA supported element types.
def Tosa_Tensor : TosaTensorOf<[Tosa_AnyNumber]>;
+def Tosa_Tensor_Extended : TosaTensorOf<[Tosa_AnyNumber_Extended]>;
// Must be ranked but no further constraints
-def Tosa_RankedTensor : TosaRankedTensorOf<[Tosa_AnyNumber]>;
+def Tosa_RankedTensor : RankedTensorOf<[Tosa_AnyNumber]>;
+def Tosa_RankedTensor_Extended : RankedTensorOf<[Tosa_AnyNumber_Extended]>;
// Any tensor element type allowed in Tosa ops.
def Tosa_ElementType : Type<Or<[Tosa_Int.predicate, Tosa_QuantizedInt.predicate,
@@ -145,9 +156,9 @@ class Tosa_TensorOfOrNone<list<Type> allowedTypes, string description = ""> :
// Tensor types with constrained ranks.
//===----------------------------------------------------------------------===//
-def Tosa_Rank0Tensor : TosaTensorRankOf<[Tosa_AnyNumber], [0]>;
-
+// Scalar tensors: Rank-1 (with only one element)
def Tosa_ScalarTensor : TosaScalarTensorOf<[Tosa_AnyNumber], [1]>;
+def Tosa_ScalarTensor_Extended : TosaScalarTensorOf<[Tosa_AnyNumber_Extended], [1]>;
def Tosa_ScalarInt8Tensor : TosaScalarTensorOf<[Tosa_Int8], [1]>;
// We include unranked tensors as a supported type for all possible tosa
@@ -155,6 +166,7 @@ def Tosa_ScalarInt8Tensor : TosaScalarTensorOf<[Tosa_Int8], [1]>;
// they should be shape propagate used Tosa's shape inference pass and verified
// to not include any remaining unranked tensors.
def Tosa_UnrankedTensor : TosaUnrankedTensorOf<[Tosa_AnyNumber]>;
+def Tosa_UnrankedTensorExtended : TosaUnrankedTensorOf<[Tosa_AnyNumber_Extended]>;
def Tosa_Tensor1D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [1]>], "1-d tosa-conformant tensor", "::mlir::TensorType">;
def Tosa_Tensor2D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [2]>], "2-d tosa-conformant tensor", "::mlir::TensorType">;
@@ -162,6 +174,17 @@ def Tosa_Tensor3D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNu
def Tosa_Tensor4D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [4]>], "4-d tosa-conformant tensor", "::mlir::TensorType">;
def Tosa_Tensor5D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [5]>], "5-d tosa-conformant tensor", "::mlir::TensorType">;
+def Tosa_Tensor1D_Extended : AnyTypeOf<[Tosa_UnrankedTensorExtended, TosaTensorRankOf<[Tosa_AnyNumber_Extended], [1]>],
+ "1-d tosa-conformant tensor extended", "::mlir::TensorType">;
+def Tosa_Tensor2D_Extended : AnyTypeOf<[Tosa_UnrankedTensorExtended, TosaTensorRankOf<[Tosa_AnyNumber_Extended], [2]>],
+ "2-d tosa-conformant tensor extended", "::mlir::TensorType">;
+def Tosa_Tensor3D_Extended : AnyTypeOf<[Tosa_UnrankedTensorExtended, TosaTensorRankOf<[Tosa_AnyNumber_Extended], [3]>],
+ "3-d tosa-conformant tensor extended", "::mlir::TensorType">;
+def Tosa_Tensor4D_Extended : AnyTypeOf<[Tosa_UnrankedTensorExtended, TosaTensorRankOf<[Tosa_AnyNumber_Extended], [4]>],
+ "4-d tosa-conformant tensor extended", "::mlir::TensorType">;
+def Tosa_Tensor5D_Extended : AnyTypeOf<[Tosa_UnrankedTensorExtended, TosaTensorRankOf<[Tosa_AnyNumber_Extended], [5]>],
+ "5-d tosa-conformant tensor extended", "::mlir::TensorType">;
+
// Ranked tensors up to given rank.
def Tosa_Tensor1Dto4D : AnyTypeOf<[
Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [1,2,3,4]>]>;
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 69b3f6d674167..704f8a82d11fa 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -65,12 +65,12 @@ void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,
}
LogicalResult SelectOp::canonicalize(SelectOp op, PatternRewriter &rewriter) {
- auto notOp = op.getPred().getDefiningOp<tosa::LogicalNotOp>();
+ auto notOp = op.getInput1().getDefiningOp<tosa::LogicalNotOp>();
if (!notOp)
return failure();
rewriter.modifyOpInPlace(op, [&]() {
op.getOperation()->setOperands(
- {notOp.getInput1(), op.getOnFalse(), op.getOnTrue()});
+ {notOp.getInput1(), op.getInput3(), op.getInput2()});
});
return success();
}
@@ -1118,18 +1118,18 @@ OpFoldResult SliceOp::fold(FoldAdaptor adaptor) {
}
OpFoldResult tosa::SelectOp::fold(FoldAdaptor adaptor) {
- if (getOnTrue() == getOnFalse())
- return getOnTrue();
+ if (getInput2() == getInput3())
+ return getInput2();
auto predicate =
- llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getPred());
+ llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getInput1());
if (!predicate)
return {};
if (!predicate.isSplat())
return {};
- return predicate.getSplatValue<APInt>().getBoolValue() ? getOnTrue()
- : getOnFalse();
+ return predicate.getSplatValue<APInt>().getBoolValue() ? getInput2()
+ : getInput3();
}
OpFoldResult TileOp::fold(FoldAdaptor adaptor) {
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 67021d6c07401..411f06f4a0b7c 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -217,15 +217,17 @@ void mlir::tosa::printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type,
template <typename T>
static LogicalResult verifyConvOp(T op) {
- // All TOSA conv ops have an input and weight arguments which must be ranked
- // tensors.
+ // All TOSA conv ops have an input() and weight().
auto inputType = llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
+
+ Ra...
[truncated]
|
@llvm/pr-subscribers-mlir-linalg Author: Jerry-Ge (Jerry-Ge) ChangesAdd FP8 support to following TOSA operators: ARGMAX Also added verifiers as needed to check input/output element types and renamed inputs of transpose_conv2d and select to match spec. Patch is 73.53 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/127730.diff 13 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index d11ba65a13736..8947f7a9bd9a1 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -41,7 +41,7 @@ def Tosa_ArgMaxOp : Tosa_InferShapedTypeOp<"argmax"> {
}];
let arguments = (ins
- Tosa_Tensor: $input,
+ Tosa_Tensor_Extended: $input,
I32Attr: $axis,
DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
);
@@ -73,7 +73,8 @@ def Tosa_AvgPool2dOp : Tosa_InferShapedTypeOp<"avg_pool2d"> {
}];
let arguments = (ins
- Tosa_Tensor4D:$input,
+ Tosa_Tensor4D_Extended:$input,
+
Tosa_IntArrayAttr2:$kernel,
Tosa_IntArrayAttr2:$stride,
Tosa_IntArrayAttr4:$pad,
@@ -83,7 +84,7 @@ def Tosa_AvgPool2dOp : Tosa_InferShapedTypeOp<"avg_pool2d"> {
);
let results = (outs
- Tosa_Tensor4D:$output
+ Tosa_Tensor4D_Extended:$output
);
let builders = [Tosa_AvgPool2dOpQuantInfoBuilder];
@@ -102,7 +103,7 @@ def Tosa_Conv2DOp : Tosa_ConvOp<"conv2d"> {
}];
let arguments = (ins
- Tosa_Tensor4D:$input,
+ Tosa_Tensor4D_Extended:$input,
TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
Tosa_Tensor1D:$bias,
Optional<Tosa_ScalarTensor>:$input_zp,
@@ -133,11 +134,12 @@ def Tosa_Conv3DOp : Tosa_ConvOp<"conv3d"> {
}];
let arguments = (ins
- Tosa_Tensor5D:$input,
- TosaTensorRankOf<[Tosa_Weight], [5]>:$weight,
- Tosa_Tensor1D:$bias,
+ Tosa_Tensor5D_Extended:$input,
+ TensorRankOf<[Tosa_Weight], [5]>:$weight,
+ Tosa_Tensor1D_Extended:$bias,
Optional<Tosa_ScalarTensor>:$input_zp,
Optional<Tosa_ScalarTensor>:$weight_zp,
+
Tosa_IntArrayAttr6:$pad,
Tosa_IntArrayAttr3:$stride,
Tosa_IntArrayAttr3:$dilation,
@@ -146,7 +148,7 @@ def Tosa_Conv3DOp : Tosa_ConvOp<"conv3d"> {
);
let results = (outs
- Tosa_Tensor5D:$output
+ Tosa_Tensor5D_Extended:$output
);
let builders = [Tosa_ConvOpQuantInfoBuilder];
@@ -165,11 +167,12 @@ def Tosa_DepthwiseConv2DOp : Tosa_ConvOp<"depthwise_conv2d"> {
}];
let arguments = (ins
- Tosa_Tensor4D:$input,
+ Tosa_Tensor4D_Extended:$input,
TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
- Tosa_Tensor1D:$bias,
+ Tosa_Tensor1D_Extended:$bias,
Optional<Tosa_ScalarTensor>:$input_zp,
Optional<Tosa_ScalarTensor>:$weight_zp,
+
Tosa_IntArrayAttr4:$pad,
Tosa_IntArrayAttr2:$stride,
Tosa_IntArrayAttr2:$dilation,
@@ -178,7 +181,7 @@ def Tosa_DepthwiseConv2DOp : Tosa_ConvOp<"depthwise_conv2d"> {
);
let results = (outs
- Tosa_Tensor4D:$output
+ Tosa_Tensor4D_Extended:$output
);
let builders = [Tosa_ConvOpQuantInfoBuilder];
@@ -237,8 +240,8 @@ def Tosa_MatMulOp : Tosa_InferShapedTypeOp<"matmul"> {
}];
let arguments = (ins
- Tosa_Tensor3D:$a,
- Tosa_Tensor3D:$b,
+ Tosa_Tensor3D_Extended:$a,
+ Tosa_Tensor3D_Extended:$b,
OptionalAttr<I32Attr>:$a_zp,
OptionalAttr<I32Attr>:$b_zp
);
@@ -248,6 +251,7 @@ def Tosa_MatMulOp : Tosa_InferShapedTypeOp<"matmul"> {
);
let builders = [Tosa_MatMulOpQuantInfoBuilder];
+ let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
@@ -264,7 +268,7 @@ def Tosa_MaxPool2dOp : Tosa_InferShapedTypeOp<"max_pool2d"> {
}];
let arguments = (ins
- Tosa_Tensor4D:$input,
+ Tosa_Tensor4D_Extended:$input,
Tosa_IntArrayAttr2:$kernel,
Tosa_IntArrayAttr2:$stride,
@@ -273,10 +277,11 @@ def Tosa_MaxPool2dOp : Tosa_InferShapedTypeOp<"max_pool2d"> {
);
let results = (outs
- Tosa_Tensor4D:$output
+ Tosa_Tensor4D_Extended:$output
);
let hasCanonicalizer = 1;
+ let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
@@ -327,11 +332,12 @@ def Tosa_TransposeConv2DOp : Tosa_ConvOp<"transpose_conv2d"> {
}];
let arguments = (ins
- Tosa_Tensor4D:$input,
+ Tosa_Tensor4D_Extended:$input,
TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
- Tosa_Tensor1D:$bias,
+ Tosa_Tensor1D_Extended:$bias,
Optional<Tosa_ScalarTensor>:$input_zp,
Optional<Tosa_ScalarTensor>:$weight_zp,
+
Tosa_IntArrayAttr4:$out_pad,
Tosa_IntArrayAttr2:$stride,
Tosa_IntArrayAttr4:$out_shape,
@@ -1190,9 +1196,9 @@ def Tosa_SelectOp : Tosa_ElementwiseOp<"select"> {
}];
let arguments = (ins
- Tosa_I1Tensor:$pred,
- Tosa_Tensor:$on_true,
- Tosa_Tensor:$on_false
+ Tosa_I1Tensor:$input1,
+ Tosa_Tensor:$input2,
+ Tosa_Tensor:$input3
);
let results = (outs
@@ -1200,9 +1206,10 @@ def Tosa_SelectOp : Tosa_ElementwiseOp<"select"> {
);
let hasCanonicalizeMethod = 1;
let hasFolder = 1;
+ let hasVerifier = 1;
let assemblyFormat = [{
- operands attr-dict `:` `(` type($pred) `,` type($on_true) `,` type($on_false)
+ operands attr-dict `:` `(` type($input1) `,` type($input2) `,` type($input3)
`)` `->` type($output)
}];
}
@@ -1518,16 +1525,17 @@ def Tosa_ConcatOp : Tosa_InferTensorTypeOp<"concat"> {
}];
let arguments = (ins
- Variadic<Tosa_Tensor>:$input1,
+ Variadic<Tosa_Tensor_Extended>:$input1,
I32Attr:$axis
);
let results = (outs
- Tosa_Tensor:$output
+ Tosa_Tensor_Extended:$output
);
let hasCanonicalizer = 1;
let hasFolder = 1;
+ let hasVerifier = 1;
let extraClassDeclaration = [{
/// Returns true when two result types are compatible for this op;
@@ -1563,14 +1571,14 @@ def Tosa_PadOp : Tosa_InferShapedTypeOp<"pad"> {
}];
let arguments = (ins
- Tosa_RankedTensor:$input1,
+ Tosa_RankedTensor_Extended:$input1,
Tosa_Shape:$padding,
- Optional<Tosa_Rank0Tensor>:$pad_const,
+ Optional<Tosa_ScalarTensor_Extended>:$pad_const,
OptionalAttr<I32Attr>:$input_zp
);
let results = (outs
- Tosa_RankedTensor:$output
+ Tosa_RankedTensor_Extended:$output
);
let builders = [Tosa_PadOpQuantInfoBuilder,
@@ -1597,12 +1605,12 @@ def Tosa_ReshapeOp : Tosa_InferTensorTypeOp<"reshape"> {
let hasVerifier = 1;
let arguments = (ins
- Tosa_Tensor:$input1,
+ Tosa_Tensor_Extended:$input1,
Tosa_Shape:$shape
);
let results = (outs
- Tosa_RankedTensor:$output
+ Tosa_RankedTensor_Extended:$output
);
let extraClassDeclaration = [{
@@ -1629,12 +1637,12 @@ def Tosa_ReverseOp: Tosa_Op<"reverse", [
}];
let arguments = (ins
- Tosa_Tensor:$input1,
+ Tosa_Tensor_Extended:$input1,
I32Attr:$axis
);
let results = (outs
- Tosa_Tensor:$output
+ Tosa_Tensor_Extended:$output
);
let hasFolder = 1;
@@ -1656,13 +1664,13 @@ def Tosa_SliceOp : Tosa_InferShapedTypeOp<"slice"> {
}];
let arguments = (ins
- Tosa_Tensor:$input1,
+ Tosa_Tensor_Extended:$input1,
Tosa_Shape:$start,
Tosa_Shape:$size
);
let results = (outs
- Tosa_Tensor:$output
+ Tosa_Tensor_Extended:$output
);
let hasCanonicalizer = 1;
@@ -1681,11 +1689,11 @@ def Tosa_TileOp : Tosa_InferShapedTypeOp<"tile"> {
}];
let arguments = (ins
- Tosa_Tensor:$input1,
+ Tosa_Tensor_Extended:$input1,
Tosa_Shape:$multiples);
let results = (outs
- Tosa_Tensor:$output
+ Tosa_Tensor_Extended:$output
);
let extraClassDeclaration = [{
@@ -1709,12 +1717,12 @@ def Tosa_TransposeOp : Tosa_InferShapedTypeOp<"transpose",
}];
let arguments = (ins
- Tosa_Tensor:$input1,
- Tosa_Int32Tensor:$perms
+ Tosa_Tensor_Extended:$input1,
+ Tosa_Int32Or64Tensor:$perms
);
let results = (
- outs Tosa_Tensor:$output
+ outs Tosa_Tensor_Extended:$output
);
let extraClassDeclaration = [{
@@ -1743,13 +1751,14 @@ def Tosa_GatherOp : Tosa_InferShapedTypeOp<"gather"> {
}];
let arguments = (ins
- Tosa_Tensor3D:$values,
- TosaTensorRankOf<[Tosa_Int32], [2]>:$indices
+ Tosa_Tensor3D_Extended:$values,
+ 2DTensorOf<[Tosa_Int32]>:$indices
);
let results = (outs
- Tosa_Tensor3D:$output
+ Tosa_Tensor3D_Extended:$output
);
+ let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
@@ -1764,14 +1773,15 @@ def Tosa_ScatterOp : Tosa_InferShapedTypeOp<"scatter"> {
}];
let arguments = (ins
- Tosa_Tensor3D:$values_in,
- TosaTensorRankOf<[Tosa_Int32], [2]>:$indices,
- Tosa_Tensor3D:$input
+ Tosa_Tensor3D_Extended:$values_in,
+ 2DTensorOf<[Tosa_Int32]>:$indices,
+ Tosa_Tensor3D_Extended:$input
);
let results = (outs
- Tosa_Tensor3D:$values_out
+ Tosa_Tensor3D_Extended:$values_out
);
+ let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
@@ -1828,37 +1838,66 @@ def Tosa_CastOp: Tosa_Op<"cast", [Pure,
| Mode | Input | Output |
|--------------------------|---------|---------|
- | signed 8 to bool | int8 | Boolean |
- | signed 16 to bool | int16 | Boolean |
- | signed 32 to bool | int32 | Boolean |
- | bool to 8 | Boolean | int8 |
- | bool to 16 | Boolean | int16 |
- | bool to 32 | Boolean | int32 |
- | signed 8 to signed 16 | int8 | int16 |
- | signed 8 to signed 32 | int8 | int32 |
- | signed 16 to signed 8 | int16 | int8 |
- | signed 16 to signed 32 | int16 | int32 |
- | signed 32 to signed 8 | int32 | int8 |
- | signed 32 to signed 16 | int32 | int16 |
- | float to signed 8 | float | int8 |
- | float to signed 16 | float | int16 |
- | signed 8 to float | int8 | float |
- | signed 16 to float | int16 | float |
- | float 32 to float 64 | float32 | float64 |
- | float 64 to float 32 | float64 | float32 |
- }];
-
- let arguments = (ins
- Tosa_Tensor:$input
- );
-
- let results = (outs
- Tosa_Tensor:$output
+ | bool to int 8 | Boolean | int8 |
+ | bool to int 16 | Boolean | int16 |
+ | bool to int 32 | Boolean | int32 |
+ | int 8 to bool | int8 | Boolean |
+ | int 8 to int 16 | int8 | int16 |
+ | int 8 to int 32 | int8 | int32 |
+ | int 8 to fp16 | int8 | float16 |
+ | int 8 to bf16 | int8 | bf16 |
+ | int 8 to fp32 | int8 | float32 |
+ | int 16 to bool | int16 | Boolean |
+ | int 16 to int 8 | int16 | int8 |
+ | int 16 to int 32 | int16 | int32 |
+ | int 16 to fp16 | int16 | float16 |
+ | int 16 to bf16 | int16 | bf16 |
+ | int 16 to fp32 | int16 | float32 |
+ | int 32 to bool | int32 | Boolean |
+ | int 32 to int 8 | int32 | int8 |
+ | int 32 to int 16 | int32 | int16 |
+ | int 32 to fp16 | int32 | float16 |
+ | int 32 to bf16 | int32 | bf16 |
+ | int 32 to fp32 | int32 | float32 |
+ | bf16 to int 8 | bf16 | int8 |
+ | bf16 to int 16 | bf16 | int16 |
+ | bf16 to int 32 | bf16 | int32 |
+ | bf16 to fp8e4m3 | bf16 | fp8e4m3 |
+ | bf16 to fp8e5m2 | bf16 | fp8e5m2 |
+ | bf16 to fp32 | bf16 | float32 |
+ | fp8e4m3 to fp16 | fp8e4m3 | float16 |
+ | fp8e4m3 to bf16 | fp8e4m3 | bf16 |
+ | fp8e4m3 to fp32 | fp8e4m3 | float32 |
+ | fp8e5m2 to fp16 | fp8e5m2 | float16 |
+ | fp8e5m2 to bf16 | fp8e5m2 | bf16 |
+ | fp8e5m2 to fp32 | fp8e5m2 | float32 |
+ | fp16 to int 8 | float16 | int8 |
+ | fp16 to int 16 | float16 | int16 |
+ | fp16 to int 32 | float16 | int32 |
+ | fp16 to fp8e4m3 | float16 | fp8e4m3 |
+ | fp16 to fp8e5m2 | float16 | fp8e5m2 |
+ | fp16 to fp32 | float16 | float32 |
+ | fp32 to int 8 | float32 | int8 |
+ | fp32 to int 16 | float32 | int16 |
+ | fp32 to int 32 | float32 | int32 |
+ | fp32 to fp8e4m3 | float32 | fp8e4m3 |
+ | fp32 to fp8e5m2 | float32 | fp8e5m2 |
+ | fp32 to bf16 | float32 | bf16 |
+ | fp32 to fp16 | float32 | float16 |
+ }];
+
+ let arguments = (ins
+ TensorOf<[AnyTypeOf<[Tosa_AnyNumber_Extended, F64]>]>:$input
+ );
+
+ let results = (outs
+ TensorOf<[AnyTypeOf<[Tosa_AnyNumber_Extended, F64]>]>:$output
);
let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
let hasFolder = 1;
+ let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
@@ -1940,7 +1979,7 @@ def Tosa_ConstOp : Tosa_Op<"const", [ConstantLike, Pure,
);
let results = (outs
- TosaTensorOf<[AnyTypeOf<[Tosa_AnyNumber]>]>:$output
+ TensorOf<[AnyTypeOf<[Tosa_AnyNumber_Extended, F64, Tosa_Int4]>]>:$output
);
let hasFolder = 1;
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
index cf6ddc66f4ada..2c6e647ae32fd 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
@@ -74,16 +74,25 @@ def Tosa_QuantizedInt : AnyTypeOf<[Tosa_QuantizedType<"uint8", [8], 0>,
Tosa_QuantizedType<"int16", [16, 0], 1>,
Tosa_QuantizedType<"int32", [32, 0], 1>]>;
+def Tosa_F8 : AnyTypeOf<[
+ F8E4M3FN,
+ F8E5M2]>;
+
//===----------------------------------------------------------------------===//
// Multi-category types.
//===----------------------------------------------------------------------===//
def Tosa_AnyNumber : AnyTypeOf<[Tosa_Int, Tosa_QuantizedInt, AnyFloat],
"number">;
+// Add F8 type support to Tosa_AnyNumber
+def Tosa_AnyNumber_Extended : AnyTypeOf<[Tosa_AnyNumber, Tosa_F8],
+ "number_extended">;
+
// For weight tensors from tosa::Conv2DOp, tosa::Conv3DOp,
// tosa::DepthwiseConv2DOp, tosa::TransposeConv2DOp
def Tosa_Weight : AnyTypeOf<[Tosa_Int4, Tosa_Int8,
- Tosa_QuantizedInt, AnyFloat]>;
+ Tosa_QuantizedInt, AnyFloat, Tosa_F8]>;
+
//===----------------------------------------------------------------------===//
// TOSA Tensor Conformance
@@ -130,9 +139,11 @@ def Tosa_FloatTensor : TosaTensorOf<[AnyFloat]>;
// Either ranked or unranked tensor of TOSA supported element types.
def Tosa_Tensor : TosaTensorOf<[Tosa_AnyNumber]>;
+def Tosa_Tensor_Extended : TosaTensorOf<[Tosa_AnyNumber_Extended]>;
// Must be ranked but no further constraints
-def Tosa_RankedTensor : TosaRankedTensorOf<[Tosa_AnyNumber]>;
+def Tosa_RankedTensor : RankedTensorOf<[Tosa_AnyNumber]>;
+def Tosa_RankedTensor_Extended : RankedTensorOf<[Tosa_AnyNumber_Extended]>;
// Any tensor element type allowed in Tosa ops.
def Tosa_ElementType : Type<Or<[Tosa_Int.predicate, Tosa_QuantizedInt.predicate,
@@ -145,9 +156,9 @@ class Tosa_TensorOfOrNone<list<Type> allowedTypes, string description = ""> :
// Tensor types with constrained ranks.
//===----------------------------------------------------------------------===//
-def Tosa_Rank0Tensor : TosaTensorRankOf<[Tosa_AnyNumber], [0]>;
-
+// Scalar tensors: Rank-1 (with only one element)
def Tosa_ScalarTensor : TosaScalarTensorOf<[Tosa_AnyNumber], [1]>;
+def Tosa_ScalarTensor_Extended : TosaScalarTensorOf<[Tosa_AnyNumber_Extended], [1]>;
def Tosa_ScalarInt8Tensor : TosaScalarTensorOf<[Tosa_Int8], [1]>;
// We include unranked tensors as a supported type for all possible tosa
@@ -155,6 +166,7 @@ def Tosa_ScalarInt8Tensor : TosaScalarTensorOf<[Tosa_Int8], [1]>;
// they should be shape propagate used Tosa's shape inference pass and verified
// to not include any remaining unranked tensors.
def Tosa_UnrankedTensor : TosaUnrankedTensorOf<[Tosa_AnyNumber]>;
+def Tosa_UnrankedTensorExtended : TosaUnrankedTensorOf<[Tosa_AnyNumber_Extended]>;
def Tosa_Tensor1D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [1]>], "1-d tosa-conformant tensor", "::mlir::TensorType">;
def Tosa_Tensor2D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [2]>], "2-d tosa-conformant tensor", "::mlir::TensorType">;
@@ -162,6 +174,17 @@ def Tosa_Tensor3D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNu
def Tosa_Tensor4D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [4]>], "4-d tosa-conformant tensor", "::mlir::TensorType">;
def Tosa_Tensor5D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [5]>], "5-d tosa-conformant tensor", "::mlir::TensorType">;
+def Tosa_Tensor1D_Extended : AnyTypeOf<[Tosa_UnrankedTensorExtended, TosaTensorRankOf<[Tosa_AnyNumber_Extended], [1]>],
+ "1-d tosa-conformant tensor extended", "::mlir::TensorType">;
+def Tosa_Tensor2D_Extended : AnyTypeOf<[Tosa_UnrankedTensorExtended, TosaTensorRankOf<[Tosa_AnyNumber_Extended], [2]>],
+ "2-d tosa-conformant tensor extended", "::mlir::TensorType">;
+def Tosa_Tensor3D_Extended : AnyTypeOf<[Tosa_UnrankedTensorExtended, TosaTensorRankOf<[Tosa_AnyNumber_Extended], [3]>],
+ "3-d tosa-conformant tensor extended", "::mlir::TensorType">;
+def Tosa_Tensor4D_Extended : AnyTypeOf<[Tosa_UnrankedTensorExtended, TosaTensorRankOf<[Tosa_AnyNumber_Extended], [4]>],
+ "4-d tosa-conformant tensor extended", "::mlir::TensorType">;
+def Tosa_Tensor5D_Extended : AnyTypeOf<[Tosa_UnrankedTensorExtended, TosaTensorRankOf<[Tosa_AnyNumber_Extended], [5]>],
+ "5-d tosa-conformant tensor extended", "::mlir::TensorType">;
+
// Ranked tensors up to given rank.
def Tosa_Tensor1Dto4D : AnyTypeOf<[
Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [1,2,3,4]>]>;
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 69b3f6d674167..704f8a82d11fa 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -65,12 +65,12 @@ void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,
}
LogicalResult SelectOp::canonicalize(SelectOp op, PatternRewriter &rewriter) {
- auto notOp = op.getPred().getDefiningOp<tosa::LogicalNotOp>();
+ auto notOp = op.getInput1().getDefiningOp<tosa::LogicalNotOp>();
if (!notOp)
return failure();
rewriter.modifyOpInPlace(op, [&]() {
op.getOperation()->setOperands(
- {notOp.getInput1(), op.getOnFalse(), op.getOnTrue()});
+ {notOp.getInput1(), op.getInput3(), op.getInput2()});
});
return success();
}
@@ -1118,18 +1118,18 @@ OpFoldResult SliceOp::fold(FoldAdaptor adaptor) {
}
OpFoldResult tosa::SelectOp::fold(FoldAdaptor adaptor) {
- if (getOnTrue() == getOnFalse())
- return getOnTrue();
+ if (getInput2() == getInput3())
+ return getInput2();
auto predicate =
- llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getPred());
+ llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getInput1());
if (!predicate)
return {};
if (!predicate.isSplat())
return {};
- return predicate.getSplatValue<APInt>().getBoolValue() ? getOnTrue()
- : getOnFalse();
+ return predicate.getSplatValue<APInt>().getBoolValue() ? getInput2()
+ : getInput3();
}
OpFoldResult TileOp::fold(FoldAdaptor adaptor) {
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 67021d6c07401..411f06f4a0b7c 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -217,15 +217,17 @@ void mlir::tosa::printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type,
template <typename T>
static LogicalResult verifyConvOp(T op) {
- // All TOSA conv ops have an input and weight arguments which must be ranked
- // tensors.
+ // All TOSA conv ops have an input() and weight().
auto inputType = llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
+
+ Ra...
[truncated]
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
Add FP8 support to following TOSA operators: ARGMAX AVGPOOL CONV2D CONV3D DEPTHWISE_CONV2D MATMUL MAX_POOL2D TRANSPOSE_CONV2D CONST CAST CONCAT PAD DIM RESHAPE REVERSE SLICE TILE TRANSPOSE GATHER SCATTER Also added verifiers as needed to check input/output element types and renamed inputs of transpose_conv2d and select to match spec. Signed-off-by: Tai Ly <[email protected]> Signed-off-by: Jerry Ge <[email protected]> Change-Id: I56adfabb2396e38b7ed3479e4fd680b740bdb4e4
//===----------------------------------------------------------------------===// | ||
// Multi-category types. | ||
//===----------------------------------------------------------------------===// | ||
def Tosa_AnyNumber : AnyTypeOf<[Tosa_Int, Tosa_QuantizedInt, AnyFloat], | ||
"number">; | ||
|
||
// Add F8 type support to Tosa_AnyNumber | ||
def Tosa_AnyNumber_Extended : AnyTypeOf<[Tosa_AnyNumber, Tosa_F8], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need the extended type here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The goal here is to differentiate between operators that take FP8 and those that do not take FP8
.
Another way is to do the following:
- define
Tosa_AnyNumber
to include every dtypes. - define something like
Tosa_AnyNumber_Exclude_FP8
to remove FP8 from that set.
But this is no better than the existing solution.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tosa_AnyNumber
already uses AnyFloat
so I believe it could already support FP8?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, right. We didn't have this a year ago.
def Tosa_ScalarTensor : TosaScalarTensorOf<[Tosa_AnyNumber], [1]>; | ||
def Tosa_ScalarTensor_Extended : TosaScalarTensorOf<[Tosa_AnyNumber_Extended], [1]>; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here. We need to replicate extended semantics everywhere. Wouldn't it be cleaner to allow all types and restrict on validation passes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same reason above.
auto inputType = llvm::dyn_cast<RankedTensorType>(op.getInput().getType()); | ||
|
||
RankedTensorType weightType; | ||
weightType = llvm::dyn_cast<RankedTensorType>(op.getWeight().getType()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why move it up and not declare it before use?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think there's any particular reason. We can move it down if you prefer that way.
llvm::isa<Float8E4M3FNType>(inputETy)) && | ||
!accType.isF16()) | ||
return emitOpError("accumulator type for f8 tensor is not f16"); | ||
|
||
if ((inputETy.isF32() && resultETy.isF32()) || |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we plan to keep growing this? We could probably simplify it a bit
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see we'll grow this in the short term. It reads clear to me. If we refactor this into another function, it takes additional time for me to find that function and check.
} | ||
|
||
// input element type: bool | ||
if (inputETy.isInteger(1)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we do this with a lambda and reuse?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can. But how much code do we really save here? This check is simple enough and easy to read. I think adding a lambda is extra-overhead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
change refactored to this patch: #127923
@@ -849,6 +969,18 @@ LogicalResult tosa::ConcatOp::inferReturnTypeComponents( | |||
return success(); | |||
} | |||
|
|||
LogicalResult tosa::ConcatOp::verify() { | |||
// check that each input has same element type as output | |||
auto outType = getOutput().getType(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use llvm::all_of?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, we can do that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
change refactored to this patch: #127923
@@ -1238,6 +1487,11 @@ LogicalResult tosa::TileOp::inferReturnTypeComponents( | |||
} | |||
|
|||
LogicalResult tosa::TileOp::verify() { | |||
if (verifySameElementTypes(*this, /* intype = */ getInput1().getType(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are all these verifiers part of the fp8 support? It doesn't feel like it. Can we pull into a separate patch?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it doesn't look like. I will put them into other patches.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
change refactored to this patch: #127923
@@ -117,7 +117,7 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> { | |||
|
|||
Value padSizeVal = getTosaConstShape(rewriter, op->getLoc(), pad); | |||
|
|||
auto padTy = RankedTensorType::get({}, inputETy); | |||
auto padTy = RankedTensorType::get({1}, inputETy); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not an fp8 related change?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ack.
if ((llvm::isa<Float8E5M2Type>(inputETy) || | ||
llvm::isa<Float8E4M3FNType>(inputETy)) && | ||
!accType.isF16()) | ||
return emitOpError("accumulator type for f8 tensor is not f16"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These seems to restrictive. max value for Float8E5M2 type is 57344. Max value for fp16 accumulator is only around ~65k. Fp8 requires Fp32 accumulator not fp16.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(added a related comment above before I saw this one: https://github.com/llvm/llvm-project/pull/127730/files#r1962023091)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My concern is about validity of having fp16 accumulator not where we do that validation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @umangyadav. This is a reasonable request. We'll add another FP32 accumulator type likely in 1.1.
if (inputEType != weightEType) { | ||
op.emitOpError( | ||
"expect both input and weight to have same element type, got ") | ||
<< inputEType << " and " << weightEType; | ||
return failure(); | ||
} | ||
|
||
if (!resultEType.isF16()) { | ||
op.emitOpError("expect bias and result element type to be f16, got ") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should move some of these checks to the validation pass in order not to make the dialect too restrictive
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea. Since we currently don't have anything for dtype checkings here: https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp. Feels like it should be another separate patch.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These checks in verify functions prevent construction of ops with invalid data types.
moving the checking to the validation pass will mean we allow construction of invalid data types, and then when validation pass is run, we disallow these ops. This may free up the dialect too much?
} | ||
} | ||
// input element type: int8 | ||
if (inputETy.isInteger(8)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as above (referring to moving to the validation pass in order to not make the dialect too restrictive)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
Add FP8 support to following TOSA operators:
TOSA FP8 Extensions: https://mltech.ml.arm.com/spec/tosa/nightly/tosa_spec.html#_ext_fp8e4m3_extension
ARGMAX
AVGPOOL
CONV2D
CONV3D
DEPTHWISE_CONV2D
MATMUL
MAX_POOL2D
TRANSPOSE_CONV2D
CONST
CAST
CONCAT
PAD
DIM
RESHAPE
REVERSE
SLICE
TILE
TRANSPOSE
GATHER
SCATTER
Also added verifiers as needed to check input/output element types and renamed inputs of transpose_conv2d and select to match spec.