Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR][OpenMP]Add prescriptiveness-modifier support to grainsize and … #128477

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

kaviya2510
Copy link
Contributor

…num_tasks clause.

@llvmbot
Copy link
Member

llvmbot commented Feb 24, 2025

@llvm/pr-subscribers-mlir-openmp
@llvm/pr-subscribers-flang-openmp

@llvm/pr-subscribers-mlir

Author: Kaviya Rajendiran (kaviya2510)

Changes

…num_tasks clause.


Full diff: https://github.com/llvm/llvm-project/pull/128477.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td (+6-8)
  • (modified) mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp (+113-9)
  • (modified) mlir/test/Dialect/OpenMP/invalid.mlir (+24)
  • (modified) mlir/test/Dialect/OpenMP/ops.mlir (+16)
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index a8d97a36df79e..32c28f72ec8e5 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -436,12 +436,11 @@ class OpenMP_GrainsizeClauseSkip<
     bit description = false, bit extraClassDeclaration = false
   > : OpenMP_Clause<traits, arguments, assemblyFormat, description,
                     extraClassDeclaration> {
-  let arguments = (ins
-    Optional<IntLikeType>:$grainsize
-  );
+  let arguments = (ins OptionalAttr<GrainsizeTypeAttr>:$grainsize_mod,
+      Optional<IntLikeType>:$grainsize);
 
   let optAssemblyFormat = [{
-    `grainsize` `(` $grainsize `:` type($grainsize) `)`
+    `grainsize` `(` custom<GrainsizeClause>($grainsize_mod , $grainsize, type($grainsize)) `)`
   }];
 
   let description = [{
@@ -895,12 +894,11 @@ class OpenMP_NumTasksClauseSkip<
     bit description = false, bit extraClassDeclaration = false
   > : OpenMP_Clause<traits, arguments, assemblyFormat, description,
                     extraClassDeclaration> {
-  let arguments = (ins
-    Optional<IntLikeType>:$num_tasks
-  );
+  let arguments = (ins OptionalAttr<NumTasksTypeAttr>:$num_tasks_mod,
+      Optional<IntLikeType>:$num_tasks);
 
   let optAssemblyFormat = [{
-    `num_tasks` `(` $num_tasks `:` type($num_tasks) `)`
+    `num_tasks` `(` custom<NumTasksClause>($num_tasks_mod , $num_tasks, type($num_tasks)) `)`
   }];
 
   let description = [{
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index d725a457aeff6..f8b948ff98864 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -472,6 +472,108 @@ static void printOrderClause(OpAsmPrinter &p, Operation *op,
     p << stringifyClauseOrderKind(order.getValue());
 }
 
+//===----------------------------------------------------------------------===//
+// Parser and printer for grainsize Clause
+//===----------------------------------------------------------------------===//
+
+// grainsize ::= `grainsize` `(` [strict ':'] grain-size `)`
+static ParseResult
+parseGrainsizeClause(OpAsmParser &parser, ClauseGrainsizeTypeAttr &grainsizeMod,
+                     std::optional<OpAsmParser::UnresolvedOperand> &grainsize,
+                     Type &grainsizeType) {
+  SMLoc loc = parser.getCurrentLocation();
+  StringRef enumStr;
+
+  if (succeeded(parser.parseOptionalKeyword(&enumStr))) {
+    if (std::optional<ClauseGrainsizeType> enumValue =
+            symbolizeClauseGrainsizeType(enumStr)) {
+      grainsizeMod =
+          ClauseGrainsizeTypeAttr::get(parser.getContext(), *enumValue);
+      if (parser.parseColon())
+        return failure();
+    } else {
+      return parser.emitError(loc, "invalid grainsize modifier : '")
+             << enumStr << "'";
+    }
+  }
+
+  OpAsmParser::UnresolvedOperand operand;
+  if (succeeded(parser.parseOperand(operand))) {
+    grainsize = operand;
+  } else {
+    return parser.emitError(parser.getCurrentLocation())
+           << "expected grainsize operand";
+  }
+
+  if (grainsize.has_value()) {
+    if (parser.parseColonType(grainsizeType))
+      return failure();
+  }
+
+  return success();
+}
+
+static void printGrainsizeClause(OpAsmPrinter &p, Operation *op,
+                                 ClauseGrainsizeTypeAttr grainsizeMod,
+                                 Value grainsize, mlir::Type grainsizeType) {
+  if (grainsizeMod)
+    p << stringifyClauseGrainsizeType(grainsizeMod.getValue()) << ": ";
+
+  if (grainsize)
+    p << grainsize << ": " << grainsizeType;
+}
+
+//===----------------------------------------------------------------------===//
+// Parser and printer for num_tasks Clause
+//===----------------------------------------------------------------------===//
+
+// numtask ::= `num_tasks` `(` [strict ':'] num-tasks `)`
+static ParseResult
+parseNumTasksClause(OpAsmParser &parser, ClauseNumTasksTypeAttr &numTasksMod,
+                    std::optional<OpAsmParser::UnresolvedOperand> &numTasks,
+                    Type &numTasksType) {
+  SMLoc loc = parser.getCurrentLocation();
+  StringRef enumStr;
+
+  if (succeeded(parser.parseOptionalKeyword(&enumStr))) {
+    if (std::optional<ClauseNumTasksType> enumValue =
+            symbolizeClauseNumTasksType(enumStr)) {
+      numTasksMod =
+          ClauseNumTasksTypeAttr::get(parser.getContext(), *enumValue);
+      if (parser.parseColon())
+        return failure();
+    } else {
+      return parser.emitError(loc, "invalid numTasks modifier : '")
+             << enumStr << "'";
+    }
+  }
+
+  OpAsmParser::UnresolvedOperand operand;
+  if (succeeded(parser.parseOperand(operand))) {
+    numTasks = operand;
+  } else {
+    return parser.emitError(parser.getCurrentLocation())
+           << "expected num_tasks operand";
+  }
+
+  if (numTasks.has_value()) {
+    if (parser.parseColonType(numTasksType))
+      return failure();
+  }
+
+  return success();
+}
+
+static void printNumTasksClause(OpAsmPrinter &p, Operation *op,
+                                ClauseNumTasksTypeAttr numTasksMod,
+                                Value numTasks, mlir::Type numTasksType) {
+  if (numTasksMod)
+    p << stringifyClauseNumTasksType(numTasksMod.getValue()) << ": ";
+
+  if (numTasks)
+    p << numTasks << ": " << numTasksType;
+}
+
 //===----------------------------------------------------------------------===//
 // Parsers for operations including clauses that define entry block arguments.
 //===----------------------------------------------------------------------===//
@@ -2593,15 +2695,17 @@ void TaskloopOp::build(OpBuilder &builder, OperationState &state,
                        const TaskloopOperands &clauses) {
   MLIRContext *ctx = builder.getContext();
   // TODO Store clauses in op: privateVars, privateSyms.
-  TaskloopOp::build(
-      builder, state, clauses.allocateVars, clauses.allocatorVars,
-      clauses.final, clauses.grainsize, clauses.ifExpr, clauses.inReductionVars,
-      makeDenseBoolArrayAttr(ctx, clauses.inReductionByref),
-      makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable,
-      clauses.nogroup, clauses.numTasks, clauses.priority, /*private_vars=*/{},
-      /*private_syms=*/nullptr, clauses.reductionMod, clauses.reductionVars,
-      makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
-      makeArrayAttr(ctx, clauses.reductionSyms), clauses.untied);
+  TaskloopOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
+                    clauses.final, clauses.grainsizeMod, clauses.grainsize,
+                    clauses.ifExpr, clauses.inReductionVars,
+                    makeDenseBoolArrayAttr(ctx, clauses.inReductionByref),
+                    makeArrayAttr(ctx, clauses.inReductionSyms),
+                    clauses.mergeable, clauses.nogroup, clauses.numTasksMod,
+                    clauses.numTasks, clauses.priority, /*private_vars=*/{},
+                    /*private_syms=*/nullptr, clauses.reductionMod,
+                    clauses.reductionVars,
+                    makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
+                    makeArrayAttr(ctx, clauses.reductionSyms), clauses.untied);
 }
 
 SmallVector<Value> TaskloopOp::getAllReductionVars() {
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index d7f468bed3d3d..63ccd7957b492 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -2064,6 +2064,30 @@ func.func @taskloop(%lb: i32, %ub: i32, %step: i32) {
 
 // -----
 
+func.func @taskloop(%lb: i32, %ub: i32, %step: i32) {
+  %testi64 = "test.i64"() : () -> (i64)
+  // expected-error @below {{invalid grainsize modifier : 'strict1'}}
+  omp.taskloop grainsize(strict1: %testi64: i64) {
+    omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
+      omp.yield
+    }
+  }
+  return
+}
+// -----
+
+func.func @taskloop(%lb: i32, %ub: i32, %step: i32) {
+  %testi64 = "test.i64"() : () -> (i64)
+  // expected-error @below {{invalid numTasks modifier : 'default'}}
+  omp.taskloop num_tasks(default: %testi64: i64) {
+    omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
+      omp.yield
+    }
+  }
+  return
+}
+// -----
+
 func.func @taskloop(%lb: i32, %ub: i32, %step: i32) {
   // expected-error @below {{op nested in loop wrapper is not another loop wrapper or `omp.loop_nest`}}
   omp.taskloop {
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index e318afbebbf0c..5d44dc1da503d 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -2417,6 +2417,22 @@ func.func @omp_taskloop(%lb: i32, %ub: i32, %step: i32) -> () {
     }
   }
 
+  // CHECK: omp.taskloop grainsize(strict: %{{[^:]+}}: i64) {
+  omp.taskloop grainsize(strict: %testi64: i64) {
+    omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
+      // CHECK: omp.yield
+      omp.yield
+    }
+  }
+
+  // CHECK: omp.taskloop num_tasks(strict: %{{[^:]+}}: i64) {
+  omp.taskloop num_tasks(strict: %testi64: i64) {
+    omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
+      // CHECK: omp.yield
+      omp.yield
+    }
+  }
+
   // CHECK: omp.taskloop nogroup {
   omp.taskloop nogroup {
     omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {

Copy link
Contributor

@tblah tblah left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These two parsing functions are doing broadly the same thing. Could they share code?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants