Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][tosa] Remove FullyConnectedOp from TOSA Dialect #126152

Merged
merged 1 commit into from
Feb 13, 2025

Conversation

Jerry-Ge
Copy link
Member

@Jerry-Ge Jerry-Ge commented Feb 6, 2025

This patch removes FullyConncected Operator from the TOSA Dialect and all associated tests and transforms.

This is part of the TOSA v1.0 alignment effort: https://discourse.llvm.org/t/rfc-tosa-dialect-increment-to-v1-0/83708

@llvmbot
Copy link
Member

llvmbot commented Feb 6, 2025

@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-mlir-tosa

Author: Jerry-Ge (Jerry-Ge)

Changes

This patch removes FullyConncected Operator from the TOSA Dialect and all associated tests and transforms.

This is part of the TOSA v1.0 alignment effort: https://discourse.llvm.org/t/rfc-tosa-dialect-increment-to-v1-0/83708


Patch is 38.65 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/126152.diff

17 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td (-9)
  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (-26)
  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td (+1-1)
  • (modified) mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h (-1)
  • (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp (-79)
  • (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp (-1)
  • (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+3-90)
  • (modified) mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt (-1)
  • (removed) mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp (-164)
  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaOptionalDecompositions.cpp (-1)
  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp (-14)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir (-71)
  • (modified) mlir/test/Dialect/Tosa/invalid.mlir (-20)
  • (modified) mlir/test/Dialect/Tosa/ops.mlir (-7)
  • (removed) mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir (-76)
  • (modified) mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir (-45)
  • (removed) mlir/test/Integration/Dialect/Tosa/CPU/test-fully-connected.mlir (-36)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
index f492bad78e775ca..862d98ad436a659 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
@@ -150,15 +150,6 @@ def Tosa_TransConvOpQuantInfoBuilder : OpBuilder<
                                   outputShape, acc_type);
   }]>;
 
-// The tosa.fully_connected op has its own builder as it does not have
-// strides/dilation/padding.
-def Tosa_FCOpQuantInfoBuilder : OpBuilder<
-  (ins "Type":$outputType, "Value":$input, "Value":$weight, "Value":$bias),
-  [{
-    buildFCOpWithQuantInfo($_builder, $_state, outputType,
-                           input, weight, bias);
-  }]>;
-
 // The tosa.matmul op is also intended to be generated where a fully_connected
 // op must be constructed where the weight is not a constant. In this case,
 // the fully_connected op must be expressed using matmul.
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 98bcbca3b02fa12..74b93bcb371fa47 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -224,32 +224,6 @@ def Tosa_FFT2dOp : Tosa_InferShapedTypeOp<"fft2d"> {
   }];
 }
 
-//===----------------------------------------------------------------------===//
-// Operator: fully_connected
-//===----------------------------------------------------------------------===//
-def Tosa_FullyConnectedOp : Tosa_InferShapedTypeOp<"fully_connected"> {
-  let summary = "Fully Connected operator";
-
-  let description = [{
-    Performs a fully connected network.
-  }];
-
-  let arguments = (ins
-    Tosa_Tensor2D:$input,
-    TosaTensorRankOf<[Tosa_Weight], [2]>:$weight,
-    Tosa_Tensor1D:$bias,
-    OptionalAttr<I32Attr>:$input_zp,
-    OptionalAttr<I32Attr>:$weight_zp
-  );
-
-  let results = (outs
-    Tosa_Tensor2D:$output
-  );
-
-  let builders = [Tosa_FCOpQuantInfoBuilder];
-  let hasVerifier = 1;
-}
-
 //===----------------------------------------------------------------------===//
 // Operator: matmul
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
index 7aa1f72ec6e1792..6d5cf9182736f0a 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
@@ -81,7 +81,7 @@ def Tosa_AnyNumber : AnyTypeOf<[Tosa_Int, Tosa_QuantizedInt, AnyFloat],
                                 "number">;
 
 // For weight tensors from tosa::Conv2DOp, tosa::Conv3DOp,
