Skip to content

Commit 1855b14

Browse files
Mikhail Zolotukhinpytorchmergebot
Mikhail Zolotukhin
authored andcommitted
[TensorExpr] Delet DimArg class. (pytorch#72390)
Summary: Pull Request resolved: pytorch#72390 This class didn't add much value and only caused more boilerplate code. This change removes the class and updates all the use cases with uses of `ExprHandle`. A side effect of this change is different names in loop variables, which caused massive mechanical changes in our tests. Test Plan: Imported from OSS Reviewed By: navahgar Differential Revision: D34030296 Pulled By: ZolotukhinM fbshipit-source-id: 2ba4e313506a43ab129a10d99e72b638b7d40108 (cherry picked from commit c2ec46a)
1 parent 9123e9b commit 1855b14

39 files changed

+948
-1204
lines changed

benchmarks/cpp/tensorexpr/bench_batchnorm.cpp

+4-8
Original file line numberDiff line numberDiff line change
@@ -82,10 +82,8 @@ BENCHMARK_DEFINE_F(BatchNorm, NNC)(benchmark::State& state) {
8282
VarHandle eps("eps", kFloat);
8383

8484
using axis = const VarHandle&;
85-
Tensor output = Compute(
86-
"output",
87-
{{N_, "N"}, {C_, "C"}, {H_, "H"}, {W_, "W"}},
88-
[&](axis n, axis c, axis h, axis w) {
85+
Tensor output =
86+
Compute("output", {N_, C_, H_, W_}, [&](axis n, axis c, axis h, axis w) {
8987
// Compute affine terms.
9088
auto inv_var = FloatImm::make(1.0f) / sqrt(var.load(c) + eps);
9189
auto weight_v = weight.load(c);
@@ -143,10 +141,8 @@ BENCHMARK_DEFINE_F(BatchNorm, NNCRelu)(benchmark::State& state) {
143141
VarHandle eps("eps", kFloat);
144142

145143
using axis = const VarHandle&;
146-
Tensor output = Compute(
147-
"output",
148-
{{N_, "N"}, {C_, "C"}, {H_, "H"}, {W_, "W"}},
149-
[&](axis n, axis c, axis h, axis w) {
144+
Tensor output =
145+
Compute("output", {N_, C_, H_, W_}, [&](axis n, axis c, axis h, axis w) {
150146
// Compute affine terms.
151147
auto inv_var = FloatImm::make(1.0f) / sqrt(var.load(c) + eps);
152148
auto weight_v = weight.load(c);

benchmarks/cpp/tensorexpr/bench_compile.cpp

+29-40
Original file line numberDiff line numberDiff line change
@@ -12,26 +12,21 @@ static void BM_CompileSwish(benchmark::State& state) {
1212
constexpr int N = 512;
1313
te::VarHandle n("n", te::kInt);
1414
te::BufHandle A("A", {N}, te::kFloat);
15-
te::Tensor relu =
16-
te::Compute("relu", {{n, "n"}}, [&](const te::VarHandle& i) {
17-
return te::Max::make(A.load(i), 0.f, false);
18-
});
19-
te::Tensor min6 =
20-
te::Compute("min6", {{n, "n"}}, [&](const te::VarHandle& i) {
21-
return te::Min::make(relu.load(i), 6.f, false);
22-
});
23-
te::Tensor plus3 =
24-
te::Compute("plus3", {{n, "n"}}, [&](const te::VarHandle& i) {
25-
return min6.load(i) + 3.f;
26-
});
27-
te::Tensor times =
28-
te::Compute("times", {{n, "n"}}, [&](const te::VarHandle& i) {
29-
return A.load(i) * plus3.load(i);
30-
});
31-
te::Tensor sixth =
32-
te::Compute("sixth", {{n, "n"}}, [&](const te::VarHandle& i) {
33-
return times.load(i) * 1.f / 6.f;
34-
});
15+
te::Tensor relu = te::Compute("relu", {n}, [&](const te::VarHandle& i) {
16+
return te::Max::make(A.load(i), 0.f, false);
17+
});
18+
te::Tensor min6 = te::Compute("min6", {n}, [&](const te::VarHandle& i) {
19+
return te::Min::make(relu.load(i), 6.f, false);
20+
});
21+
te::Tensor plus3 = te::Compute("plus3", {n}, [&](const te::VarHandle& i) {
22+
return min6.load(i) + 3.f;
23+
});
24+
te::Tensor times = te::Compute("times", {n}, [&](const te::VarHandle& i) {
25+
return A.load(i) * plus3.load(i);
26+
});
27+
te::Tensor sixth = te::Compute("sixth", {n}, [&](const te::VarHandle& i) {
28+
return times.load(i) * 1.f / 6.f;
29+
});
3530
te::LoopNest nest({sixth}, {relu, min6, plus3, times, sixth});
3631
for (auto tensor : {relu, min6, plus3, times}) {
3732
nest.computeInline(tensor.buf());
@@ -46,26 +41,20 @@ static void BM_CompileSwishLLVMOnly(benchmark::State& state) {
4641
constexpr int N = 512;
4742
te::VarHandle n("n", te::kInt);
4843
te::BufHandle A("A", {N}, te::kFloat);
49-
te::Tensor relu =
50-
te::Compute("relu", {{n, "n"}}, [&](const te::VarHandle& i) {
51-
return te::Max::make(A.load(i), 0.f, false);
52-
});
53-
te::Tensor min6 =
54-
te::Compute("min6", {{n, "n"}}, [&](const te::VarHandle& i) {
55-
return te::Min::make(relu.load(i), 6.f, false);
56-
});
57-
te::Tensor plus3 =
58-
te::Compute("plus3", {{n, "n"}}, [&](const te::VarHandle& i) {
59-
return min6.load(i) + 3.f;
60-
});
61-
te::Tensor times =
62-
te::Compute("times", {{n, "n"}}, [&](const te::VarHandle& i) {
63-
return A.load(i) * plus3.load(i);
64-
});
65-
te::Tensor sixth =
66-
te::Compute("sixth", {{n, "n"}}, [&](const te::VarHandle& i) {
67-
return times.load(i) * 1.f / 6.f;
68-
});
44+
te::Tensor relu = te::Compute("relu", {n}, [&](const te::VarHandle& i) {
45+
return te::Max::make(A.load(i), 0.f, false);
46+
});
47+
te::Tensor min6 = te::Compute("min6", {n}, [&](const te::VarHandle& i) {
48+
return te::Min::make(relu.load(i), 6.f, false);
49+
});
50+
te::Tensor plus3 = te::Compute(
51+
"plus3", {n}, [&](const te::VarHandle& i) { return min6.load(i) + 3.f; });
52+
te::Tensor times = te::Compute("times", {n}, [&](const te::VarHandle& i) {
53+
return A.load(i) * plus3.load(i);
54+
});
55+
te::Tensor sixth = te::Compute("sixth", {n}, [&](const te::VarHandle& i) {
56+
return times.load(i) * 1.f / 6.f;
57+
});
6958
te::LoopNest nest({sixth}, {relu, min6, plus3, times, sixth});
7059
for (auto tensor : {relu, min6, plus3, times}) {
7160
nest.computeInline(tensor.buf());

benchmarks/cpp/tensorexpr/bench_concat.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ class ConcatBench : public benchmark::Fixture {
6161

6262
Tensor output = Compute(
6363
"aten_cat",
64-
{{output_size_[0], "M"}, {output_size_[1], "N"}},
64+
{output_size_[0], output_size_[1]},
6565
[&](const VarHandle& m, const VarHandle& n) {
6666
int d = 0;
6767
std::vector<int> cumulative_concat_dim_sizes(num_inputs);

benchmarks/cpp/tensorexpr/bench_gemm.cpp

+10-10
Original file line numberDiff line numberDiff line change
@@ -44,12 +44,12 @@ BENCHMARK_DEFINE_F(Gemm, TensorExprNoopt)(benchmark::State& state) {
4444
te::BufHandle BP("B", {K, N}, te::kFloat);
4545
te::Tensor CT = te::Reduce(
4646
"gemm",
47-
{{M, "M"}, {N, "N"}},
47+
{M, N},
4848
te::Sum(),
4949
[&](const te::ExprHandle& m,
5050
const te::ExprHandle& n,
5151
const te::ExprHandle& k) { return AP.load(m, k) * BP.load(k, n); },
52-
{{K, "K"}});
52+
{K});
5353
te::LoopNest loop({CT});
5454
loop.prepareForCodegen();
5555
te::StmtPtr s = loop.root_stmt();
@@ -66,12 +66,12 @@ BENCHMARK_DEFINE_F(Gemm, TensorExprTile32x32)(benchmark::State& state) {
6666
te::BufHandle BP("B", {K, N}, te::kFloat);
6767
te::Tensor CT = te::Reduce(
6868
"gemm",
69-
{{M, "M"}, {N, "N"}},
69+
{M, N},
7070
te::Sum(),
7171
[&](const te::ExprHandle& m,
7272
const te::ExprHandle& n,
7373
const te::ExprHandle& k) { return AP.load(m, k) * BP.load(k, n); },
74-
{{K, "K"}});
74+
{K});
7575
te::LoopNest loop({CT});
7676

7777
{
@@ -124,12 +124,12 @@ BENCHMARK_DEFINE_F(Gemm, TensorExprTile4x16)(benchmark::State& state) {
124124
te::BufHandle BP("B", {K, N}, te::kFloat);
125125
te::Tensor CT = te::Reduce(
126126
"gemm",
127-
{{M, "M"}, {N, "N"}},
127+
{M, N},
128128
te::Sum(),
129129
[&](const te::ExprHandle& m,
130130
const te::ExprHandle& n,
131131
const te::ExprHandle& k) { return AP.load(m, k) * BP.load(k, n); },
132-
{{K, "K"}});
132+
{K});
133133
te::LoopNest loop({CT});
134134

135135
{
@@ -182,12 +182,12 @@ BENCHMARK_DEFINE_F(Gemm, TensorExprTile4x16VecUnroll)(benchmark::State& state) {
182182
te::BufHandle BP("B", {K, N}, te::kFloat);
183183
te::Tensor CT = te::Reduce(
184184
"gemm",
185-
{{M, "M"}, {N, "N"}},
185+
{M, N},
186186
te::Sum(),
187187
[&](const te::ExprHandle& m,
188188
const te::ExprHandle& n,
189189
const te::ExprHandle& k) { return AP.load(m, k) * BP.load(k, n); },
190-
{{K, "K"}});
190+
{K});
191191
te::LoopNest loop({CT});
192192

193193
{
@@ -248,12 +248,12 @@ BENCHMARK_DEFINE_F(Gemm, TensorExprTile4x16Cache)(benchmark::State& state) {
248248
te::BufHandle BP("B", {K, N}, te::kFloat);
249249
te::Tensor CT = te::Reduce(
250250
"gemm",
251-
{{M, "M"}, {N, "N"}},
251+
{M, N},
252252
te::Sum(),
253253
[&](const te::ExprHandle& m,
254254
const te::ExprHandle& n,
255255
const te::ExprHandle& k) { return AP.load(m, k) * BP.load(k, n); },
256-
{{K, "K"}});
256+
{K});
257257
te::LoopNest loop({CT});
258258

259259
{

benchmarks/cpp/tensorexpr/bench_parallel.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ class ParallelAdd : public benchmark::Fixture {
3838
BENCHMARK_DEFINE_F(ParallelAdd, Simple)(benchmark::State& state) {
3939
BufHandle a_buf("a", {M}, kFloat);
4040
BufHandle b_buf("b", {M}, kFloat);
41-
Tensor c_tensor = Compute("c", {{M, "m"}}, [&](const VarHandle& m) {
41+
Tensor c_tensor = Compute("c", {M}, [&](const VarHandle& m) {
4242
return a_buf.load(m) + b_buf.load(m);
4343
});
4444
LoopNest loop_nest({c_tensor});

benchmarks/cpp/tensorexpr/bench_reduce.cpp

+7-7
Original file line numberDiff line numberDiff line change
@@ -235,12 +235,12 @@ BENCHMARK_DEFINE_F(Reduce1D, TeNaive)(benchmark::State& state) {
235235
te::BufHandle AP("A", {M}, te::kFloat);
236236
te::Tensor BT = te::Reduce(
237237
"reduce_full",
238-
{{1, "N"}},
238+
{1},
239239
te::Sum(),
240240
[&](const te::ExprHandle& n, const te::ExprHandle& m) {
241241
return AP.load(m);
242242
},
243-
{{M, "M"}});
243+
{M});
244244

245245
te::LoopNest loop({BT});
246246
loop.prepareForCodegen();
@@ -266,12 +266,12 @@ BENCHMARK_DEFINE_F(Reduce1D, TeSplitTail)(benchmark::State& state) {
266266
te::BufHandle AP("A", {M}, te::kFloat);
267267
te::Tensor BT = te::Reduce(
268268
"reduce_full",
269-
{{1, "N"}},
269+
{1},
270270
te::Sum(),
271271
[&](const te::ExprHandle& n, const te::ExprHandle& m) {
272272
return AP.load(m);
273273
},
274-
{{M, "M"}});
274+
{M});
275275

276276
te::LoopNest loop({BT});
277277
const int kChunkSize = 8;
@@ -305,12 +305,12 @@ BENCHMARK_DEFINE_F(Reduce1D, TeSplitMask)(benchmark::State& state) {
305305
te::BufHandle AP("A", {M}, te::kFloat);
306306
te::Tensor BT = te::Reduce(
307307
"reduce_full",
308-
{{1, "N"}},
308+
{1},
309309
te::Sum(),
310310
[&](const te::ExprHandle& n, const te::ExprHandle& m) {
311311
return AP.load(m);
312312
},
313-
{{M, "M"}});
313+
{M});
314314

315315
te::LoopNest loop({BT});
316316
const int kChunkSize = 8;
@@ -349,7 +349,7 @@ BENCHMARK_DEFINE_F(Reduce1D, TeRfactorV1)(benchmark::State& state) {
349349
{},
350350
te::Sum(),
351351
[&](const te::ExprHandle& m) { return AP.load(m); },
352-
{{M, "M"}});
352+
{M});
353353

354354
te::LoopNest loop({BT});
355355
te::BufPtr rfac_buf;

benchmarks/cpp/tensorexpr/bench_signed_log1p.cpp

+6-6
Original file line numberDiff line numberDiff line change
@@ -46,21 +46,21 @@ class SignedLog1pBench : public benchmark::Fixture {
4646
"input", {input_size_int_[0], input_size_int_[1]}, kFloat);
4747
Tensor abs_result = Compute(
4848
"aten_abs",
49-
{{input_size_int_[0], "M"}, {input_size_int_[1], "N"}},
49+
{input_size_int_[0], input_size_int_[1]},
5050
[&](const VarHandle& m, const VarHandle& n) {
5151
return abs(input_ph.load(m, n));
5252
});
5353
Tensor log1p_result = Compute(
5454
"aten_log1p",
55-
{{input_size_int_[0], "M"}, {input_size_int_[1], "N"}},
55+
{input_size_int_[0], input_size_int_[1]},
5656
[&](const VarHandle& m, const VarHandle& n) {
5757
return log1p(abs_result.load(m, n));
5858
});
5959
Tensor sign_result =
6060
computeSign({input_ph}, {input_size_int_[0], input_size_int_[1]});
6161
Tensor output = Compute(
6262
"aten_mul",
63-
{{input_size_int_[0], "M"}, {input_size_int_[1], "N"}},
63+
{input_size_int_[0], input_size_int_[1]},
6464
[&](const VarHandle& m, const VarHandle& n) {
6565
return sign_result.load(m, n) * log1p_result.load(m, n);
6666
});
@@ -94,21 +94,21 @@ class SignedLog1pBench : public benchmark::Fixture {
9494
"input", {input_size_int_[0], input_size_int_[1]}, kFloat);
9595
Tensor abs_result = Compute(
9696
"aten_abs",
97-
{{input_size_int_[0], "M"}, {input_size_int_[1], "N"}},
97+
{input_size_int_[0], input_size_int_[1]},
9898
[&](const VarHandle& m, const VarHandle& n) {
9999
return abs(input_ph.load(m, n));
100100
});
101101
Tensor log_vml_result = Compute(
102102
"aten_log1p",
103-
{{input_size_int_[0], "M"}, {input_size_int_[1], "N"}},
103+
{input_size_int_[0], input_size_int_[1]},
104104
[&](const VarHandle& m, const VarHandle& n) {
105105
return log_vml(abs_result.load(m, n) + ExprHandle(1));
106106
});
107107
Tensor sign_result =
108108
computeSign({input_ph}, {input_size_int_[0], input_size_int_[1]});
109109
Tensor output = Compute(
110110
"aten_mul",
111-
{{input_size_int_[0], "M"}, {input_size_int_[1], "N"}},
111+
{input_size_int_[0], input_size_int_[1]},
112112
[&](const VarHandle& m, const VarHandle& n) {
113113
return sign_result.load(m, n) * log_vml_result.load(m, n);
114114
});

0 commit comments

Comments
 (0)