Skip to content

Commit

Permalink
[mlir][tosa] Remove FullyConnectedOp from TOSA Dialect (llvm#126152)
Browse files Browse the repository at this point in the history
This patch removes FullyConncected Operator from the TOSA Dialect and
all associated tests and transforms.

This is part of the TOSA v1.0 alignment effort:
https://discourse.llvm.org/t/rfc-tosa-dialect-increment-to-v1-0/83708

Signed-off-by: Tai Ly <[email protected]>
Co-authored-by: Tai Ly <[email protected]>
  • Loading branch information
2 people authored and joaosaffran committed Feb 14, 2025
1 parent 4a94aa2 commit 79fdb8c
Show file tree
Hide file tree
Showing 17 changed files with 4 additions and 651 deletions.
9 changes: 0 additions & 9 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -150,15 +150,6 @@ def Tosa_TransConvOpQuantInfoBuilder : OpBuilder<
outputShape, acc_type);
}]>;

// The tosa.fully_connected op has its own builder as it does not have
// strides/dilation/padding.
def Tosa_FCOpQuantInfoBuilder : OpBuilder<
(ins "Type":$outputType, "Value":$input, "Value":$weight, "Value":$bias),
[{
buildFCOpWithQuantInfo($_builder, $_state, outputType,
input, weight, bias);
}]>;

// The tosa.matmul op is also intended to be generated where a fully_connected
// op must be constructed where the weight is not a constant. In this case,
// the fully_connected op must be expressed using matmul.
Expand Down
26 changes: 0 additions & 26 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -224,32 +224,6 @@ def Tosa_FFT2dOp : Tosa_InferShapedTypeOp<"fft2d"> {
}];
}

//===----------------------------------------------------------------------===//
// Operator: fully_connected
//===----------------------------------------------------------------------===//
def Tosa_FullyConnectedOp : Tosa_InferShapedTypeOp<"fully_connected"> {
let summary = "Fully Connected operator";

let description = [{
Performs a fully connected network.
}];

let arguments = (ins
Tosa_Tensor2D:$input,
TosaTensorRankOf<[Tosa_Weight], [2]>:$weight,
Tosa_Tensor1D:$bias,
OptionalAttr<I32Attr>:$input_zp,
OptionalAttr<I32Attr>:$weight_zp
);

let results = (outs
Tosa_Tensor2D:$output
);

let builders = [Tosa_FCOpQuantInfoBuilder];
let hasVerifier = 1;
}

//===----------------------------------------------------------------------===//
// Operator: matmul
//===----------------------------------------------------------------------===//
Expand Down
2 changes: 1 addition & 1 deletion mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def Tosa_AnyNumber : AnyTypeOf<[Tosa_Int, Tosa_QuantizedInt, AnyFloat],
"number">;

// For weight tensors from tosa::Conv2DOp, tosa::Conv3DOp,
// tosa::DepthwiseConv2DOp, tosa::TransposeConv2DOp, tosa::FullyConnectedOp
// tosa::DepthwiseConv2DOp, tosa::TransposeConv2DOp
def Tosa_Weight : AnyTypeOf<[Tosa_Int4, Tosa_Int8,
Tosa_QuantizedInt, AnyFloat]>;

