Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Resize scheduler update #3657

Closed
wants to merge 128 commits into from
Closed

[WIP] Resize scheduler update #3657

wants to merge 128 commits into from

Conversation

naoyam
Copy link
Collaborator

@naoyam naoyam commented Dec 31, 2024

No description provided.

naoyam and others added 27 commits January 15, 2025 14:53
This doesn't seem to break anything, but it's just strange to have iter
domains like `iS1{8 ex 8}`, i.e., a concrete iter domain that appears to
be extended to the same actual extent.
This reverts commit 14193c6.
Copy link

github-actions bot commented Feb 4, 2025

Review updated until commit 52c19a6

Description

  • Added padding consistency analysis for resize operations.

  • Enhanced pad predicate elimination to omit unnecessary predicates.

  • Improved resize scheduling with detailed logging and graph generation.

  • Added new tests for resize and predicate elimination.


Changes walkthrough 📝

Relevant files
Enhancement
8 files
index.cpp
Added padding consistency analysis and predicate elimination for pad
operations.
+353/-23
fusion_segmenter.cpp
Added debug prints and improved segment candidate merging logic.
+125/-16
indexing.cpp
Enhanced predicate generation for resize operations.         
+45/-2   
indexing_traversal.cpp
Added error handling and graph generation for indexing traversal.
+36/-1   
registry_utils.cpp
Added static size check to non-unique broadcast detection.
+24/-2   
resize.cpp
Enhanced resize scheduling with detailed logging and graph generation.
+71/-1   
loop_domain_scheduler.cpp
Added id_mapping_mode parameter to scheduleLoopDomainsLike.
+20/-6   
test_rope.cpp
Added debug prints and graph generation for rope tests.   
+43/-5   
Configuration changes
4 files
options.cpp
Added new option for disabling pad predicate elimination.
+1/-0     
options.h
Added new option for disabling pad predicate elimination.
+1/-0     
registry_utils.h
Added static size check parameter to non-unique broadcast detection.
+1/-1     
loop_domain_scheduler.h
Added id_mapping_mode parameter to scheduleLoopDomainsLike.
+2/-1     
Tests
2 files
test_indexing.cpp
Added new tests for predicate indexing with resize operations.
+288/-0 
test_resize.cpp
Added new tests for pad predicate elimination.                     
+317/-0 

PR Reviewer Guide 🔍

Here are some key observations to aid the review process:

🧪 PR contains tests
⚡ Recommended focus areas for review

Debugging Output

The code includes multiple std::cerr statements for debugging purposes. These should be removed or converted to proper logging before merging.

  // When using the fused reduction, allocate the reduction object at
  // the outer-most scope
  insertAtTopLevel(fused_reduction_alloc_reduction);
}

namespace {

// Check if the tensor is scheduled in such a way that the
// padded region is included in the loop domain. It should be
// sufficient if there's the mapped resize expr between the
// logical and loop domains of this tensor. Note that, however,
// preceding resizes may make this unsafe.
class PaddingConsistencyAnalysis {
 public:
  static bool paddedConsistently(
      TensorView* dep_tv,
      const ExprGroups& pad_resizes) {
    PaddingConsistencyAnalysis analysis(dep_tv, pad_resizes);
    return analysis.is_consistent_;
  }