-// tosa::DepthwiseConv2DOp, tosa::TransposeConv2DOp, tosa::FullyConnectedOp
+// tosa::DepthwiseConv2DOp, tosa::TransposeConv2DOp
 def Tosa_Weight : AnyTypeOf<[Tosa_Int4, Tosa_Int8,
                              Tosa_QuantizedInt, AnyFloat]>;
 
diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
index 1f9522b51a4cf5c..565970367e5dc56 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
@@ -26,7 +26,6 @@ namespace tosa {
 
 // Expose Rewrite Functions that decompose TOSA Ops into further TOSA Ops.
 // The rewrites can be selectively added to a conversion pass.
-void populateTosaDecomposeConv2D(MLIRContext *ctx, RewritePatternSet &patterns);
 void populateTosaDecomposeTransposeConv(MLIRContext *ctx,
                                         RewritePatternSet &patterns);
 void populateTosaDecomposeDepthwise(MLIRContext *ctx,
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index 6321cb6087394a3..a8fd536dd254842 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -607,84 +607,6 @@ class MatMulConverter : public OpConversionPattern<tosa::MatMulOp> {
   }
 };
 
-class FullyConnectedConverter
-    : public OpConversionPattern<tosa::FullyConnectedOp> {
-public:
-  using OpConversionPattern<tosa::FullyConnectedOp>::OpConversionPattern;
-  LogicalResult
-  matchAndRewrite(tosa::FullyConnectedOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const final {
-    Location loc = op.getLoc();
-    auto outputTy = cast<ShapedType>(op.getType());
-    auto input = op.getInput();
-    auto inputTy = cast<ShapedType>(input.getType());
-
-    auto bias = op.getBias();
-
-    auto weight = op.getWeight();
-    auto weightTy = cast<ShapedType>(weight.getType());
-    auto weightShape = weightTy.getShape();
-
-    auto outputETy = outputTy.getElementType();
-
-    SmallVector<Value> dynDims;
-    dynDims.resize(cast<ShapedType>(op->getResult(0).getType()).getRank());
-
-    if (!inputTy.hasRank() || inputTy.isDynamicDim(0)) {
-      dynDims[0] = rewriter.create<tensor::DimOp>(loc, input, 0);
-    }
-
-    if (!weightTy.hasRank() || weightTy.isDynamicDim(0)) {
-      dynDims[1] = rewriter.create<tensor::DimOp>(loc, weight, 0);
-    }
-
-    SmallVector<Value> filteredDims = condenseValues(dynDims);
-
-    SmallVector<int64_t> permutation = {1, 0};
-    auto permutationAttr = rewriter.getI64TensorAttr(permutation);
-    Value permutationValue =
-        rewriter.create<arith::ConstantOp>(loc, permutationAttr);
-
-    SmallVector<int64_t> newWeightShape = {weightShape[1], weightShape[0]};
-    Type newWeightTy =
-        RankedTensorType::get(newWeightShape, weightTy.getElementType());
-
-    Value transposedWeight = rewriter.create<tosa::TransposeOp>(
-        loc, newWeightTy, weight, permutationValue);
-
-    Value biasEmptyTensor = rewriter.create<tensor::EmptyOp>(
-        loc, outputTy.getShape(), outputETy, filteredDims);
-
-    Value broadcastBias =
-        linalgBroadcastAndMaybeExtSI(rewriter, loc, bias, biasEmptyTensor);
-
-    if (!op.getInputZp() && !op.getWeightZp()) {
-      Value matmul = rewriter
-                         .create<linalg::MatmulOp>(
-                             loc, TypeRange{op.getType()},
-                             ValueRange{input, transposedWeight}, broadcastBias)
-                         ->getResult(0);
-
-      rewriter.replaceOp(op, matmul);
-      return success();
-    }
-
-    auto inputZp = rewriter.create<arith::ConstantOp>(loc, op.getInputZpAttr());
-    auto outputZp =
-        rewriter.create<arith::ConstantOp>(loc, op.getWeightZpAttr());
-    Value matmul =
-        rewriter
-            .create<linalg::QuantizedMatmulOp>(
-                loc, TypeRange{op.getType()},
-                ValueRange{input, transposedWeight, inputZp, outputZp},
-                broadcastBias)
-            ->getResult(0);
-
-    rewriter.replaceOp(op, matmul);
-    return success();
-  }
-};
-
 class MaxPool2dConverter : public OpConversionPattern<tosa::MaxPool2dOp> {
 public:
   using OpConversionPattern::OpConversionPattern;
@@ -1090,7 +1012,6 @@ void mlir::tosa::populateTosaToLinalgNamedConversionPatterns(
       DepthwiseConvConverter,
       MatMulConverter,
       AvgPool2dConverter,
-      FullyConnectedConverter,
       TransposeConverter
   >(patterns->getContext());
 
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp
index 7d943b3779fb02e..80df3908991dde2 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp
@@ -62,7 +62,6 @@ struct TosaToLinalgNamed
     target.addIllegalOp<tosa::MaxPool2dOp>();
     target.addIllegalOp<tosa::AvgPool2dOp>();
     target.addIllegalOp<tosa::MatMulOp>();
-    target.addIllegalOp<tosa::FullyConnectedOp>();
     target.addIllegalOp<tosa::TransposeOp>();
 
     target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 031c279ff09e275..be54319e64fde68 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -540,26 +540,9 @@ static void buildTransConvOpWithQuantInfo(
   }
 }
 
-/// The tosa.fully_connected op has its own builder as it does not have
-/// strides/dilation/padding.
-static void buildFCOpWithQuantInfo(OpBuilder &builder, OperationState &result,
-                                   Type outputType, Value input, Value weight,
-                                   Value bias) {
-
-  result.addOperands({input, weight, bias});
-  auto quantAttr = ::buildConvOpQuantizationAttr(builder, input, weight);
-  if (quantAttr) {
-    result.addAttribute("quantization_info", quantAttr);
-    result.addTypes(
-        buildConvOpResultTypeInfo(builder, outputType, input, weight));
-  } else {
-    result.addTypes(outputType);
-  }
-}
-
-/// The tosa.matmul op is also intended to be generated where a
-/// fully_connected op must be constructed where the weight is not a constant.
-/// In this case, the fully_connected op must be expressed using matmul.
+/// The tosa.matmul op is also intended to be generated where a fully_connected
+/// op must be constructed where the weight is not a constant. In this case,
+/// the fully_connected op must be expressed using matmul.
 /// TODO: Add link to the leglization document explaining this.
 static void buildMatMulOpWithQuantInfo(OpBuilder &builder,
                                        OperationState &result, Type outputType,
@@ -863,76 +846,6 @@ bool tosa::EqualOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
   return succeeded(verifyCompatibleShape(l[0], r[0]));
 }
 
-LogicalResult tosa::FullyConnectedOp::inferReturnTypeComponents(
-    MLIRContext *context, ::std::optional<Location> location,
-    FullyConnectedOp::Adaptor adaptor,
-    SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
-  ShapeAdaptor inputShape(adaptor.getInput().getType());
-  ShapeAdaptor weightShape(adaptor.getWeight().getType());
-  ShapeAdaptor biasShape(adaptor.getBias().getType());
-
-  // All shapes are dynamic.
-  SmallVector<int64_t> outShape;
-  outShape.resize(2, ShapedType::kDynamic);
-
-  if (inputShape.hasRank()) {
-    outShape[0] = inputShape.getDimSize(0);
-  }
-
-  if (weightShape.hasRank()) {
-    outShape[1] = weightShape.getDimSize(0);
-  }
-
-  if (biasShape.hasRank()) {
-    outShape[1] = outShape[1] == ShapedType::kDynamic ? biasShape.getDimSize(0)
-                                                      : outShape[1];
-  }
-
-  inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
-  return success();
-}
-
-LogicalResult FullyConnectedOp::verify() {
-  // All TOSA conv ops have an input() and weight().
-  auto inputType = llvm::dyn_cast<RankedTensorType>(getInput().getType());
-
-  RankedTensorType weightType =
-      llvm::dyn_cast<RankedTensorType>(getWeight().getType());
-
-  // Must be ranked tensor types
-  if (!inputType) {
-    emitOpError("expect a ranked tensor for input, got ") << getInput();
-    return failure();
-  }
-  if (!weightType) {
-    emitOpError("expect a ranked tensor for weight, got ") << getWeight();
-    return failure();
-  }
-
-  auto inputEType = inputType.getElementType();
-  auto weightEType = weightType.getElementType();
-
-  bool inputIsQuant = !llvm::isa<FloatType>(inputEType);
-  bool weightIsQuant = !llvm::isa<FloatType>(weightEType);
-
-  // Either both must be quantized or both unquantized.
-  if (inputIsQuant != weightIsQuant) {
-    emitOpError(
-        "expect both input and weight to be float or not together, got ")
-        << inputEType << " and " << weightEType;
-    return failure();
-  }
-
-  // Quantized type must have constructed the quantizationattr, and unquantized
-  // types should not have a quantizationattr.
-  if ((inputIsQuant && !getInputZp()) || (!inputIsQuant && getInputZp())) {
-    emitOpError("input zero point is required for quantized type, and not "
-                "allowed for float type");
-    return failure();
-  }
-  return success();
-}
-
 LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
     MLIRContext *context, ::std::optional<Location> location,
     MatMulOp::Adaptor adaptor,
diff --git a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
index 5b0f5ec4cd5687b..9c3345b617cc5f3 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
@@ -1,6 +1,5 @@
 add_mlir_dialect_library(MLIRTosaTransforms
   TosaDecomposeTransposeConv.cpp
-  TosaDecomposeConv2D.cpp
   TosaDecomposeDepthwise.cpp
   TosaFolders.cpp
   TosaInferShapes.cpp
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp
deleted file mode 100644
index 4eba89b59bbd79f..000000000000000
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp
+++ /dev/null
@@ -1,164 +0,0 @@
-//===- TosaDecomposeConv2D.cpp --------------------------------------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// Decompose TOSA Conv2D operation to a series of TOSA Ops specifically
-// (1) Convert a 1x1 Convolution to a Reshape->FC->Reshape
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/Tosa/IR/TosaOps.h"
-#include "mlir/Dialect/Tosa/Transforms/Passes.h"
-#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
-
-using namespace mlir;
-using namespace mlir::tosa;
-
-namespace {
-
-SmallVector<int64_t> convertFromMlirShape(ArrayRef<int64_t> shape) {
-  return to_vector(llvm::map_range(shape, [](int64_t dim) {
-    return ShapedType::isDynamic(dim) ? -1 : dim;
-  }));
-}
-
-struct Conv2DIsFullyConnected : public OpRewritePattern<tosa::Conv2DOp> {
-  explicit Conv2DIsFullyConnected(MLIRContext *context)
-      : OpRewritePattern(context) {}
-
-  LogicalResult matchAndRewrite(tosa::Conv2DOp op,
-                                PatternRewriter &rewriter) const override {
-    Value input = op.getInput();
-    Value weight = op.getWeight();
-    ShapedType inputType = cast<ShapedType>(input.getType());
-    ShapedType weightType = cast<ShapedType>(weight.getType());
-    ShapedType resultType = cast<ShapedType>(op.getType());
-
-    auto numDynamic =
-        llvm::count_if(inputType.getShape(), ShapedType::isDynamic);
-    if (numDynamic > 1)
-      return rewriter.notifyMatchFailure(
-          op, "at most one dim in input may be dynamic");
-    if (!weightType.hasRank())
-      return rewriter.notifyMatchFailure(op, "unranked weight input");
-
-    if (!llvm::all_of(op.getStride(), [](int64_t v) { return v == 1; }))
-      return failure();
-
-    // Only works for a 1x1 kernel.
-    ArrayRef<int64_t> weightShape = weightType.getShape();
-    if (weightShape[1] != 1 || weightShape[2] != 1)
-      return failure();
-
-    llvm::ArrayRef<int64_t> padAttr = op.getPad();
-    llvm::SmallVector<int64_t> pad(8, 0);
-    for (const auto &it : llvm::enumerate(padAttr))
-      pad[it.index() + 2] = it.value();
-
-    Type inputETy = inputType.getElementType();
-    if (llvm::any_of(pad, [](int64_t p) { return p != 0; })) {
-      auto failureOrMaybeZps = extractConvZpPair(op, rewriter);
-      if (failed(failureOrMaybeZps))
-        return failure();
-
-      auto maybeZps = failureOrMaybeZps.value();
-
-      Attribute zeroAttr =
-          maybeZps ? rewriter.getIntegerAttr(inputETy, maybeZps->inputZp)
-                   : rewriter.getZeroAttr(inputETy);
-
-      llvm::SmallVector<int64_t> newShape(inputType.getShape());
-
-      for (int i = 0, s = newShape.size(); i < s; ++i) {
-        if (newShape[i] != ShapedType::kDynamic) {
-          newShape[i] += pad[i * 2] + pad[i * 2 + 1];
-        }
-      }
-
-      Value padSizeVal = getTosaConstShape(rewriter, op->getLoc(), pad);
-
-      auto padTy = RankedTensorType::get({}, inputETy);
-      auto padAttr = DenseElementsAttr::get(padTy, zeroAttr);
-      Value padVal =
-          rewriter.create<tosa::ConstOp>(op->getLoc(), padTy, padAttr);
-      inputType = RankedTensorType::get(newShape, inputETy);
-      input = rewriter.create<tosa::PadOp>(op->getLoc(), inputType, input,
-                                           padSizeVal, padVal);
-    }
-
-    // Reshape input to [N,IH,IW,IC] -> [N * IH * IW, IC].
-    ArrayRef<int64_t> inputShape = inputType.getShape();
-    int64_t combined = ShapedType::kDynamic;
-    if (numDynamic == 0)
-      combined = inputShape[0] * inputShape[1] * inputShape[2];
-    llvm::SmallVector<int64_t, 2> revisedInputShape{combined, inputShape[3]};
-    auto revisedInputShapeType =
-        RankedTensorType::get(revisedInputShape, inputType.getElementType());
-    auto reshapedInput = rewriter
-                             .create<tosa::ReshapeOp>(
-                                 op.getLoc(), revisedInputShapeType, input,
-                                 rewriter.getDenseI64ArrayAttr(
-                                     convertFromMlirShape(revisedInputShape)))
-                             .getResult();
-
-    // Reshape kernel to [OC,KH,KW,IC] -> [OC, IC].
-    llvm::SmallVector<int64_t, 2> revisedWeightShape{weightShape[0],
-                                                     weightShape[3]};
-    auto revisedWeightShapeType = RankedTensorType::get(
-        revisedWeightShape,
-        dyn_cast<RankedTensorType>(weight.getType()).getElementType());
-    auto reshapedWeight = rewriter
-                              .create<tosa::ReshapeOp>(
-                                  op.getLoc(), revisedWeightShapeType, weight,
-                                  rewriter.getDenseI64ArrayAttr(
-                                      convertFromMlirShape(revisedWeightShape)))
-                              .getResult();
-
-    // Perform a fully connected network over the reshaped input and weight.
-    llvm::SmallVector<int64_t, 2> fullyConnectedShape{combined, weightShape[0]};
-    auto fullyConnectedShapeType =
-        RankedTensorType::get(fullyConnectedShape, resultType.getElementType());
-
-    auto failureOrMaybeZps = extractConvZpPair(op, rewriter);
-    if (failed(failureOrMaybeZps))
-      return failure();
-
-    auto maybeZps = failureOrMaybeZps.value();
-    Value fullyConnectedValue;
-    if (maybeZps) {
-      fullyConnectedValue =
-          rewriter
-              .create<tosa::FullyConnectedOp>(
-                  op.getLoc(), fullyConnectedShapeType, reshapedInput,
-                  reshapedWeight, op.getBias(),
-                  rewriter.getI32IntegerAttr(maybeZps->inputZp),
-                  rewriter.getI32IntegerAttr(maybeZps->weightZp))
-              .getResult();
-    } else {
-      fullyConnectedValue = rewriter
-                                .create<tosa::FullyConnectedOp>(
-                                    op.getLoc(), fullyConnectedShapeType,
-                                    reshapedInput, reshapedWeight, op.getBias())
-                                .getResult();
-    }
-
-    // Reshape output to [N, IH, IW, OC].
-    llvm::SmallVector<int64_t, 4> outputShape{inputShape[0], inputShape[1],
-                                              inputShape[2], weightShape[0]};
-    rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
-        op, resultType, fullyConnectedValue,
-        rewriter.getDenseI64ArrayAttr(convertFromMlirShape(outputShape)));
-    return success();
-  }
-};
-
-} // namespace
-
-void mlir::tosa::populateTosaDecomposeConv2D(MLIRContext *ctx,
-                                             RewritePatternSet &patterns) {
-  patterns.add<Conv2DIsFullyConnected>(ctx);
-}
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaOptionalDecompositions.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaOptionalDecompositions.cpp
index 603185e48aa94e1..ffa2ea3d0629f37 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaOptionalDecompositions.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaOptionalDecompositions.cpp
@@ -38,7 +38,6 @@ struct TosaOptionalDecompositions
     RewritePatternSet patterns(ctx);
     auto func = getOperation();
 
-    mlir::tosa::populateTosaDecomposeConv2D(ctx, patterns);
   ...
[truncated]

@lhutton1 lhutton1 changed the title Remove FullyConnectedOp from TOSA Dialect [mlir][tosa] Remove FullyConnectedOp from TOSA Dialect Feb 12, 2025
* This patch removes FullyConncected Operator from the TOSA Dialect
and all associated tests and transforms.

* TosaDecomposeConv2D pass is also removed. Please expect this as a
  backend optimization

Signed-off-by: Tai Ly <[email protected]>
Change-Id: Ib8c928cb21daf325f00cdad302680af2d7c13da5
@Jerry-Ge Jerry-Ge merged commit 4ec1990 into llvm:main Feb 13, 2025
8 checks passed
joaosaffran pushed a commit to joaosaffran/llvm-project that referenced this pull request Feb 14, 2025
This patch removes FullyConncected Operator from the TOSA Dialect and
all associated tests and transforms.

This is part of the TOSA v1.0 alignment effort:
https://discourse.llvm.org/t/rfc-tosa-dialect-increment-to-v1-0/83708

Signed-off-by: Tai Ly <[email protected]>
Co-authored-by: Tai Ly <[email protected]>
sivan-shani pushed a commit to sivan-shani/llvm-project that referenced this pull request Feb 24, 2025
This patch removes FullyConncected Operator from the TOSA Dialect and
all associated tests and transforms.

This is part of the TOSA v1.0 alignment effort:
https://discourse.llvm.org/t/rfc-tosa-dialect-increment-to-v1-0/83708

Signed-off-by: Tai Ly <[email protected]>
Co-authored-by: Tai Ly <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants