-
Notifications
You must be signed in to change notification settings - Fork 676
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[LLVMGPUVectorDistribute] Add support for distributing masked reads/writes #19830
Conversation
55403f9
to
4c643cd
Compare
reads/writes This commit adds support for distributing masked reads/writes that originates from `vector.create_mask` op. Signed-off-by: Manupa Karunaratne <[email protected]>
…ed form as other distributions. Signed-off-by: Manupa Karunaratne <[email protected]>
Signed-off-by: Manupa Karunaratne <[email protected]>
4c643cd
to
b6ed503
Compare
|
||
// CHECK: %[[MASK_EXTR:.+]] = vector.extract_strided_slice %[[MASK]] {offsets = [0, 0], sizes = [2, 8], strides = [1, 1]} : vector<8x8xi1> to vector<2x8xi1> | ||
// CHECK: %[[READ:.+]] = vector.transfer_read %arg0{{.*}}, %[[MASK_EXTR]] {in_bounds = [true, true]} : memref<?x128xf16>, vector<2x8xf16> | ||
// CHECK: vector.transfer_write %[[READ]], %arg1{{.*}}, %[[MASK_EXTR]] {in_bounds = [true, true]} : vector<2x8xf16>, memref<?x128xf16> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add tests for nontrivial transfer layouts, e.g.
Transpose
iree/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_vector_distribution.mlir
Line 147 in be3c729
permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, d2)>} |
and broadcast
iree/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_vector_distribution.mlir
Line 104 in be3c729
permutation_map = affine_map<(d0, d1, d2, d3) -> (0, 0)>} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I ve added cases for permutations now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I ve added a broadcast test case too
SmallVector<int64_t> strides(innerVectorType.getRank(), 1); | ||
slicedMask = rewriter.create<vector::ExtractStridedSliceOp>( | ||
readOp.getLoc(), mask, sliceMaskOffsets, innerVectorType.getShape(), | ||
strides); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this needs to account for the transfer permutation map because the mask is applied before broadcasting/permutations
An optional SSA value mask may be specified to mask out elements read from the MemRef/Tensor. The mask type is an i1 vector with a shape that matches how elements are read from the MemRef/Tensor, before any permutation or broadcasting. Elements whose corresponding mask element is 0 are masked out and replaced with padding.
https://mlir.llvm.org/docs/Dialects/Vector/#vectortransfer_read-vectortransferreadop
A little confusing, but makes sense after considering how this operation is lowered to masked loads.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think Im going to bail on broadcasts for now ( i ll add code to fail here) as they are not needed immediately because there is no good way of supporting that unless we change how transfer_reads are lowered.
I ll add support for permutations.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
after having written the code, I dont think the distribution has to account for permutations or help me understand why it has to..
As for an example:
pre distribution IR:
%41 = vector.create_mask %c128, %dyn, %c1 : vector<128x256x1xi1>
%42 = vector.transfer_read %arg0[%c0, %c0, %c0], %cst_6, %41 {in_bounds = [true, true], permutation_map = affine_map<(d0, d1, d2) -> (d2, d1, d0)>} : memref<128x?x1xf16>, vector<1x256x128xf16>
So it is already permuted in the pre-distribution domain.
What actually breaks is the vector layout enforcement that happens early on. So when layout is enforced on the create_mask op, we need to permute layout.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yep. that works. I ve added a test now.
I ll add another one with a minor.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think Im going to bail on broadcasts for now ( i ll add code to fail here)
Understandable, I think this was one of the main reasons I bailed on masks altogether in the first version of this pattern.
after having written the code, I dont think the distribution has to account for permutations
yep. that works. I ve added a test now.
Cool if it works sg (although I'm not sure I follow why exactly). The distribution pattern does unrolling that I thought would have to account for the permutation (even if the distribution portion is handled by the layout propagation like you said).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool if it works sg (although I'm not sure I follow why exactly). The distribution pattern does unrolling that I thought would have to account for the permutation (even if the distribution portion is handled by the layout propagation like you said).
well you are right there.
I need to get StaticTileOffsetRange based on the layout of the mask not the result of the read -- if that make sense.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(now updated (with tests) to reflect above)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think Im going to bail on broadcasts for now ( i ll add code to fail here)
Understandable, I think this was one of the main reasons I bailed on masks altogether in the first version of this pattern.
With same logic, i ve added this now and it works.
In the vector layout enforcement, we just needed to drop the broadcasted dims.
SmallVector<int64_t> getPackedShapeForUndistributedDim(int64_t dim) const; | ||
|
||
// Get the distributed shape but has the same rank as the undistributed shape. | ||
SmallVector<int64_t> getDistributedUnpackedShape() const; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Echoing @bjacob here: #19905 (comment)
These could be freestanding functions instead of class members if it's not too much trouble.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes they could be (and I m happy to change to that).. but I d like to understand/learn the rationale especially since it uses the state of object to perform the computation required especially when they are MLIR attributes.
To me f(object, ... ) vs object.f(...) is traditional argument in C++ where the former is preferred for encapsulation reasons as it cannot access private/protected members but I thought that does not necessarily hold for MLIR tablegen'd attributes.
Signed-off-by: Manupa Karunaratne <[email protected]>
Signed-off-by: Manupa Karunaratne <[email protected]>
* same test case covers where the map is a permuted identiy minor. Signed-off-by: Manupa Karunaratne <[email protected]>
0dbc4da
to
9253abc
Compare
Signed-off-by: Manupa Karunaratne <[email protected]>
Hi @qedawkins, Thanks for the review. I ve added support for permutations and broadcast now. Im awaiting a response on freestanding vs member functions tablegen'd attributes as I m not seeing a difference in the two especially when the attribute class does not have protected/private members. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It took me a while to figure out why you needed to do the packing/unpacking/deinterleaving thing, but this makes sense. LGTM
SmallVector<unsigned> permutation; | ||
AffineMap permMap = read.getPermutationMap(); | ||
bool isSupportedPerm = | ||
permMap.isPermutationOfMinorIdentityWithBroadcasting(permutation); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe this is required by the verifier so technically not required, but fine to have in case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
basically it gets the permutation in the result domain assuming broadcasts for missing dims.
which is the permute we need to when we go from read result layout to mask layout.
this was inspired by how its actually lowered (after you pointed it out :))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah right I missed that it's populating permutation
. You can ignore me then
This commit adds support for distributing masked reads/writes that originates from
vector.create_mask
op.