4
4
# LICENSE file in the root directory of this source tree.
5
5
6
6
import argparse
7
+ from typing import Any
7
8
8
9
import pytest
9
10
import torch
10
11
from packaging import version
11
12
12
- from tensordict import TensorDict
13
+ from tensordict import tensorclass , TensorDict
14
+ from tensordict .utils import logger as tensordict_logger
13
15
14
16
TORCH_VERSION = version .parse (version .parse (torch .__version__ ).base_version )
15
17
16
18
17
- @pytest .fixture
18
- def td ():
19
- return TensorDict (
20
- {
21
- str (i ): {str (j ): torch .randn (16 , 16 , device = "cpu" ) for j in range (16 )}
22
- for i in range (16 )
23
- },
24
- batch_size = [16 ],
25
- device = "cpu" ,
26
- )
19
+ @tensorclass
20
+ class NJT :
21
+ _values : torch .Tensor
22
+ _offsets : torch .Tensor
23
+ _lengths : torch .Tensor
24
+ njt_shape : Any = None
25
+
26
+ @classmethod
27
+ def from_njt (cls , njt_tensor ):
28
+ return NJT (
29
+ _values = njt_tensor ._values ,
30
+ _offsets = njt_tensor ._offsets ,
31
+ _lengths = njt_tensor ._lengths ,
32
+ njt_shape = njt_tensor .size (0 ),
33
+ )
34
+
35
+
36
+ @pytest .fixture (autouse = True , scope = "function" )
37
+ def empty_compiler_cache ():
38
+ torch ._dynamo .reset_code_caches ()
39
+ yield
27
40
28
41
29
42
def _make_njt ():
@@ -34,14 +47,27 @@ def _make_njt():
34
47
)
35
48
36
49
37
- @pytest .fixture
38
- def njt_td ():
50
+ def _njt_td ():
39
51
return TensorDict (
40
52
{str (i ): {str (j ): _make_njt () for j in range (32 )} for i in range (32 )},
41
53
device = "cpu" ,
42
54
)
43
55
44
56
57
+ @pytest .fixture
58
+ def njt_td ():
59
+ return _njt_td ()
60
+
61
+
62
+ @pytest .fixture
63
+ def td ():
64
+ njtd = _njt_td ()
65
+ for k0 , v0 in njtd .items ():
66
+ for k1 , v1 in v0 .items ():
67
+ njtd [k0 , k1 ] = NJT .from_njt (v1 )
68
+ return njtd
69
+
70
+
45
71
@pytest .fixture
46
72
def default_device ():
47
73
if torch .cuda .is_available ():
@@ -52,22 +78,81 @@ def default_device():
52
78
pytest .skip ("CUDA/MPS is not available" )
53
79
54
80
55
- @pytest .mark .parametrize ("consolidated" , [False , True ])
81
+ @pytest .mark .parametrize (
82
+ "consolidated,compile_mode,num_threads" ,
83
+ [
84
+ [False , False , None ],
85
+ [True , False , None ],
86
+ ["within" , False , None ],
87
+ # [True, False, 4],
88
+ # [True, False, 16],
89
+ # [True, "default", None],
90
+ ],
91
+ )
56
92
@pytest .mark .skipif (
57
93
TORCH_VERSION < version .parse ("2.5.0" ), reason = "requires torch>=2.5"
58
94
)
59
95
class TestTo :
60
- def test_to (self , benchmark , consolidated , td , default_device ):
61
- if consolidated :
62
- td = td .consolidate ()
63
- benchmark (lambda : td .to (default_device ))
96
+ def test_to (
97
+ self , benchmark , consolidated , td , default_device , compile_mode , num_threads
98
+ ):
99
+ tensordict_logger .info (f"td size { td .bytes () / 1024 / 1024 :.2f} Mb" )
100
+ pin_mem = default_device .type == "cuda"
101
+ if consolidated is True :
102
+ td = td .consolidate (pin_memory = pin_mem )
103
+
104
+ if consolidated == "within" :
105
+
106
+ def to (td , num_threads ):
107
+ return td .consolidate (pin_memory = pin_mem ).to (
108
+ default_device , num_threads = num_threads
109
+ )
110
+
111
+ else :
112
+
113
+ def to (td , num_threads ):
114
+ return td .to (default_device , num_threads = num_threads )
115
+
116
+ if compile_mode :
117
+ to = torch .compile (to , mode = compile_mode )
118
+
119
+ for _ in range (3 ):
120
+ to (td , num_threads = num_threads )
121
+
122
+ benchmark (to , td , num_threads )
64
123
65
- def test_to_njt (self , benchmark , consolidated , njt_td , default_device ):
66
- if consolidated :
67
- njt_td = njt_td .consolidate ()
68
- benchmark (lambda : njt_td .to (default_device ))
124
+ def test_to_njt (
125
+ self , benchmark , consolidated , njt_td , default_device , compile_mode , num_threads
126
+ ):
127
+ tensordict_logger .info (f"njtd size { njt_td .bytes () / 1024 / 1024 :.2f} Mb" )
128
+ pin_mem = default_device .type == "cuda"
129
+ if consolidated is True :
130
+ njt_td = njt_td .consolidate (pin_memory = pin_mem )
131
+
132
+ if consolidated == "within" :
133
+
134
+ def to (td , num_threads ):
135
+ return td .consolidate (pin_memory = pin_mem ).to (
136
+ default_device , num_threads = num_threads
137
+ )
138
+
139
+ else :
140
+
141
+ def to (td , num_threads ):
142
+ return td .to (default_device , num_threads = num_threads )
143
+
144
+ if compile_mode :
145
+ to = torch .compile (to , mode = compile_mode )
146
+
147
+ for _ in range (3 ):
148
+ to (njt_td , num_threads = num_threads )
149
+
150
+ benchmark (to , njt_td , num_threads )
69
151
70
152
71
153
if __name__ == "__main__" :
72
154
args , unknown = argparse .ArgumentParser ().parse_known_args ()
73
- pytest .main ([__file__ , "--capture" , "no" , "--exitfirst" ] + unknown )
155
+ pytest .main (
156
+ [__file__ , "--capture" , "no" , "--exitfirst" , "--benchmark-group-by" , "func" ]
157
+ + unknown
158
+ )
0 commit comments