@@ -12,26 +12,21 @@ static void BM_CompileSwish(benchmark::State& state) {
12
12
constexpr int N = 512 ;
13
13
te::VarHandle n (" n" , te::kInt );
14
14
te::BufHandle A (" A" , {N}, te::kFloat );
15
- te::Tensor relu =
16
- te::Compute (" relu" , {{n, " n" }}, [&](const te::VarHandle& i) {
17
- return te::Max::make (A.load (i), 0 .f , false );
18
- });
19
- te::Tensor min6 =
20
- te::Compute (" min6" , {{n, " n" }}, [&](const te::VarHandle& i) {
21
- return te::Min::make (relu.load (i), 6 .f , false );
22
- });
23
- te::Tensor plus3 =
24
- te::Compute (" plus3" , {{n, " n" }}, [&](const te::VarHandle& i) {
25
- return min6.load (i) + 3 .f ;
26
- });
27
- te::Tensor times =
28
- te::Compute (" times" , {{n, " n" }}, [&](const te::VarHandle& i) {
29
- return A.load (i) * plus3.load (i);
30
- });
31
- te::Tensor sixth =
32
- te::Compute (" sixth" , {{n, " n" }}, [&](const te::VarHandle& i) {
33
- return times.load (i) * 1 .f / 6 .f ;
34
- });
15
+ te::Tensor relu = te::Compute (" relu" , {n}, [&](const te::VarHandle& i) {
16
+ return te::Max::make (A.load (i), 0 .f , false );
17
+ });
18
+ te::Tensor min6 = te::Compute (" min6" , {n}, [&](const te::VarHandle& i) {
19
+ return te::Min::make (relu.load (i), 6 .f , false );
20
+ });
21
+ te::Tensor plus3 = te::Compute (" plus3" , {n}, [&](const te::VarHandle& i) {
22
+ return min6.load (i) + 3 .f ;
23
+ });
24
+ te::Tensor times = te::Compute (" times" , {n}, [&](const te::VarHandle& i) {
25
+ return A.load (i) * plus3.load (i);
26
+ });
27
+ te::Tensor sixth = te::Compute (" sixth" , {n}, [&](const te::VarHandle& i) {
28
+ return times.load (i) * 1 .f / 6 .f ;
29
+ });
35
30
te::LoopNest nest ({sixth}, {relu, min6, plus3, times, sixth});
36
31
for (auto tensor : {relu, min6, plus3, times}) {
37
32
nest.computeInline (tensor.buf ());
@@ -46,26 +41,20 @@ static void BM_CompileSwishLLVMOnly(benchmark::State& state) {
46
41
constexpr int N = 512 ;
47
42
te::VarHandle n (" n" , te::kInt );
48
43
te::BufHandle A (" A" , {N}, te::kFloat );
49
- te::Tensor relu =
50
- te::Compute (" relu" , {{n, " n" }}, [&](const te::VarHandle& i) {
51
- return te::Max::make (A.load (i), 0 .f , false );
52
- });
53
- te::Tensor min6 =
54
- te::Compute (" min6" , {{n, " n" }}, [&](const te::VarHandle& i) {
55
- return te::Min::make (relu.load (i), 6 .f , false );
56
- });
57
- te::Tensor plus3 =
58
- te::Compute (" plus3" , {{n, " n" }}, [&](const te::VarHandle& i) {
59
- return min6.load (i) + 3 .f ;
60
- });
61
- te::Tensor times =
62
- te::Compute (" times" , {{n, " n" }}, [&](const te::VarHandle& i) {
63
- return A.load (i) * plus3.load (i);
64
- });
65
- te::Tensor sixth =
66
- te::Compute (" sixth" , {{n, " n" }}, [&](const te::VarHandle& i) {
67
- return times.load (i) * 1 .f / 6 .f ;
68
- });
44
+ te::Tensor relu = te::Compute (" relu" , {n}, [&](const te::VarHandle& i) {
45
+ return te::Max::make (A.load (i), 0 .f , false );
46
+ });
47
+ te::Tensor min6 = te::Compute (" min6" , {n}, [&](const te::VarHandle& i) {
48
+ return te::Min::make (relu.load (i), 6 .f , false );
49
+ });
50
+ te::Tensor plus3 = te::Compute (
51
+ " plus3" , {n}, [&](const te::VarHandle& i) { return min6.load (i) + 3 .f ; });
52
+ te::Tensor times = te::Compute (" times" , {n}, [&](const te::VarHandle& i) {
53
+ return A.load (i) * plus3.load (i);
54
+ });
55
+ te::Tensor sixth = te::Compute (" sixth" , {n}, [&](const te::VarHandle& i) {
56
+ return times.load (i) * 1 .f / 6 .f ;
57
+ });
69
58
te::LoopNest nest ({sixth}, {relu, min6, plus3, times, sixth});
70
59
for (auto tensor : {relu, min6, plus3, times}) {
71
60
nest.computeInline (tensor.buf ());
0 commit comments