-
Notifications
You must be signed in to change notification settings - Fork 1.2k
/
Copy pathcopy.hpp
527 lines (472 loc) · 18.7 KB
/
copy.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
/***************************************************************************************************
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <cute/config.hpp> // CUTE_HOST_DEVICE
#include <cute/tensor_impl.hpp> // cute::Tensor
#include <cute/tensor_predicate.hpp> // cute::TrivialPredTensor
#include <cute/atom/copy_atom.hpp> // cute::Copy_Atom
namespace cute
{
//
// copy_if -- Predicated Copy
//
template <class PrdTensor,
class SrcEngine, class SrcLayout,
class DstEngine, class DstLayout>
CUTE_HOST_DEVICE
void
copy_if(PrdTensor const& pred,
Tensor<SrcEngine, SrcLayout> const& src,
Tensor<DstEngine, DstLayout> & dst)
{
using SrcType = typename SrcEngine::value_type;
using DstType = typename DstEngine::value_type;
CUTE_UNROLL
for (int i = 0; i < size(dst); ++i) {
if (pred(i)) {
dst(i) = static_cast<DstType>(static_cast<SrcType>(src(i)));
}
}
}
//
// copy_if -- Predicated CopyAtom
//
template <class... CopyArgs,
class PredTensor,
class SrcEngine, class SrcLayout,
class DstEngine, class DstLayout>
CUTE_HOST_DEVICE
void
copy_if(Copy_Atom<CopyArgs...> const& copy_atom,
PredTensor const& pred, // (Rest...)
Tensor<SrcEngine, SrcLayout> const& src, // (V,Rest...)
Tensor<DstEngine, DstLayout> & dst) // (V,Rest...)
{
static_assert(SrcLayout::rank == DstLayout::rank, "CopyAtom rank-mismatch.");
auto has_with_bool = cute::is_valid([](auto t)->void_t<decltype(declval<typename decltype(t)::Traits>().with(true))>{}, copy_atom);
if constexpr (SrcLayout::rank == 1) { // Dispatch the copy
if constexpr (has_with_bool) {
copy_atom.with(pred()).call(src, dst);
} else {
if (pred()) { copy_atom.call(src, dst); }
}
} else { // Loop over all but the first mode
constexpr int R = SrcLayout::rank;
Tensor src_v = group_modes<1,R>(src);
Tensor dst_v = group_modes<1,R>(dst);
CUTE_UNROLL
for (int i = 0; i < size<1>(dst_v); ++i) {
if constexpr (has_with_bool) {
copy_atom.with(pred(i)).call(src_v(_,i), dst_v(_,i));
} else {
if (pred(i)) { copy_atom.call(src_v(_,i), dst_v(_,i)); }
}
}
}
}
//
// copy_if -- AutoCopyAsync
//
template <class PrdTensor,
class SrcEngine, class SrcLayout,
class DstEngine, class DstLayout>
CUTE_HOST_DEVICE
void
copy_if(AutoCopyAsync const& cpy,
PrdTensor const& pred,
Tensor<SrcEngine, SrcLayout> const& src,
Tensor<DstEngine, DstLayout> & dst)
{
using SrcElemWithConst = remove_reference_t<typename SrcEngine::reference>;
using SrcType = typename SrcEngine::value_type;
using DstType = typename DstEngine::value_type;
auto copy_op = []() {
#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED)
if constexpr (is_gmem<SrcEngine>::value && is_smem<DstEngine>::value &&
sizeof(SrcType) == sizeof(DstType)) {
if constexpr (is_const_v<SrcElemWithConst> && sizeof(SrcType) == 16) {
return SM80_CP_ASYNC_CACHEGLOBAL<SrcType,DstType>{};
} else if constexpr (sizeof(SrcType) == 4 || sizeof(SrcType) == 8 || sizeof(SrcType) == 16) {
return SM80_CP_ASYNC_CACHEALWAYS<SrcType,DstType>{};
} else {
return UniversalCopy<SrcType,DstType>{};
}
} else {
return UniversalCopy<SrcType,DstType>{};
}
CUTE_GCC_UNREACHABLE;
#else
return UniversalCopy<SrcType,DstType>{};
#endif
}();
CUTE_UNROLL
for (int i = 0; i < size(dst); ++i) {
if (pred(i)) {
copy_op.copy(src(i), dst(i));
}
}
}
//
// copy -- AutoCopyAsync
//
template <class SrcEngine, class SrcLayout,
class DstEngine, class DstLayout>
CUTE_HOST_DEVICE
void
copy(AutoCopyAsync const& cpy,
Tensor<SrcEngine, SrcLayout> const& src, // (V,Rest...)
Tensor<DstEngine, DstLayout> & dst) // (V,Rest...)
{
copy_if(cpy, TrivialPredTensor{}, src, dst);
}
//
// copy -- CopyAtom
//
template <class... CopyArgs,
class SrcEngine, class SrcLayout,
class DstEngine, class DstLayout>
CUTE_HOST_DEVICE
void
copy(Copy_Atom<CopyArgs...> const& copy_atom,
Tensor<SrcEngine, SrcLayout> const& src, // (V,Rest...)
Tensor<DstEngine, DstLayout> & dst) // (V,Rest...)
{
static_assert(SrcLayout::rank == DstLayout::rank, "CopyAtom rank-mismatch.");
if constexpr (SrcLayout::rank == 1) { // Dispatch the copy
copy_atom.call(src, dst);
} else { // Loop over all but the first mode
constexpr int R = SrcLayout::rank;
Tensor src_v = group_modes<1,R>(src);
Tensor dst_v = group_modes<1,R>(dst);
if constexpr (is_static<decltype(shape(src_v))>::value && is_static<decltype(shape(dst_v))>::value) {
CUTE_STATIC_ASSERT_V(size<1>(src_v) == size<1>(dst_v));
// AutoFilter on the Rest-mode
auto dst_null = nullspace(layout<1>(dst_v));
Tensor dst_n = zipped_divide(dst_v, make_tile(shape<0>(dst_v), dst_null)); // ((V, NLL), (_1, Rest))
Tensor src_n = zipped_divide(src_v, make_tile(shape<0>(src_v), dst_null)); // ((V, NLL), (_1, Rest))
CUTE_STATIC_ASSERT_V(size<1>(src_n) == size<1>(dst_n));
CUTE_STATIC_ASSERT_V((cosize<0,1>(dst_n.layout()) == Int<1>{}), "Nullspace definition error");
CUTE_STATIC_ASSERT_V((cosize<0,1>(src_n.layout()) == Int<1>{}), "Error: Ambiguous scatter detected in copy");
CUTE_STATIC_ASSERT_V((size<1,0>(dst_n) == Int<1>{}));
CUTE_STATIC_ASSERT_V((size<1,0>(src_n) == Int<1>{}));
Tensor dst_c = dst_n(make_coord(_,Int<0>{}),make_coord(Int<0>{},_)); // (V, Rest)
Tensor src_c = src_n(make_coord(_,Int<0>{}),make_coord(Int<0>{},_)); // (V, Rest)
CUTE_STATIC_ASSERT_V(size<1>(src_c) == size<1>(dst_c));
CUTE_STATIC_ASSERT_V(shape<0>(dst_c) == shape<0>(dst));
CUTE_STATIC_ASSERT_V(shape<0>(src_c) == shape<0>(src));
CUTE_UNROLL
for (int i = 0; i < size<1>(dst_c); ++i) {
copy_atom.call(src_c(_,i), dst_c(_,i));
}
} else {
CUTE_UNROLL
for (int i = 0; i < size<1>(dst_v); ++i) {
copy_atom.call(src_v(_,i), dst_v(_,i));
}
}
}
}
////////////////////////////////////////////////////////
// Special Auto-Vectorizing, Auto-Filtering Overloads //
////////////////////////////////////////////////////////
// Specialization for AutoVectorizingCopyAssumedAlignment<MaxVecBits>
template <int MaxVecBits, class... Args,
class SrcEngine, class SrcLayout,
class DstEngine, class DstLayout>
CUTE_HOST_DEVICE
void
copy(AutoVectorizingCopyWithAssumedAlignment<MaxVecBits> const&,
Tensor<SrcEngine, SrcLayout> const& src,
Tensor<DstEngine, DstLayout> & dst)
{
constexpr int common_elem = CUTE_STATIC_V(max_common_vector(src, dst));
constexpr int align_bits = CUTE_STATIC_V(gcd(max_alignment(src), max_alignment(dst), Int<MaxVecBits>{}));
static_assert(is_integral<decltype(Int<common_elem>{} * sizeof_bits_v<typename SrcEngine::value_type>)>::value, "Error: Attempting a subbit copy!");
constexpr int vec_bits = gcd(common_elem * sizeof_bits_v<typename SrcEngine::value_type>, align_bits);
if constexpr (common_elem > 1 && ((vec_bits % 8) == 0)) {
// If more than one element vectorizes to 8bits or more, then recast and copy
using VecType = uint_bit_t<vec_bits>;
// Preserve volatility
using SrcVecType = conditional_t<is_volatile_v<typename SrcEngine::element_type>, VecType const volatile, VecType const>;
using DstVecType = conditional_t<is_volatile_v<typename DstEngine::element_type>, VecType volatile, VecType >;
// Recast
Tensor src_v = recast<SrcVecType>(src);
Tensor dst_v = recast<DstVecType>(dst);
return copy_if(TrivialPredTensor{}, src_v, dst_v);
} else {
return copy_if(TrivialPredTensor{}, src, dst);
}
}
template <class Base>
struct AutoFilter {
Base const& base;
CUTE_HOST_DEVICE AutoFilter(Base const& b) : base(b) {}
};
// Specialization for AutoFilter
template <class CopyOp,
class SrcEngine, class SrcLayout,
class DstEngine, class DstLayout>
CUTE_HOST_DEVICE
void
copy(AutoFilter<CopyOp> const& copy_op,
Tensor<SrcEngine, SrcLayout> const& src,
Tensor<DstEngine, DstLayout> & dst)
{
if constexpr (is_constant<true, decltype(size(src) == size(dst))>::value) {
auto dst_null = nullspace(dst.layout());
Tensor dst_n = zipped_divide(dst, dst_null);
Tensor src_n = zipped_divide(src, dst_null);
CUTE_STATIC_ASSERT_V(cosize<0>(dst_n.layout()) == Int<1>{}, "Nullspace definition error");
CUTE_STATIC_ASSERT_V(cosize<0>(src_n.layout()) == Int<1>{}, "Error: Ambiguous scatter detected in copy");
copy(copy_op.base, src_n(Int<0>{},_), dst_n(Int<0>{},_));
} else {
copy(copy_op.base, src, dst);
}
}
// Auto-vectorizing copy for static layouts
template <class SrcEngine, class SrcLayout,
class DstEngine, class DstLayout>
CUTE_HOST_DEVICE
void
copy(Tensor<SrcEngine, SrcLayout> const& src,
Tensor<DstEngine, DstLayout> & dst)
{
if constexpr (is_static<SrcLayout>::value && is_static<DstLayout>::value) {
// Assume Tensors with static layouts (e.g. registers) have pointers that are 128b aligned
return copy(AutoFilter(AutoVectorizingCopyWithAssumedAlignment<128>{}), src, dst);
} else
if constexpr (is_static<decltype(shape(src))>::value && is_static<decltype(shape(dst))>::value) {
// Tensors with static shapes can be filtered, but do not assume that dynamic layouts are aligned.
return copy(AutoFilter(AutoVectorizingCopyWithAssumedAlignment<8>{}), src, dst);
} else {
// Do not assume that dynamic layouts are aligned.
return copy(AutoVectorizingCopyWithAssumedAlignment<8>{}, src, dst);
}
}
// Auto-vectorizing copy with assumed alignment up to 128bit.
template <class SrcEngine, class SrcLayout,
class DstEngine, class DstLayout>
CUTE_HOST_DEVICE
void
copy_aligned(Tensor<SrcEngine, SrcLayout> const& src,
Tensor<DstEngine, DstLayout> & dst)
{
if constexpr (is_static<decltype(shape(src))>::value && is_static<decltype(shape(dst))>::value) {
// Tensors with static shapes can be filtered
return copy(AutoFilter(AutoVectorizingCopyWithAssumedAlignment<128>{}), src, dst);
} else {
return copy(AutoVectorizingCopyWithAssumedAlignment<128>{}, src, dst);
}
}
// Specializaton for Atom AutoVectorizingCopyAssumedAlignment
template <int MaxVecBits, class... Args,
class SrcEngine, class SrcLayout,
class DstEngine, class DstLayout>
CUTE_HOST_DEVICE
void
copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<MaxVecBits>, Args...> const&,
Tensor<SrcEngine, SrcLayout> const& src,
Tensor<DstEngine, DstLayout> & dst)
{
return copy(AutoVectorizingCopyWithAssumedAlignment<MaxVecBits>{}, src, dst);
}
#if defined(CUTE_COPY_ATOM_TMA_SM90_ENABLED)
template <class... CT_Args,
class SrcEngine, class SrcLayout,
class DstEngine, class DstLayout>
CUTE_HOST_DEVICE
void
copy(Copy_Traits<SM90_BULK_COPY_AUTO, CT_Args...> const& atom, // Copy_Traits may or may not have the memory barrier in it already
Tensor<SrcEngine, SrcLayout> const& src,
Tensor<DstEngine, DstLayout> & dst)
{
using SrcType = typename SrcEngine::value_type;
using DstType = typename DstEngine::value_type;
static_assert(cute::is_same<SrcType, DstType>::value);
static_assert((is_gmem<SrcEngine>::value && is_smem<DstEngine>::value) ||
(is_smem<SrcEngine>::value && is_gmem<DstEngine>::value),
"Bulk Copy only supports gmem -> smem or smem -> gmem movement.");
// G2S or S2G dispatch
using BULK_COPY_OP = conditional_t<is_gmem<SrcEngine>::value,
SM90_BULK_COPY_G2S,
SM90_BULK_COPY_S2G>;
// Find the common subtensor of src and dst
auto tiler = max_common_layout(src, dst);
constexpr int vec_elem = decltype(size(tiler))::value;
constexpr int vec_bits = vec_elem * sizeof_bits_v<SrcType>;
static_assert(vec_bits >= 128, "Expected at least 128-bits for BLKCP");
// Construct a new concrete Atom of the vector size
using BulkAtom = Copy_Atom<Copy_Traits<BULK_COPY_OP, Int<vec_bits>, CT_Args...>, SrcType>;
auto bulk_atom = apply(atom.opargs_, [](auto const&... args) { return BulkAtom{args...}; });
return copy(bulk_atom, logical_divide(src, tiler), logical_divide(dst, tiler));
}
// Backwards-compat. Throw out any extra Copy_Atom args.
template <class... CT_Args, class... CA_Args,
class SrcEngine, class SrcLayout,
class DstEngine, class DstLayout>
CUTE_HOST_DEVICE
void
copy(Copy_Atom<Copy_Traits<SM90_BULK_COPY_AUTO, CT_Args...>, CA_Args...> const& atom,
Tensor<SrcEngine, SrcLayout> const& src,
Tensor<DstEngine, DstLayout> & dst)
{
return copy(static_cast<Copy_Traits<SM90_BULK_COPY_AUTO, CT_Args...> const&>(atom), src, dst);
}
#endif // #if defined(CUTE_COPY_ATOM_TMA_SM90_ENABLED)
//
// Decay TiledCopy to CopyAtom
//
template <class CopyAtom, class TV, class Tiler,
class PrdTensor,
class SrcEngine, class SrcLayout,
class DstEngine, class DstLayout>
CUTE_HOST_DEVICE
void
copy_if(TiledCopy<CopyAtom, TV, Tiler> const& tiled_copy,
PrdTensor const& pred,
Tensor<SrcEngine, SrcLayout> const& src,
Tensor<DstEngine, DstLayout> & dst)
{
return copy_if(static_cast<CopyAtom const&>(tiled_copy), pred, src, dst);
}
template <class CopyAtom, class TV, class Tiler,
class SrcEngine, class SrcLayout,
class DstEngine, class DstLayout>
CUTE_HOST_DEVICE
void
copy(TiledCopy<CopyAtom, TV, Tiler> const& tiled_copy,
Tensor<SrcEngine, SrcLayout> const& src,
Tensor<DstEngine, DstLayout> & dst)
{
return copy(static_cast<CopyAtom const&>(tiled_copy), src, dst);
}
template <class TiledCopy, class ThrIdx,
class PrdTensor,
class SrcEngine, class SrcLayout,
class DstEngine, class DstLayout>
CUTE_HOST_DEVICE
void
copy_if(ThrCopy<TiledCopy, ThrIdx> const& thr_copy,
PrdTensor const& pred,
Tensor<SrcEngine, SrcLayout> const& src,
Tensor<DstEngine, DstLayout> & dst) = delete;
template <class TiledCopy, class ThrIdx,
class SrcEngine, class SrcLayout,
class DstEngine, class DstLayout>
CUTE_HOST_DEVICE
void
copy(ThrCopy<TiledCopy, ThrIdx> const& thr_copy,
Tensor<SrcEngine, SrcLayout> const& src,
Tensor<DstEngine, DstLayout> & dst) = delete;
//
// Catch uncaught policies
//
template <class CopyPolicy,
class PredTensor,
class SrcEngine, class SrcLayout,
class DstEngine, class DstLayout>
CUTE_HOST_DEVICE
void
copy_if(CopyPolicy const& cpy,
PredTensor const& prd,
Tensor<SrcEngine, SrcLayout> const& src,
Tensor<DstEngine, DstLayout> & dst)
{
static_assert(dependent_false<CopyPolicy>, "Unrecognized CopyPolicy.");
}
template <class CopyPolicy,
class SrcEngine, class SrcLayout,
class DstEngine, class DstLayout>
CUTE_HOST_DEVICE
void
copy(CopyPolicy const& cpy,
Tensor<SrcEngine, SrcLayout> const& src,
Tensor<DstEngine, DstLayout> & dst)
{
static_assert(dependent_false<CopyPolicy>, "Unrecognized CopyPolicy.");
}
//
// Accept mutable temporaries
//
template <class PrdTensor,
class SrcEngine, class SrcLayout,
class DstEngine, class DstLayout>
CUTE_HOST_DEVICE
void
copy_if(PrdTensor const& pred,
Tensor<SrcEngine, SrcLayout> const& src,
Tensor<DstEngine, DstLayout> && dst)
{
return copy_if(pred, src, dst);
}
template <class CopyPolicy,
class PrdTensor,
class SrcEngine, class SrcLayout,
class DstEngine, class DstLayout>
CUTE_HOST_DEVICE
void
copy_if(CopyPolicy const& copy_policy,
PrdTensor const& pred,
Tensor<SrcEngine, SrcLayout> const& src,
Tensor<DstEngine, DstLayout> && dst)
{
return copy_if(copy_policy, pred, src, dst);
}
template <class SrcEngine, class SrcLayout,
class DstEngine, class DstLayout>
CUTE_HOST_DEVICE
void
copy(Tensor<SrcEngine, SrcLayout> const& src,
Tensor<DstEngine, DstLayout> && dst)
{
return copy(src, dst);
}
template <class CopyPolicy,
class SrcEngine, class SrcLayout,
class DstEngine, class DstLayout>
CUTE_HOST_DEVICE
void
copy(CopyPolicy const& copy_policy,
Tensor<SrcEngine, SrcLayout> const& src,
Tensor<DstEngine, DstLayout> && dst)
{
return copy(copy_policy, src, dst);
}
template <class SrcEngine, class SrcLayout,
class DstEngine, class DstLayout>
CUTE_HOST_DEVICE
void
copy_aligned(Tensor<SrcEngine, SrcLayout> const& src,
Tensor<DstEngine, DstLayout> && dst)
{
return copy_aligned(src, dst);
}
} // end namespace cute