 private:
  PaddingConsistencyAnalysis(TensorView* dep_tv, const ExprGroups& pad_resizes)
      : pad_resizes_(pad_resizes) {
    const auto& tensor_indexer = GpuLower::current()->tensorIndexer();
    const auto& index_traversal_graph = tensor_indexer.traversalGraph();

    auto dep_tv_def = dep_tv->definition();

    const auto logical_domain_indexing_path =
        tensor_indexer.getIndexingPath(dep_tv_def, dep_tv->getLogicalDomain());

    for (const auto& [path_expr_g, dir] : logical_domain_indexing_path) {
      // Since the indexing traversal may use a local graph for resize,
      // needs to explicitly find the group in the indexing
      // traversal graph.
      const auto& expr_g = index_traversal_graph.toGroup(path_expr_g->front());
      const auto inputs = getInputsOfExpr(
          expr_g,
          dir,
          ValGraphInputs(index_traversal_graph),
          ValGraphOutputs(index_traversal_graph));

      const auto outputs = getOutputsOfExpr(
          expr_g,
          dir,
          ValGraphInputs(index_traversal_graph),
          ValGraphOutputs(index_traversal_graph));

      if (expr_g->front()->isA<Resize>()) {
        if (!handleResize(expr_g, inputs, outputs)) {
          return;
        }
      } else {
        if (!handleNonResize(expr_g, inputs, outputs)) {
          return;
        }
      }
    }

    // Traversed the path successfully. Should mean there's no
    // conflicting different resize

    // Need to make sure all of the pad resizes are found in the
    // path
    if (detected_resize_exprs_ != pad_resizes.set()) {
      return;
    }

    is_consistent_ = true;
  }

  bool handleResize(
      const ExprGroup& expr_g,
      const ValGroups& inputs,
      const ValGroups& outputs) {
    const auto& resize_input = inputs.at(0);
    if (no_more_resize_dep_set.count(resize_input)) {
      // No resize allowed
      return false;
    }

    // Check if this is a resize of the pad op.
    auto resize_expr_it =
        std::find(pad_resizes_.begin(), pad_resizes_.end(), expr_g);
    if (resize_expr_it != pad_resizes_.end()) {
      detected_resize_exprs_.insert(expr_g);
      for (const auto& output_g : outputs) {
        NVF_ERROR(pad_resize_dep_map_.emplace(output_g, expr_g).second);
      }
      return true;
    }

    const bool depends_on_pad_resize = std::any_of(
        inputs.begin(), inputs.end(), [&](const ValGroup& input_group) {
          return pad_resize_dep_map_.count(input_group);
        });

    if (!depends_on_pad_resize) {
      return true;
    }

    // There's a dependency. Check if this is a valid expr. If
    // it's another resize, it must have positive expansion
    // factors. If not resize, no further resize will be allowed

    // Different resize. For padded sides, the resize expand
    // factor must be non-negative

    const ExprGroup& original_pad_resize = pad_resize_dep_map_.at(inputs.at(0));

    // Left expand factor
    if (!original_pad_resize->front()->as<Resize>()->leftExpand()->isZero()) {
      auto dep_resize_left = expr_g->front()->as<Resize>()->leftExpand();
      if (!dep_resize_left->isConstInt() ||
          dep_resize_left->evaluate().as<int64_t>() < 0) {
        return false;
      }
    }

    if (!original_pad_resize->front()->as<Resize>()->rightExpand()->isZero()) {
      auto dep_resize_right = expr_g->front()->as<Resize>()->rightExpand();
      if (!dep_resize_right->isConstInt() ||
          dep_resize_right->evaluate().as<int64_t>() < 0) {
        return false;
      }
    }

    for (const auto& output_g : outputs) {
      NVF_ERROR(
          pad_resize_dep_map_.emplace(output_g, original_pad_resize).second);
    }

    return true;
  }

  bool handleNonResize(
      const ExprGroup& expr_g,
      const ValGroups& inputs,
      const ValGroups& outputs) {
    const bool depends_on_pad_resize = std::any_of(
        inputs.begin(), inputs.end(), [&](const ValGroup& input_group) {
          return pad_resize_dep_map_.count(input_group);
        });

    // If depends on a pad resize, no further resize is
    // allowed. Propagate the no-resize info
    if (depends_on_pad_resize ||
        std::any_of(
            inputs.begin(), inputs.end(), [&](const ValGroup& input_group) {
              return no_more_resize_dep_set.count(input_group);
            })) {
      for (const auto& output_group : outputs) {
        no_more_resize_dep_set.emplace(output_group);
      }
    }

    return true;
  }

 private:
  const ExprGroups& pad_resizes_;

  bool is_consistent_ = false;