Expand Down
1 change: 0 additions & 1 deletion mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ namespace tosa {

// Expose Rewrite Functions that decompose TOSA Ops into further TOSA Ops.
// The rewrites can be selectively added to a conversion pass.
void populateTosaDecomposeConv2D(MLIRContext *ctx, RewritePatternSet &patterns);
void populateTosaDecomposeTransposeConv(MLIRContext *ctx,
RewritePatternSet &patterns);
void populateTosaDecomposeDepthwise(MLIRContext *ctx,
Expand Down
79 changes: 0 additions & 79 deletions mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -607,84 +607,6 @@ class MatMulConverter : public OpConversionPattern<tosa::MatMulOp> {
}
};

class FullyConnectedConverter
: public OpConversionPattern<tosa::FullyConnectedOp> {
public:
using OpConversionPattern<tosa::FullyConnectedOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(tosa::FullyConnectedOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
Location loc = op.getLoc();
auto outputTy = cast<ShapedType>(op.getType());
auto input = op.getInput();
auto inputTy = cast<ShapedType>(input.getType());

auto bias = op.getBias();

auto weight = op.getWeight();
auto weightTy = cast<ShapedType>(weight.getType());
auto weightShape = weightTy.getShape();

auto outputETy = outputTy.getElementType();

SmallVector<Value> dynDims;
dynDims.resize(cast<ShapedType>(op->getResult(0).getType()).getRank());

if (!inputTy.hasRank() || inputTy.isDynamicDim(0)) {
dynDims[0] = rewriter.create<tensor::DimOp>(loc, input, 0);
}

if (!weightTy.hasRank() || weightTy.isDynamicDim(0)) {
dynDims[1] = rewriter.create<tensor::DimOp>(loc, weight, 0);
}

SmallVector<Value> filteredDims = condenseValues(dynDims);

SmallVector<int64_t> permutation = {1, 0};
auto permutationAttr = rewriter.getI64TensorAttr(permutation);
Value permutationValue =
rewriter.create<arith::ConstantOp>(loc, permutationAttr);

SmallVector<int64_t> newWeightShape = {weightShape[1], weightShape[0]};
Type newWeightTy =
RankedTensorType::get(newWeightShape, weightTy.getElementType());

Value transposedWeight = rewriter.create<tosa::TransposeOp>(
loc, newWeightTy, weight, permutationValue);

Value biasEmptyTensor = rewriter.create<tensor::EmptyOp>(
loc, outputTy.getShape(), outputETy, filteredDims);

Value broadcastBias =
linalgBroadcastAndMaybeExtSI(rewriter, loc, bias, biasEmptyTensor);

if (!op.getInputZp() && !op.getWeightZp()) {
Value matmul = rewriter
.create<linalg::MatmulOp>(
loc, TypeRange{op.getType()},
ValueRange{input, transposedWeight}, broadcastBias)
->getResult(0);

rewriter.replaceOp(op, matmul);
return success();
}

auto inputZp = rewriter.create<arith::ConstantOp>(loc, op.getInputZpAttr());
auto outputZp =
rewriter.create<arith::ConstantOp>(loc, op.getWeightZpAttr());
Value matmul =
rewriter
.create<linalg::QuantizedMatmulOp>(
loc, TypeRange{op.getType()},
ValueRange{input, transposedWeight, inputZp, outputZp},
broadcastBias)
->getResult(0);

rewriter.replaceOp(op, matmul);
return success();
}
};

class MaxPool2dConverter : public OpConversionPattern<tosa::MaxPool2dOp> {
public:
using OpConversionPattern::OpConversionPattern;
Expand Down Expand Up @@ -1090,7 +1012,6 @@ void mlir::tosa::populateTosaToLinalgNamedConversionPatterns(
DepthwiseConvConverter,
MatMulConverter,
AvgPool2dConverter,
FullyConnectedConverter,
TransposeConverter
>(patterns->getContext());

Expand Down
1 change: 0 additions & 1 deletion mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ struct TosaToLinalgNamed
target.addIllegalOp<tosa::MaxPool2dOp>();
target.addIllegalOp<tosa::AvgPool2dOp>();
target.addIllegalOp<tosa::MatMulOp>();
target.addIllegalOp<tosa::FullyConnectedOp>();
target.addIllegalOp<tosa::TransposeOp>();

target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
Expand Down
93 changes: 3 additions & 90 deletions mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -566,26 +566,9 @@ static void buildTransConvOpWithQuantInfo(
result.addTypes(finalOutputType);
}

/// The tosa.fully_connected op has its own builder as it does not have
/// strides/dilation/padding.
static void buildFCOpWithQuantInfo(OpBuilder &builder, OperationState &result,
Type outputType, Value input, Value weight,
Value bias) {

result.addOperands({input, weight, bias});
auto quantAttr = ::buildConvOpQuantizationAttr(builder, input, weight);
if (quantAttr) {
result.addAttribute("quantization_info", quantAttr);
result.addTypes(
buildConvOpResultTypeInfo(builder, outputType, input, weight));
} else {
result.addTypes(outputType);
}
}

/// The tosa.matmul op is also intended to be generated where a
/// fully_connected op must be constructed where the weight is not a constant.
/// In this case, the fully_connected op must be expressed using matmul.
/// The tosa.matmul op is also intended to be generated where a fully_connected
/// op must be constructed where the weight is not a constant. In this case,
/// the fully_connected op must be expressed using matmul.
/// TODO: Add link to the leglization document explaining this.
static void buildMatMulOpWithQuantInfo(OpBuilder &builder,
OperationState &result, Type outputType,
Expand Down Expand Up @@ -889,76 +872,6 @@ bool tosa::EqualOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
return succeeded(verifyCompatibleShape(l[0], r[0]));
}

LogicalResult tosa::FullyConnectedOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
FullyConnectedOp::Adaptor adaptor,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
ShapeAdaptor inputShape(adaptor.getInput().getType());
ShapeAdaptor weightShape(adaptor.getWeight().getType());
ShapeAdaptor biasShape(adaptor.getBias().getType());

// All shapes are dynamic.
SmallVector<int64_t> outShape;
outShape.resize(2, ShapedType::kDynamic);

if (inputShape.hasRank()) {
outShape[0] = inputShape.getDimSize(0);
}

if (weightShape.hasRank()) {
outShape[1] = weightShape.getDimSize(0);
}

if (biasShape.hasRank()) {
outShape[1] = outShape[1] == ShapedType::kDynamic ? biasShape.getDimSize(0)
: outShape[1];
}

inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
return success();
}

LogicalResult FullyConnectedOp::verify() {
// All TOSA conv ops have an input() and weight().
auto inputType = llvm::dyn_cast<RankedTensorType>(getInput().getType());

RankedTensorType weightType =
llvm::dyn_cast<RankedTensorType>(getWeight().getType());

// Must be ranked tensor types
if (!inputType) {
emitOpError("expect a ranked tensor for input, got ") << getInput();
return failure();
}
if (!weightType) {
emitOpError("expect a ranked tensor for weight, got ") << getWeight();
return failure();
}

auto inputEType = inputType.getElementType();
auto weightEType = weightType.getElementType();

bool inputIsQuant = !llvm::isa<FloatType>(inputEType);
bool weightIsQuant = !llvm::isa<FloatType>(weightEType);

// Either both must be quantized or both unquantized.
if (inputIsQuant != weightIsQuant) {
emitOpError(
"expect both input and weight to be float or not together, got ")
<< inputEType << " and " << weightEType;
return failure();
}

// Quantized type must have constructed the quantizationattr, and unquantized
// types should not have a quantizationattr.
if ((inputIsQuant && !getInputZp()) || (!inputIsQuant && getInputZp())) {
emitOpError("input zero point is required for quantized type, and not "
"allowed for float type");
return failure();
}
return success();
}

LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
MatMulOp::Adaptor adaptor,
Expand Down
1 change: 0 additions & 1 deletion mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
add_mlir_dialect_library(MLIRTosaTransforms
TosaDecomposeTransposeConv.cpp
TosaDecomposeConv2D.cpp
TosaDecomposeDepthwise.cpp
TosaFolders.cpp
TosaInferShapes.cpp
Expand Down
Loading

0 comments on commit 79fdb8c

Please sign in to comment.