Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LLVMGPUVectorDistribute] Add support for distributing masked reads/writes #19830

Merged
merged 7 commits into from
Feb 20, 2025

Conversation

manupak
Copy link
Contributor

@manupak manupak commented Jan 28, 2025

This commit adds support for distributing masked reads/writes that originates from vector.create_mask op.

@manupak manupak requested a review from Groverkss January 28, 2025 14:31
@manupak manupak changed the title [LLVMGPUVectorDistribute] Add support for distributing masked [LLVMGPUVectorDistribute] Add support for distributing masked reads/writes Jan 28, 2025
reads/writes

This commit adds support for distributing masked reads/writes
that originates from `vector.create_mask` op.

Signed-off-by: Manupa Karunaratne <[email protected]>
…ed form

as other distributions.

Signed-off-by: Manupa Karunaratne <[email protected]>
Signed-off-by: Manupa Karunaratne <[email protected]>

// CHECK: %[[MASK_EXTR:.+]] = vector.extract_strided_slice %[[MASK]] {offsets = [0, 0], sizes = [2, 8], strides = [1, 1]} : vector<8x8xi1> to vector<2x8xi1>
// CHECK: %[[READ:.+]] = vector.transfer_read %arg0{{.*}}, %[[MASK_EXTR]] {in_bounds = [true, true]} : memref<?x128xf16>, vector<2x8xf16>
// CHECK: vector.transfer_write %[[READ]], %arg1{{.*}}, %[[MASK_EXTR]] {in_bounds = [true, true]} : vector<2x8xf16>, memref<?x128xf16>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add tests for nontrivial transfer layouts, e.g.

Transpose

permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, d2)>}

and broadcast

permutation_map = affine_map<(d0, d1, d2, d3) -> (0, 0)>}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I ve added cases for permutations now.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I ve added a broadcast test case too

SmallVector<int64_t> strides(innerVectorType.getRank(), 1);
slicedMask = rewriter.create<vector::ExtractStridedSliceOp>(
readOp.getLoc(), mask, sliceMaskOffsets, innerVectorType.getShape(),
strides);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this needs to account for the transfer permutation map because the mask is applied before broadcasting/permutations

An optional SSA value mask may be specified to mask out elements read from the MemRef/Tensor. The mask type is an i1 vector with a shape that matches how elements are read from the MemRef/Tensor, before any permutation or broadcasting. Elements whose corresponding mask element is 0 are masked out and replaced with padding.

https://mlir.llvm.org/docs/Dialects/Vector/#vectortransfer_read-vectortransferreadop

A little confusing, but makes sense after considering how this operation is lowered to masked loads.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think Im going to bail on broadcasts for now ( i ll add code to fail here) as they are not needed immediately because there is no good way of supporting that unless we change how transfer_reads are lowered.

I ll add support for permutations.

Copy link
Contributor Author

@manupak manupak Feb 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

after having written the code, I dont think the distribution has to account for permutations or help me understand why it has to..

As for an example:

pre distribution IR:

%41 = vector.create_mask %c128, %dyn, %c1 : vector<128x256x1xi1>
%42 = vector.transfer_read %arg0[%c0, %c0, %c0], %cst_6, %41 {in_bounds = [true, true], permutation_map = affine_map<(d0, d1, d2) -> (d2, d1, d0)>} : memref<128x?x1xf16>, vector<1x256x128xf16>

So it is already permuted in the pre-distribution domain.
What actually breaks is the vector layout enforcement that happens early on. So when layout is enforced on the create_mask op, we need to permute layout.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep. that works. I ve added a test now.
I ll add another one with a minor.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think Im going to bail on broadcasts for now ( i ll add code to fail here)

Understandable, I think this was one of the main reasons I bailed on masks altogether in the first version of this pattern.

after having written the code, I dont think the distribution has to account for permutations
yep. that works. I ve added a test now.

Cool if it works sg (although I'm not sure I follow why exactly). The distribution pattern does unrolling that I thought would have to account for the permutation (even if the distribution portion is handled by the layout propagation like you said).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool if it works sg (although I'm not sure I follow why exactly). The distribution pattern does unrolling that I thought would have to account for the permutation (even if the distribution portion is handled by the layout propagation like you said).

well you are right there.
I need to get StaticTileOffsetRange based on the layout of the mask not the result of the read -- if that make sense.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(now updated (with tests) to reflect above)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think Im going to bail on broadcasts for now ( i ll add code to fail here)

Understandable, I think this was one of the main reasons I bailed on masks altogether in the first version of this pattern.

With same logic, i ve added this now and it works.
In the vector layout enforcement, we just needed to drop the broadcasted dims.

SmallVector<int64_t> getPackedShapeForUndistributedDim(int64_t dim) const;

// Get the distributed shape but has the same rank as the undistributed shape.
SmallVector<int64_t> getDistributedUnpackedShape() const;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Echoing @bjacob here: #19905 (comment)

These could be freestanding functions instead of class members if it's not too much trouble.

Copy link
Contributor Author

@manupak manupak Feb 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes they could be (and I m happy to change to that).. but I d like to understand/learn the rationale especially since it uses the state of object to perform the computation required especially when they are MLIR attributes.

To me f(object, ... ) vs object.f(...) is traditional argument in C++ where the former is preferred for encapsulation reasons as it cannot access private/protected members but I thought that does not necessarily hold for MLIR tablegen'd attributes.

* same test case covers where the map is a permuted identiy minor.

Signed-off-by: Manupa Karunaratne <[email protected]>
@manupak manupak requested a review from qedawkins February 19, 2025 10:37
@manupak
Copy link
Contributor Author

manupak commented Feb 19, 2025

Hi @qedawkins,

Thanks for the review. I ve added support for permutations and broadcast now.

Im awaiting a response on freestanding vs member functions tablegen'd attributes as I m not seeing a difference in the two especially when the attribute class does not have protected/private members.
(Im just looking for a reason to convince myself to own the change)

Copy link
Contributor

@qedawkins qedawkins left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It took me a while to figure out why you needed to do the packing/unpacking/deinterleaving thing, but this makes sense. LGTM

SmallVector<unsigned> permutation;
AffineMap permMap = read.getPermutationMap();
bool isSupportedPerm =
permMap.isPermutationOfMinorIdentityWithBroadcasting(permutation);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this is required by the verifier so technically not required, but fine to have in case.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

basically it gets the permutation in the result domain assuming broadcasts for missing dims.
which is the permute we need to when we go from read result layout to mask layout.

this was inspired by how its actually lowered (after you pointed it out :))

https://github.com/llvm/llvm-project/blob/3e61c1ab7f5d9666db88069d49c8916c40fae5ea/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp#L107-L152

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah right I missed that it's populating permutation. You can ignore me then

@manupak manupak merged commit b113829 into iree-org:main Feb 20, 2025
40 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants