-
Notifications
You must be signed in to change notification settings - Fork 52
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add utility lower_utils::proveLinearAndGetStride
#2911
Conversation
lower_utils::proveLinearAndGetStride
template <typename T> | ||
using MaybeUniqueOwningPtr = dynamic_type:: | ||
DynamicType<dynamic_type::NoContainers, T*, std::unique_ptr<T>>; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
End up not using this, but seems useful to just leave it there.
static ExprPath getExprsBetween( | ||
const ValGraph& graph, | ||
std::vector<NodeType> from, | ||
std::vector<NodeType> to) { | ||
ValGraphBFS bfs(graph, std::move(from), std::move(to)); | ||
bfs.traverse(); | ||
return bfs.getShortestExprPath(); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
End up not using, but feel that it is helpful to keep.
!build |
@naoyam This PR is ready for review again. Highlighted changes:
I believe after the change, the code is much more modular and easier to understand now. |
csrc/device_lower/utils.cpp
Outdated
// inner_extent=3, | ||
// selected_extent=2} | ||
// => | ||
// PartOf{what=24, inner_extent=6, selected_extent=2} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To add a little more context here, PartOf{what=24, inner_extent=2, selected_extent=12}
can be generated by merging linear_g
of extent 12 and a domain of extent 2, i.e., merge(linear_g, [2]) -> PartOf{what=24, inner_extent=2, selected_extent=12}
.
Then, what propagation would create this?
PartOf{
what=PartOf{what=24, inner_extent=2, selected_extent=12},
inner_extent=3,
selected_extent=2}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, I'm confused. PartOf{what=24, inner_extent=2, selected_extent=12}
represents the 24
node. Then, IIUC, PartOf{ what=PartOf{what=24, inner_extent=2, selected_extent=12},...
should be generated by propagating through either merge or split, but your diagram indicates something propagated along the 12
node. What am I missing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When we propagate from the red 2 to the 24, we first propagate through the merge that generates 6, at this point, we will get:
PartOf{what=6, inner_extent=3, selected_extent=2}
The next step is to propagate through the merge that generates the 12, and we will replace the 6
with
PartOf{what=12, inner_extent=nullptr, selected_extent=6}
which we get:
PartOf{what=PartOf{what=12, inner_extent=nullptr, selected_extent=6}, inner_extent=3, selected_extent=2}
which will be simplified as
PartOf{what=12, inner_extent=3, selected_extent=2}
Then we will propagate through the merge that generates the 24, which will replace the 12
with
PartOf{what=24, inner_extent=2, selected_extent=12}
and get
PartOf{what=PartOf{what=24, inner_extent=2, selected_extent=12}, inner_extent=3, selected_extent=2}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot. This makes sense to me. Can you add this as a code comment as well? I think it's very helpful to understand the logic.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added to the comment before propagate
. Also, I moved the definition of simplify
below proveLinearAndGetStrideAfterPropagation
so that the order the reader read comments is more understandable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I also added a "Proof of correctness" to the comment of cancelCommonFactors
. The mathematical foundation for this simplification is the Theorem 2.1 in https://github.com/NVIDIA/Fuser/blob/main/doc/reading/iterdomain.md.
csrc/device_lower/utils.cpp
Outdated
// canceling the last items in `what`. | ||
// | ||
// Example: | ||
// PartOf{what=[5,3,2], inner_extent=42} => PartOf{what=[5], inner_extent=7} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you also mention how PartOf{what=[5,3,2], inner_extent=42}
can be generated?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not a very good example. A better example could be
PartOf{what=[7,3,2], inner_extent=30}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you mean a domain of extent 30 is split twice? I think I understand each split would generate [1, 42], inner_extent=42
and [5, 3, 2]
, but I don't follow why the could be combined.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, they split twice. They could combine because, for example, the [1, 42]
is the loop domain, and the [5, 3, 2]
is allocation domain.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, right. I think I finally understand your question after you answered it yourself 😂.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So, about this simplification, I suppose this is the same simplification we discussed earlier ([5, 3]
-> [5]
). In the case of [5, 3, 2], inner_extent=42
-> [5], inner_extent=7
, this means that we view the inner two IDs of [3, 2]
as a non-modifiable unit, so it should be effectively equivalent to [5], inner_extent=7
. However, what would happen if there's also an expr to traverse that uses [3, 2]
? For example, let's say there's a split of 2
to [1, 3]
. Since what
no longer has [2]
, this propagation would be just a no-op? Or, what about, e.g., a merge of 3
as the outer and 5
as the inner?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a really good question. I love this question because it really touches the mathematical foundation of the proving we implemented.
When we start from linear_g
and propagate to 5, 3, 2
, mathematically, we are proving that linear_g ~ f(5, 3, 2)
, where f
is a specific way to operates on its parameters, and ~
denote "equivalent" (e.g. 6 ~ [3, 2]
if you have 6 = merge(3, 2)
).
When we run the simplification, we simplify f(5, 3, 2)
into g(5)
, where g
is another specific way to operate on its parameters. The process of simplifying f(5, 3, 2)
into g(5)
is just like in a math class, we simplify y
and z
is canceled means the value of f(x, y, z)
does not depend on the value of y
and z
. Similarly, for this case, simplifying f(5, 3, 2)
into g(5)
means that, the index of f(5, 3, 2)
does not depend on the index of 3
and 2
.
When we further propagate from g(5)
to domain
, mathematically, we are proving that g(5) ~ h(x1, x2, ...)
where x1
, x2
, ... are ValGroups in domain
. Note that this process only depend on the exprs between 5
and domain
, nothing else. The split of 2
into [1, 3]
is not on this path, so it will be unrelated. Merging 3
with 5
is related, but the only information that is related here is "3 is not a parameter of g
", it does not matter where this 3
comes from and whether it is connected to linear_g
or not, because here, we are just focusing on the subproblem of proving g(5) ~ h(x1, x2, ...)
and linear_g
is unrelated to this subproblem. This is just like, in a math class, (assuming y
comes from, and whether g(x)
was simplified from a f(x, y, z)
that contains y
is unrelated.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So, IIUC, what you're saying is that no matter what other exprs are traversed, if it uses anything other than 5
, it should be just ignored as it shouldn't matter. Is my understanding correct?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think what I missed was this part:
inner_extent is a multiple of the extent of the last items in
what
,
So, because of this condition, the canceled factor should not matter for the remaining traversal.
It finally makes sense to me. Thanks for the updated comment. That helped a lot.
csrc/device_lower/utils.cpp
Outdated
count++; | ||
const auto& item = dq.at(dq.size() - count); | ||
what_extent = SimplifyingIrBuilder::mulExpr(what_extent, extent(item)); | ||
if (simplifyExpr(IrBuilder::isDivisibleExpr(what_extent, required_extent)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can this be just what_extent >= required_extent
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe no, for similar reason as #415 where we changed >=
to gcd
in vectorization helper.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. Can you please add this as a comment?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a "Proof of correctness" for trimRedundant
. The mathematical foundation for this simplification is the Theorem 2.1 in https://github.com/NVIDIA/Fuser/blob/main/doc/reading/iterdomain.md.
// = 512. | ||
namespace { | ||
|
||
// From the above example, we can see that how linear_g lives in domain could be |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
General question: how is reordering handled? Since there's no "reorder" op, the ValGraph traversal itself doesn't seem to account for effects by reordering.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reorder does not need to be handled. Just like we do not need to handle reorder in the vectorization helper. However, we do check the orders when we see a merge, and we we run proveLinearAndGetStrideAfterPropagation
(basically, any function that uses search
checks the order).
Co-authored-by: Naoya Maruyama <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thanks for all the discussions!
Add utility
lower_utils::proveLinearAndGetStride
, which proves that the providedlinear_g
is linear with respect to the givendomain
, and returns the linear coefficient (stride). This utility will be used in the lowering of mma to fill in the "Leading dimension byte offset" and "Stride dimension byte offset" in the matrix descriptor.See:
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-shared-memory-layout-matrix-descriptor
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#strides