Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add utility lower_utils::proveLinearAndGetStride #2911

Merged
merged 61 commits into from
Sep 23, 2024
Merged

Add utility lower_utils::proveLinearAndGetStride #2911

merged 61 commits into from
Sep 23, 2024

Conversation

zasdfgbnm
Copy link
Collaborator

@zasdfgbnm zasdfgbnm commented Sep 5, 2024

Add utility lower_utils::proveLinearAndGetStride, which proves that the provided linear_g is linear with respect to the given domain, and returns the linear coefficient (stride). This utility will be used in the lowering of mma to fill in the "Leading dimension byte offset" and "Stride dimension byte offset" in the matrix descriptor.

See:
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-shared-memory-layout-matrix-descriptor
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#strides

@zasdfgbnm zasdfgbnm changed the title proveLinearAndGetStride Add utility lower_utils::proveLinearAndGetStride Sep 7, 2024
Comment on lines +597 to +600
template <typename T>
using MaybeUniqueOwningPtr = dynamic_type::
DynamicType<dynamic_type::NoContainers, T*, std::unique_ptr<T>>;

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

End up not using this, but seems useful to just leave it there.

Comment on lines +186 to +193
static ExprPath getExprsBetween(
const ValGraph& graph,
std::vector<NodeType> from,
std::vector<NodeType> to) {
ValGraphBFS bfs(graph, std::move(from), std::move(to));
bfs.traverse();
return bfs.getShortestExprPath();
}
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

End up not using, but feel that it is helpful to keep.

@zasdfgbnm
Copy link
Collaborator Author

!build

@zasdfgbnm zasdfgbnm marked this pull request as ready for review September 7, 2024 07:54
@zasdfgbnm
Copy link
Collaborator Author

@naoyam This PR is ready for review again. Highlighted changes:

  • Rename LinearGroupProjection -> Projection
  • Added a paragraph in comment suggesting that we should view Projection as a formal language where different types in the dynamic type represent different node types in the abstract syntax tree(AST). Updated other comments mentioning "abstract syntax tree". I believe this view point is helpful for understanding the algorithm.
  • Added a new function simplify for traversing the AST, identifying specific patterns, and simplify identified patterns. The function simplify is very similar to a compiler where each pass find and replace a specific pattern in the AST.
  • propagate now doesn't do any simplification. It focuses on generating the correct AST, and invoke simplify to simplify it.
  • Rename PartOf::group -> PartOf::what.

I believe after the change, the code is much more modular and easier to understand now.

// inner_extent=3,
// selected_extent=2}
// =>
// PartOf{what=24, inner_extent=6, selected_extent=2}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To add a little more context here, PartOf{what=24, inner_extent=2, selected_extent=12} can be generated by merging linear_g of extent 12 and a domain of extent 2, i.e., merge(linear_g, [2]) -> PartOf{what=24, inner_extent=2, selected_extent=12}.

Then, what propagation would create this?

   PartOf{
     what=PartOf{what=24, inner_extent=2, selected_extent=12},
     inner_extent=3,
     selected_extent=2}

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1000001421

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I'm confused. PartOf{what=24, inner_extent=2, selected_extent=12} represents the 24 node. Then, IIUC, PartOf{ what=PartOf{what=24, inner_extent=2, selected_extent=12},... should be generated by propagating through either merge or split, but your diagram indicates something propagated along the 12 node. What am I missing?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When we propagate from the red 2 to the 24, we first propagate through the merge that generates 6, at this point, we will get:

PartOf{what=6, inner_extent=3, selected_extent=2}

The next step is to propagate through the merge that generates the 12, and we will replace the 6 with

PartOf{what=12, inner_extent=nullptr, selected_extent=6}

which we get:

PartOf{what=PartOf{what=12, inner_extent=nullptr, selected_extent=6}, inner_extent=3, selected_extent=2}

which will be simplified as

PartOf{what=12, inner_extent=3, selected_extent=2}

Then we will propagate through the merge that generates the 24, which will replace the 12 with

PartOf{what=24, inner_extent=2, selected_extent=12}

and get

PartOf{what=PartOf{what=24, inner_extent=2, selected_extent=12}, inner_extent=3, selected_extent=2}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot. This makes sense to me. Can you add this as a code comment as well? I think it's very helpful to understand the logic.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added to the comment before propagate. Also, I moved the definition of simplify below proveLinearAndGetStrideAfterPropagation so that the order the reader read comments is more understandable.

Copy link
Collaborator Author

@zasdfgbnm zasdfgbnm Sep 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also added a "Proof of correctness" to the comment of cancelCommonFactors. The mathematical foundation for this simplification is the Theorem 2.1 in https://github.com/NVIDIA/Fuser/blob/main/doc/reading/iterdomain.md.

// canceling the last items in `what`.
//
// Example:
// PartOf{what=[5,3,2], inner_extent=42} => PartOf{what=[5], inner_extent=7}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you also mention how PartOf{what=[5,3,2], inner_extent=42} can be generated?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Screenshot_20240918_093052_Samsung Notes

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not a very good example. A better example could be

PartOf{what=[7,3,2], inner_extent=30}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean a domain of extent 30 is split twice? I think I understand each split would generate [1, 42], inner_extent=42 and [5, 3, 2], but I don't follow why the could be combined.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, they split twice. They could combine because, for example, the [1, 42] is the loop domain, and the [5, 3, 2] is allocation domain.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, right. I think I finally understand your question after you answered it yourself 😂.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, about this simplification, I suppose this is the same simplification we discussed earlier ([5, 3] -> [5]). In the case of [5, 3, 2], inner_extent=42 -> [5], inner_extent=7, this means that we view the inner two IDs of [3, 2] as a non-modifiable unit, so it should be effectively equivalent to [5], inner_extent=7. However, what would happen if there's also an expr to traverse that uses [3, 2]? For example, let's say there's a split of 2 to [1, 3]. Since what no longer has [2], this propagation would be just a no-op? Or, what about, e.g., a merge of 3 as the outer and 5 as the inner?

Copy link
Collaborator Author

@zasdfgbnm zasdfgbnm Sep 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a really good question. I love this question because it really touches the mathematical foundation of the proving we implemented.

When we start from linear_g and propagate to 5, 3, 2, mathematically, we are proving that linear_g ~ f(5, 3, 2), where f is a specific way to operates on its parameters, and ~ denote "equivalent" (e.g. 6 ~ [3, 2] if you have 6 = merge(3, 2)).

When we run the simplification, we simplify f(5, 3, 2) into g(5), where g is another specific way to operate on its parameters. The process of simplifying f(5, 3, 2) into g(5) is just like in a math class, we simplify $f(x, y, z) = z + x + 0 \cdot y - z$ as $g(x) = x$, where y and z is canceled means the value of f(x, y, z) does not depend on the value of y and z. Similarly, for this case, simplifying f(5, 3, 2) into g(5) means that, the index of f(5, 3, 2) does not depend on the index of 3 and 2.

When we further propagate from g(5) to domain, mathematically, we are proving that g(5) ~ h(x1, x2, ...) where x1, x2, ... are ValGroups in domain. Note that this process only depend on the exprs between 5 and domain, nothing else. The split of 2 into [1, 3] is not on this path, so it will be unrelated. Merging 3 with 5 is related, but the only information that is related here is "3 is not a parameter of g", it does not matter where this 3 comes from and whether it is connected to linear_g or not, because here, we are just focusing on the subproblem of proving g(5) ~ h(x1, x2, ...) and linear_g is unrelated to this subproblem. This is just like, in a math class, (assuming $0 \le x &lt; 5$), if I need to prove $g(x) = h(z)$, where $g(x) = x$, $h(z) = z \% 5$, and $z = x + 5 \cdot y$, the only thing I need to pay attention is $(x + 5 \cdot y) \% 5 = x$ for an arbitrary $y$. Where this y comes from, and whether g(x) was simplified from a f(x, y, z) that contains y is unrelated.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, IIUC, what you're saying is that no matter what other exprs are traversed, if it uses anything other than 5, it should be just ignored as it shouldn't matter. Is my understanding correct?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think what I missed was this part:

inner_extent is a multiple of the extent of the last items in what,

So, because of this condition, the canceled factor should not matter for the remaining traversal.

It finally makes sense to me. Thanks for the updated comment. That helped a lot.

count++;
const auto& item = dq.at(dq.size() - count);
what_extent = SimplifyingIrBuilder::mulExpr(what_extent, extent(item));
if (simplifyExpr(IrBuilder::isDivisibleExpr(what_extent, required_extent))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this be just what_extent >= required_extent?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe no, for similar reason as #415 where we changed >= to gcd in vectorization helper.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Can you please add this as a comment?

Copy link
Collaborator Author

@zasdfgbnm zasdfgbnm Sep 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a "Proof of correctness" for trimRedundant. The mathematical foundation for this simplification is the Theorem 2.1 in https://github.com/NVIDIA/Fuser/blob/main/doc/reading/iterdomain.md.

// = 512.
namespace {

// From the above example, we can see that how linear_g lives in domain could be
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

General question: how is reordering handled? Since there's no "reorder" op, the ValGraph traversal itself doesn't seem to account for effects by reordering.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reorder does not need to be handled. Just like we do not need to handle reorder in the vectorization helper. However, we do check the orders when we see a merge, and we we run proveLinearAndGetStrideAfterPropagation(basically, any function that uses search checks the order).

Copy link
Collaborator

@naoyam naoyam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks for all the discussions!

@zasdfgbnm zasdfgbnm merged commit cd1e4f5 into main Sep 23, 2024
5 checks passed
@zasdfgbnm zasdfgbnm deleted the linear branch September 23, 2024 22:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants