@@ -3521,9 +3521,10 @@ def _reduce_vals_and_metadata(self, *, dtype=NO_DEFAULT, requires_metadata):
3521
3521
3522
3522
flat_size = []
3523
3523
start = 0
3524
+ sorting_index = 0
3524
3525
3525
3526
def add_single_value (value , key , metadata_dict , dtype , shape , flat_size ):
3526
- nonlocal start
3527
+ nonlocal start , sorting_index
3527
3528
n = value .element_size () * value .numel ()
3528
3529
if need_padding :
3529
3530
pad = n % 8
@@ -3541,7 +3542,10 @@ def add_single_value(value, key, metadata_dict, dtype, shape, flat_size):
3541
3542
start ,
3542
3543
stop ,
3543
3544
pad ,
3545
+ flat_size [- 1 ],
3546
+ sorting_index ,
3544
3547
)
3548
+ sorting_index = sorting_index + 1
3545
3549
start = stop
3546
3550
3547
3551
def assign (
@@ -10395,6 +10399,7 @@ def to(self, *args, **kwargs) -> T:
10395
10399
pin_memory = non_blocking_pin ,
10396
10400
num_threads = num_threads ,
10397
10401
non_blocking = non_blocking ,
10402
+ compilable = is_dynamo_compiling (),
10398
10403
)
10399
10404
10400
10405
if non_blocking is None :
@@ -10452,14 +10457,42 @@ def to_pinmem(tensor, _to=to):
10452
10457
self ._sync_all ()
10453
10458
return result
10454
10459
10455
- def _to_consolidated (self , * , device , pin_memory , num_threads , non_blocking ):
10460
+ def _to_consolidated (self , * , device , pin_memory , num_threads , non_blocking , compilable ):
10456
10461
if num_threads is None :
10457
10462
# unspecified num_threads should mean 0
10458
10463
num_threads = 0
10464
+
10459
10465
storage = self ._consolidated ["storage" ]
10460
- if pin_memory :
10461
- storage = storage .pin_memory ()
10462
- storage_cast = storage .to (device , non_blocking = True )
10466
+
10467
+ @torch .compiler .disable ()
10468
+ def to (storage ):
10469
+ if pin_memory :
10470
+ storage = storage .pin_memory ()
10471
+ storage_cast = storage .to (device , non_blocking = True )
10472
+ return storage_cast
10473
+ storage_cast = to (storage )
10474
+
10475
+ if compilable :
10476
+ result = self ._to_consolidated_compile (device = device , num_threads = num_threads , storage_cast = storage_cast )
10477
+ else :
10478
+ result = self ._to_consolidated_eager (device = device , num_threads = num_threads , storage_cast = storage_cast )
10479
+
10480
+ if non_blocking in (False , None ):
10481
+ if device .type == "cuda" and non_blocking is False :
10482
+ # sending to CUDA force sync
10483
+ cuda_device = device
10484
+ elif storage .device .type == "cuda" :
10485
+ # sending from cuda: need sync unless intentionally not asked for
10486
+ cuda_device = storage .device .type
10487
+ else :
10488
+ cuda_device = None
10489
+ if cuda_device is not None :
10490
+ torch .cuda .current_stream (cuda_device ).synchronize ()
10491
+
10492
+ return result
10493
+
10494
+ def _to_consolidated_eager (self , * , device , num_threads , storage_cast ):
10495
+
10463
10496
untyped_storage = storage_cast .untyped_storage ()
10464
10497
10465
10498
def set_ (x ):
@@ -10528,20 +10561,111 @@ def copy_dict(d):
10528
10561
}
10529
10562
10530
10563
result ._consolidated ["metadata" ] = copy_dict (self ._consolidated ["metadata" ])
10531
- if non_blocking in (False , None ):
10532
- if device .type == "cuda" and non_blocking is False :
10533
- # sending to CUDA force sync
10534
- cuda_device = device
10535
- elif storage .device .type == "cuda" :
10536
- # sending from cuda: need sync unless intentionally not asked for
10537
- cuda_device = storage .device .type
10538
- else :
10539
- cuda_device = None
10540
- if cuda_device is not None :
10541
- torch .cuda .current_stream (cuda_device ).synchronize ()
10542
-
10543
10564
return result
10544
10565
10566
+ def _to_consolidated_compile (self , * , device , num_threads , storage_cast ):
10567
+
10568
+ def get_tensors_length (metadata , lengths = None , pos = None , keys = None , prefix = ()):
10569
+ root = False
10570
+ if lengths is None :
10571
+ lengths = []
10572
+ pos = []
10573
+ keys = []
10574
+ root = True
10575
+ for k , v in metadata ["leaves" ].items ():
10576
+ lengths .append (v [- 2 ])
10577
+ pos .append (v [- 1 ])
10578
+ keys .append (prefix + (k ,))
10579
+ for k , d in metadata .items ():
10580
+ if "leaves" in d :
10581
+ get_tensors_length (d , lengths = lengths , pos = pos , keys = keys , prefix = prefix + (k ,))
10582
+ if root :
10583
+ # l = torch.empty(len(lengths), dtype=torch.long)
10584
+ # l[torch.as_tensor(pos)] = torch.as_tensor(lengths)
10585
+ out0 = [None , ] * len (pos )
10586
+ out1 = [None , ] * len (pos )
10587
+ for p , l , k in zip (pos , lengths , keys ):
10588
+ out0 [p ] = k
10589
+ out1 [p ] = l
10590
+ return out0 , out1
10591
+
10592
+ def split_storage (consolidated ):
10593
+ keys , splits = get_tensors_length (consolidated ["metadata" ])
10594
+ return dict (zip (keys , consolidated ["storage" ].split (splits )))
10595
+
10596
+ if num_threads is None :
10597
+ # unspecified num_threads should mean 0
10598
+ num_threads = 0
10599
+
10600
+ _consolidated = {"storage" : storage_cast }
10601
+ if "metadata" in self ._consolidated :
10602
+ # faster than deepcopy
10603
+ def copy_dict (d ):
10604
+ return {
10605
+ k : v if not isinstance (v , dict ) else copy_dict (v )
10606
+ for k , v in d .items ()
10607
+ }
10608
+
10609
+ _consolidated ["metadata" ] = copy_dict (self ._consolidated ["metadata" ])
10610
+
10611
+ slice_map = split_storage (_consolidated )
10612
+
10613
+ def set_ (name , x ):
10614
+ if not isinstance (name , tuple ):
10615
+ name = (name ,)
10616
+ if x .is_nested :
10617
+ from torch ._subclasses .fake_tensor import FakeTensor
10618
+ from torch ._subclasses .functional_tensor import FunctionalTensor
10619
+ from torch .nested ._internal .nested_tensor import (
10620
+ _tensor_symint_registry ,
10621
+ NestedTensor ,
10622
+ )
10623
+ from torch .nested ._internal .ops import extract_kwargs
10624
+
10625
+ if x .layout != torch .jagged :
10626
+ raise RuntimeError (
10627
+ "to(device) with nested tensors that do not have a jagged layout is not implemented yet. "
10628
+ "Please raise an issue on GitHub."
10629
+ )
10630
+ kwargs = extract_kwargs (x )
10631
+ values = x ._values
10632
+ lengths = x ._lengths
10633
+ offsets = x ._offsets
10634
+ storage_offsets = slice_map [(* name [:- 1 ], "<NJT_OFFSETS>" + name [- 1 ],)]
10635
+ kwargs ["offsets" ] = storage_offsets .view (offsets .dtype ).view (offsets .shape )
10636
+ if lengths is not None :
10637
+ storage_lengths = slice_map [(* name [:- 1 ], "<NJT_LENGTHS>" + name [- 1 ],)]
10638
+ kwargs ["lengths" ] = storage_lengths .view (lengths .dtype ).view (lengths .shape )
10639
+ ragged_source = lengths
10640
+ else :
10641
+ ragged_source = offsets
10642
+ new_thing = kwargs .get ("lengths" , kwargs .get ("offsets" ))
10643
+ if isinstance (new_thing , (FakeTensor , FunctionalTensor )):
10644
+ from torch ._subclasses .functional_tensor import (
10645
+ mb_unwrap_functional_tensor ,
10646
+ )
10647
+
10648
+ # Temporary hack until we have the union find
10649
+ tgt = mb_unwrap_functional_tensor (new_thing )
10650
+ src = mb_unwrap_functional_tensor (ragged_source )
10651
+ tgt .nested_int_memo = src .nested_int_memo
10652
+ else :
10653
+ _tensor_symint_registry [new_thing ] = _tensor_symint_registry [
10654
+ ragged_source
10655
+ ]
10656
+
10657
+ storage_values = slice_map [(* name [:- 1 ], "<NJT_VALUES>" + name [- 1 ],)]
10658
+ return NestedTensor (
10659
+ storage_values .view (values .dtype ).view (values .shape ),
10660
+ ** kwargs ,
10661
+ )
10662
+ return slice_map [name ].view (x .dtype ).view (x .shape )
10663
+
10664
+ result = self ._fast_apply (
10665
+ set_ , device = torch .device (device ), num_threads = num_threads , named = True , nested_keys = True ,
10666
+ )
10667
+ result ._consolidated = _consolidated
10668
+ return result
10545
10669
def _sync_all (self ):
10546
10670
if _has_cuda :
10547
10671
# TODO: dynamo doesn't like torch.cuda.is_initialized
0 commit comments