  std::unordered_map<ValGroup, ExprGroup> pad_resize_dep_map_;
  std::unordered_set<ValGroup> no_more_resize_dep_set;
  std::unordered_set<ExprGroup> detected_resize_exprs_;
};

bool canOmitPadPredicate(const PadOp* pad) {
  if (isOptionDisabled(DisableOption::PadPredicateElimination)) {
    return false;
  }

  // TensorIndexer and PredicateElimination are required
  if (!GpuLower::current()->isTensorIndexerEnabled() ||
      isOptionDisabled(DisableOption::PredicateElimination)) {
    return false;
  }

  std::cerr << "Checking " << pad->toString();

  auto consumer_tv = pad->out()->as<TensorView>();

  const auto& tensor_indexer = GpuLower::current()->tensorIndexer();
  const auto& index_traversal_graph = tensor_indexer.traversalGraph();
  const auto& pred_info = GpuLower::current()->predicateElimination();

  auto resize_exprs = DependencyCheck::getAllExprsBetween(
      {consumer_tv->getRootDomain().begin(),
       consumer_tv->getRootDomain().end()},
      {consumer_tv->getLogicalDomain().begin(),
       consumer_tv->getLogicalDomain().end()});

  NVF_ERROR(
      !resize_exprs.empty() &&
      std::all_of(resize_exprs.begin(), resize_exprs.end(), [](Expr* expr) {
        return expr->isA<Resize>();
      }));

  // All resize expansion factors must be static and non-negative
  for (auto expr : resize_exprs) {
    for (auto expand_val :
         {expr->as<Resize>()->leftExpand(),
          expr->as<Resize>()->rightExpand()}) {
      if (!expand_val->isConstInt()) {
        return false;
      }
      auto expand_int = expand_val->evaluate().as<int64_t>();
      if (expand_int < 0) {
        return false;
      }
    }
  }

  ExprGroups resize_expr_groups = index_traversal_graph.toGroups(resize_exprs);

  const auto pad_val = pad->value();

  // Contains tensors to check. Starting with the consumer tv and its
  // upward dependent tensors are added as necessary. Specifically,
  // when an expr is predicated, check if it's initialized to the same
  // value as the padding value. If yes and the consumer of the expr
  // is found to have the padded region as part of its loop domain, it
  // should be safe to omit the pad predicate. If the predicate of the
  // expr is omitted, its preceding exprs need to be checked. Since
  // reading fusion inputs should never omit predicates, this list
  // should never include fusion inputs.

  std::deque<TensorView*> tvs_to_check;
  tvs_to_check.push_back(consumer_tv);

  while (!tvs_to_check.empty()) {
    auto tv_to_check = tvs_to_check.front();
    tvs_to_check.pop_front();

    // tvs_to_check should never include a fusion input
    NVF_ERROR(
        !tv_to_check->isFusionInput(),
        "Not expected to have a fusion input: ",
        tv_to_check->toString());

    auto tv_expr = tv_to_check->definition();
    NVF_ERROR(
        tv_expr != nullptr,
        "Unexpected to have no definition: ",
        tv_to_check->toString());

    if (pred_info.canOmitPredicate(tv_expr)) {
      // If predicate is omitted and producer values are just
      // propagated, check the producers

      // Check if the producer value is just moved
      if (tv_expr != pad) {
        if (!tv_expr->isOneOf<
                LoadStoreOp,
                BroadcastOp,
                ExpandOp,
                SqueezeOp,
                SliceOp,
                CatOp,
                ViewOp,
                UnaryOp>()) {
          std::cerr << "Unsupported op: " << tv_expr->toString();
          return false;
        }

        // For unary op, only cast is allowed for now. Should be able to
        // support, e.g., abs, neg, etc. Neg must be careful as the
        // negative zero is different from the positive zero, which
        // matters for bitwise-or based concat
        if (auto uop = dynamic_cast<UnaryOp*>(tv_expr)) {
          if (uop->getUnaryOpType() != UnaryOpType::Cast) {
            std::cerr << "Unsupported op: " << tv_expr->toString();
            return false;
          }
        }
      }

      // If there's no producer, i.e., a full op, and the predicate of
      // the expr is omitted, can't guarantee anything about the
      // padded region
      auto producer_tvs = ir_utils::producerTvsOf(tv_to_check);
      if (producer_tvs.empty()) {
        return false;
      }
      for (auto producer_tv : producer_tvs) {
        // If tv_expr has a fusion input as one of its input, its
        // predicate should never be omitted, so producer_tv should
        // not be a fusion input
        NVF_ERROR(!producer_tv->isFusionInput());
        tvs_to_check.push_back(producer_tv);
      }
    } else {
      auto init_val = pred_info.getInitValue(tv_to_check);
      if (init_val == nullptr) {
        // Can't determine if it's safe to omit without an init value
        std::cerr << "No init val\n";
        return false;
      }

      // Note Val::sameAs may not work as the data types may be
      // different (e.g., 0.0f vs 0L)
      bool initialized_to_same_value = pad_val->value().hasValue() &&
          init_val != nullptr && init_val->value().hasValue() &&
          pad_val->value() == init_val->value();

      if (!initialized_to_same_value) {
        std::cerr << "Not initialized to same val\n";
        return false;
      }

      if (!PaddingConsistencyAnalysis::paddedConsistently(
              tv_to_check, resize_expr_groups)) {
        std::cerr << "Not consistently padded\n";
        return false;
      }
    }
  }

  std::cerr << "Pad predicate can be safely removed: " << pad->toString();
Debugging Output

The code includes multiple std::cerr statements for debugging purposes. These should be removed or converted to proper logging before merging.

if (std::all_of(
Debugging Output

The code includes multiple std::cout and std::cerr statements for debugging purposes. These should be removed or converted to proper logging before merging.

  std::cout << std::endl;
  std::cout << "Resize scheduling\n";
  fusion->print();
  std::cout << std::endl;
}

{
  std::stringstream file_name;
  file_name << "pre_scheduling.dot";
  IrGraphGenerator::print(
      fusion,
      file_name.str().c_str(),
      IrGraphGenerator::DetailLevel::ComputeOnly);
}

auto ref_tv = getReferenceTensor(fusion);
NVF_ERROR(ref_tv != nullptr);

scheduler_utils::cacheInputs(fusion, true);
scheduler_utils::cacheAndForkOutputs(fusion, true);

auto resize_tensor_ops = ir_utils::getOpsOfType<SliceOp, PadOp>(fusion);

std::unique_ptr<IdModel> id_model =
    std::make_unique<IdModel>(fusion, /*build_graphs=*/false);
id_model->buildExactGraph();

// Replicate resize inputs if necessary to avoid conflicting
// propagations
const auto exclusivity_info_map = scheduler_tools::getNonExclusiveResizeInfo(
    resize_tensor_ops, id_model->idGraph(IdMappingMode::EXACT));
for (auto resize_tensor_op : resize_tensor_ops) {
  auto out_tv = resize_tensor_op->output(0)->as<TensorView>();
  if (exclusivity_info_map.count(out_tv) == 0) {
    continue;
  }
  auto inp_tv = resize_tensor_op->input(0)->as<TensorView>();
  // Since cacheInput may skip caching if an input is used by
  // slice/pad, inp_tv may be a fusion input, in which case it is
  // not necessary to recompute the tensor.
  if (inp_tv->isFusionInput()) {
    continue;
  }
  auto inp_tv_copy = RecomputeTv::recompute(inp_tv);
  ir_utils::replaceValInExprInputs(resize_tensor_op, inp_tv, inp_tv_copy);
}

{
  std::cout << std::endl;
  std::cout << "After recomputation\n";
  fusion->print();
  std::cout << std::endl;

  std::stringstream file_name;
  file_name << "after_recomputation.dot";
  IrGraphGenerator::print(
      fusion,
      file_name.str().c_str(),
      IrGraphGenerator::DetailLevel::ComputeOnly);
}

TensorView* largest_input = nullptr;
if (resize_params->largest_input >= 0) {
  largest_input =
      fusion->inputs().at(resize_params->largest_input)->as<TensorView>();

  // The tensors are going to be reordered to align with the largest
  // input. To make it work, merge operations for reshape should be
  // cancelled.
  scheduler_tools::cancelReshapeInLoopDomains(largest_input);
}

{
  std::cout << std::endl;
  std::cout << "After reshape cancel\n";
  fusion->print();
  std::cout << std::endl;
}
for (auto expr : fusion->exprs()) {
  if (!expr->isOneOf<SliceOp, PadOp>()) {
    continue;
  }

  std::cerr << "propagateResize: " << expr->toString();

  scheduler_tools::propagateResizeToInputs(expr);
}

// Update the IdModel
id_model = std::make_unique<IdModel>(fusion, /*build_graphs=*/false);
id_model->buildExactGraph();

// Detect an ending repeat
auto static_repeat_info = scheduler_tools::getMaybeStaticRepeatInfo(ref_tv);

if (static_repeat_info.has_value()) {
  std::cerr << "Static repeat: "
            << static_repeat_info->reshape_repeat_id->toString() << "\n";
}

// Just simple scheduling for now.
// TODO: Do something smarter. Can just use the pointwise scheduler?

std::cerr << "Ref tensor: " << ref_tv->toString() << "\n";

// Reorder tensors to align with the largest input. This is expected
// to improve the memory read performance, while the write
// performance could be lowered. This should generally be more
// important to optimize the read performance, but more robust
// decision would be needed.
if (largest_input != nullptr) {
  std::vector<IterDomain*> ref_alloc;
  ref_alloc.reserve(largest_input->getMaybeAllocationDomain().size());
  std::copy_if(
      largest_input->getMaybeAllocationDomain().begin(),
      largest_input->getMaybeAllocationDomain().end(),
      std::back_inserter(ref_alloc),
      [](IterDomain* alloc_id) {
        return !alloc_id->isBroadcast() && !alloc_id->isReduction() &&
            !alloc_id->isDeviceDim();
      });

  // Reorder the reference as the allocation domain of the largest fusion
  // input
  scheduler_utils::reorderTensorLike(ref_tv, ref_alloc);
}

const int64_t bdimx = 128;

// Make sure the DID ID located at the outermost position
auto outermost_pos = scheduler_utils::reorderDevicesToOuter(ref_tv);

// [DID, ..., ...]
//        ^
//        +--- outermost_pos

// Move the static repeat ID to the outermost position if
// detected. The repeat ID then just remains there with no
// scheduling.
bool repeat_id_moved_to_outermost = false;
if (static_repeat_info.has_value()) {
  NVF_ERROR(ref_tv == static_repeat_info->repeat_output_tv);
  auto ref_repeat_id_it = std::find_if(
      ref_tv->getLoopDomain().begin(),
      ref_tv->getLoopDomain().end(),
      [&](IterDomain* loop_id) {
        return id_model->idGraph(IdMappingMode::EXACT)
            .disjointValSets()
            .strictAreMapped(loop_id, static_repeat_info->reshape_repeat_id);
      });
  // Gives up if the repeat ID is not found. Unclear if this could
  // actually happen, though.
  if (ref_repeat_id_it != ref_tv->getLoopDomain().end()) {
    auto repeat_id_pos =
        std::distance(ref_tv->getLoopDomain().begin(), ref_repeat_id_it);
    NVF_ERROR(
        repeat_id_pos >= outermost_pos,
        "Unexpected to have DID-parallelized repeat axis: ",
        static_repeat_info->reshape_repeat_id->toString());

    // [DID, ..., repeat_id, ...]
    //        ^
    //        +--- outermost_pos
    ref_tv->reorder(std::unordered_map<int64_t, int64_t>{{repeat_id_pos, 0}});
    ++outermost_pos;
    // [repeat_id, DID, ...]
    //                   ^
    //                   +--- outermost_pos

    repeat_id_moved_to_outermost = true;
  }
}

const int64_t vec_factor = resize_params->vectorization_factor;

int64_t next_innermost_pos = -1;
// [..., ...]
//        ^
//        +--- next_innermost_pos

if (vec_factor > 1) {
  ref_tv->split(-1, vec_factor);
  --next_innermost_pos;
  // [..., vec_factor]
  //   ^
  //   +--- next_innermost_pos
}

ref_tv->flatten(outermost_pos, next_innermost_pos);
// [..., I0, vec_factor]
//       ^
//       +--- next_innermost_pos

ref_tv->split(next_innermost_pos, bdimx);
ref_tv->axis(next_innermost_pos)->parallelize(ParallelType::TIDx);
--next_innermost_pos;
// [..., I0/bdimx, bdimx(TIDx), vec_factor]
//         ^
//         +--- next_innermost_pos

if (resize_params->split_grid_x_dim) {
  ref_tv->split(next_innermost_pos, ResizeParams::max_gdimx);
  // [..., I0/bdimx/max_gdimx, max_gdimx, bdimx(TIDx), vec_factor]
}
ref_tv->axis(next_innermost_pos)->parallelize(ParallelType::BIDx);
// [..., I0/bdimx/max_gdimx, max_gdimx(BIDx), bdimx(TIDx), vec_factor] or
// [..., I0/bdimx(BIDx), bdimx(TIDx), vec_factor]

std::cout << "Before ref prop\n";
fusion->print();
std::cout << std::endl;

for (auto tv : fusion->allTvs()) {
  std::cerr << tv->toString() << "\n";
  for (auto expr : tv->domain()->allExprs()) {
    std::cerr << expr->toString();
  }
  std::cerr << "---\n";
}

{
  IdModel idg(fusion, false);
  idg.buildExactGraph();
  std::ofstream ofs("exact_graph_before_ref_prop.dot", std::ofstream::trunc);
  auto dot_string = idg.idGraph(IdMappingMode::EXACT).toGraphvizDotGraph();
  ofs << dot_string;
  ofs.close();
}

// Propagate the reference to the other tensors. Note that the
// update flag is enabled to workaround the resize propagation
// issue. This may not work if there's a tensor that is reshaped
// from the reference tensor, but that should not be the case as the
// reference is picked by the same routine used for the pointwise
// scheduler.
//
// When an ending static repeat is detected and the repeat ID is
// moved to the outermost position, propagation is done separately
// between the tensors before the repeat and after the repeat. The
// tensors are first grouped into the pre-repeat group and the
// post-repeat group, where only the latter group has the repeat
// IDs. When propagating the loop domain of the reference tensor,
// which has the repeat ID, the full loop domain is propagated only
// to the post-repeat group. For the pre-repeat group, the repeat ID
// is dropped and only the remaining loop domain is propagated.
if (repeat_id_moved_to_outermost) {
  // Divide all tvs to the pre and posgt repeat groups
  auto all_tvs = fusion->allTvs();
  std::vector<TensorView*> post_repeat_tvs;
  post_repeat_tvs.reserve(static_repeat_info->repeat_tvs.size());
  std::vector<TensorView*> pre_repeat_tvs;
  pre_repeat_tvs.reserve(
      all_tvs.size() - static_repeat_info->repeat_tvs.size());
  for (auto tv : all_tvs) {
    if (static_repeat_info->repeat_tvs.count(tv)) {
      post_repeat_tvs.push_back(tv);
    } else {
      pre_repeat_tvs.push_back(tv);
    }
  }

  // The repeat ID should be located at the outermost position
  std::vector<IterDomain*> non_repeated_loop{
      ref_tv->getLoopDomain().begin() + 1, ref_tv->getLoopDomain().end()};

  scheduler_tools::scheduleLoopDomainsLike(
      pre_repeat_tvs,
      non_repeated_loop,
      /*update_loop_domain_only=*/true);
  scheduler_tools::scheduleLoopDomainsLike(
      post_repeat_tvs,
      ref_tv->getLoopDomain(),
      /*update_loop_domain_only=*/true);
} else {
  scheduler_tools::scheduleLoopDomainsLike(
      fusion->allTvs(),
      ref_tv->getLoopDomain(),

@naoyam naoyam closed this Mar 14, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants