Skip to content

Commit 0d1a706

Browse files
committed
[Performance] Make _to_consolidated compatible with compile
ghstack-source-id: f1f6ed823acf899a2c45b391063ca8b483147256 Pull Request resolved: #1041
1 parent 75b33c4 commit 0d1a706

File tree

5 files changed

+429
-139
lines changed

5 files changed

+429
-139
lines changed

benchmarks/common/h2d_test.py

+108-23
Original file line numberDiff line numberDiff line change
@@ -4,26 +4,39 @@
44
# LICENSE file in the root directory of this source tree.
55

66
import argparse
7+
from typing import Any
78

89
import pytest
910
import torch
1011
from packaging import version
1112

12-
from tensordict import TensorDict
13+
from tensordict import tensorclass, TensorDict
14+
from tensordict.utils import logger as tensordict_logger
1315

1416
TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version)
1517

1618

17-
@pytest.fixture
18-
def td():
19-
return TensorDict(
20-
{
21-
str(i): {str(j): torch.randn(16, 16, device="cpu") for j in range(16)}
22-
for i in range(16)
23-
},
24-
batch_size=[16],
25-
device="cpu",
26-
)
19+
@tensorclass
20+
class NJT:
21+
_values: torch.Tensor
22+
_offsets: torch.Tensor
23+
_lengths: torch.Tensor
24+
njt_shape: Any = None
25+
26+
@classmethod
27+
def from_njt(cls, njt_tensor):
28+
return NJT(
29+
_values=njt_tensor._values,
30+
_offsets=njt_tensor._offsets,
31+
_lengths=njt_tensor._lengths,
32+
njt_shape=njt_tensor.size(0),
33+
)
34+
35+
36+
@pytest.fixture(autouse=True, scope="function")
37+
def empty_compiler_cache():
38+
torch._dynamo.reset_code_caches()
39+
yield
2740

2841

2942
def _make_njt():
@@ -34,14 +47,27 @@ def _make_njt():
3447
)
3548

3649

37-
@pytest.fixture
38-
def njt_td():
50+
def _njt_td():
3951
return TensorDict(
4052
{str(i): {str(j): _make_njt() for j in range(32)} for i in range(32)},
4153
device="cpu",
4254
)
4355

4456

57+
@pytest.fixture
58+
def njt_td():
59+
return _njt_td()
60+
61+
62+
@pytest.fixture
63+
def td():
64+
njtd = _njt_td()
65+
for k0, v0 in njtd.items():
66+
for k1, v1 in v0.items():
67+
njtd[k0, k1] = NJT.from_njt(v1)
68+
return njtd
69+
70+
4571
@pytest.fixture
4672
def default_device():
4773
if torch.cuda.is_available():
@@ -52,22 +78,81 @@ def default_device():
5278
pytest.skip("CUDA/MPS is not available")
5379

5480

55-
@pytest.mark.parametrize("consolidated", [False, True])
81+
@pytest.mark.parametrize(
82+
"consolidated,compile_mode,num_threads",
83+
[
84+
[False, False, None],
85+
[True, False, None],
86+
["within", False, None],
87+
# [True, False, 4],
88+
# [True, False, 16],
89+
# [True, "default", None],
90+
],
91+
)
5692
@pytest.mark.skipif(
5793
TORCH_VERSION < version.parse("2.5.0"), reason="requires torch>=2.5"
5894
)
5995
class TestTo:
60-
def test_to(self, benchmark, consolidated, td, default_device):
61-
if consolidated:
62-
td = td.consolidate()
63-
benchmark(lambda: td.to(default_device))
96+
def test_to(
97+
self, benchmark, consolidated, td, default_device, compile_mode, num_threads
98+
):
99+
tensordict_logger.info(f"td size {td.bytes() / 1024 / 1024:.2f} Mb")
100+
pin_mem = default_device.type == "cuda"
101+
if consolidated is True:
102+
td = td.consolidate(pin_memory=pin_mem)
103+
104+
if consolidated == "within":
105+
106+
def to(td, num_threads):
107+
return td.consolidate(pin_memory=pin_mem).to(
108+
default_device, num_threads=num_threads
109+
)
110+
111+
else:
112+
113+
def to(td, num_threads):
114+
return td.to(default_device, num_threads=num_threads)
115+
116+
if compile_mode:
117+
to = torch.compile(to, mode=compile_mode)
118+
119+
for _ in range(3):
120+
to(td, num_threads=num_threads)
121+
122+
benchmark(to, td, num_threads)
64123

65-
def test_to_njt(self, benchmark, consolidated, njt_td, default_device):
66-
if consolidated:
67-
njt_td = njt_td.consolidate()
68-
benchmark(lambda: njt_td.to(default_device))
124+
def test_to_njt(
125+
self, benchmark, consolidated, njt_td, default_device, compile_mode, num_threads
126+
):
127+
tensordict_logger.info(f"njtd size {njt_td.bytes() / 1024 / 1024 :.2f} Mb")
128+
pin_mem = default_device.type == "cuda"
129+
if consolidated is True:
130+
njt_td = njt_td.consolidate(pin_memory=pin_mem)
131+
132+
if consolidated == "within":
133+
134+
def to(td, num_threads):
135+
return td.consolidate(pin_memory=pin_mem).to(
136+
default_device, num_threads=num_threads
137+
)
138+
139+
else:
140+
141+
def to(td, num_threads):
142+
return td.to(default_device, num_threads=num_threads)
143+
144+
if compile_mode:
145+
to = torch.compile(to, mode=compile_mode)
146+
147+
for _ in range(3):
148+
to(njt_td, num_threads=num_threads)
149+
150+
benchmark(to, njt_td, num_threads)
69151

70152

71153
if __name__ == "__main__":
72154
args, unknown = argparse.ArgumentParser().parse_known_args()
73-
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
155+
pytest.main(
156+
[__file__, "--capture", "no", "--exitfirst", "--benchmark-group-by", "func"]
157+
+ unknown
158+
)

benchmarks/compile/compile_td_test.py

+6
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,12 @@ class MyTensorClass:
2323
f: torch.Tensor
2424

2525

26+
@pytest.fixture(autouse=True, scope="function")
27+
def empty_compiler_cache():
28+
torch._dynamo.reset_code_caches()
29+
yield
30+
31+
2632
# Functions
2733
def add_one(td):
2834
return td + 1

0 commit comments

Comments
 (0)