3
3
4
4
#if AT_USE_JITERATOR()
5
5
6
- #include < c10/util/variant.h>
7
6
#include < ATen/native/TensorIterator.h>
8
7
#include < ATen/cuda/detail/OffsetCalculator.cuh>
9
8
#include < ATen/native/cuda/jit_utils.h>
10
9
#include < ATen/native/cuda/MemoryAccess.cuh>
11
10
#include < ATen/native/cuda/JitLoops.cuh>
12
11
13
12
#include < string>
13
+ #include < variant>
14
14
#include < vector>
15
15
16
16
namespace at ::native {
@@ -93,7 +93,7 @@ static std::unique_ptr<OffsetCalculator<N>> make_unique_offset_calculator(
93
93
template <bool IS_INPUT>
94
94
struct OffsetCalculatorVariant {
95
95
#define DEFINE_CASE (index ) std::unique_ptr<OffsetCalculator<index>>
96
- using OffsetCalculatorTypes = c10 ::variant<
96
+ using OffsetCalculatorTypes = std ::variant<
97
97
AT_FOR_8_CASES_WITH_COMMA (DEFINE_CASE)
98
98
>;
99
99
#undef DEFINE_CASE
@@ -113,7 +113,7 @@ struct OffsetCalculatorVariant {
113
113
}
114
114
115
115
void * data_ptr () {
116
- return c10 ::visit ([](auto & v){ return static_cast <void *>(v.get ()); }, v);
116
+ return std ::visit ([](auto & v){ return static_cast <void *>(v.get ()); }, v);
117
117
}
118
118
119
119
private:
@@ -123,7 +123,7 @@ struct OffsetCalculatorVariant {
123
123
struct ArrayVariant {
124
124
// works for up to 8 input + 8 outputs
125
125
#define DEFINE_CASE (index ) at::detail::Array<char *, index>, at::detail::Array<char *, index+8 >
126
- using ArrayTypes = c10 ::variant<
126
+ using ArrayTypes = std ::variant<
127
127
AT_FOR_8_CASES_WITH_COMMA (DEFINE_CASE)
128
128
>;
129
129
#undef DEFINE_CASE
@@ -142,15 +142,15 @@ struct ArrayVariant {
142
142
TORCH_CHECK (false , " ArrayVariant is not implemented for ntensors = " , ntensors);
143
143
}
144
144
145
- c10 ::visit ([&](auto & a) {
145
+ std ::visit ([&](auto & a) {
146
146
for (auto i = 0 ; i < ntensors; ++i) {
147
147
a[i] = (char *)iter.data_ptr (i);
148
148
}
149
149
}, array);
150
150
}
151
151
152
152
void * data_ptr () {
153
- return c10 ::visit ([](auto & a){ return static_cast <void *>(&a); }, array);
153
+ return std ::visit ([](auto & a){ return static_cast <void *>(&a); }, array);
154
154
}
155
155
156
156
private:
@@ -159,7 +159,7 @@ struct ArrayVariant {
159
159
160
160
struct TrivialOffsetCalculatorVariant {
161
161
#define DEFINE_CASE (index ) TrivialOffsetCalculator<index>
162
- using TrivialOffsetCalculatorTypes = c10 ::variant<
162
+ using TrivialOffsetCalculatorTypes = std ::variant<
163
163
AT_FOR_8_CASES_WITH_COMMA (DEFINE_CASE)
164
164
>;
165
165
#undef DEFINE_CASE
@@ -178,7 +178,7 @@ struct TrivialOffsetCalculatorVariant {
178
178
}
179
179
180
180
void * data_ptr () {
181
- return c10 ::visit ([](auto & v){ return static_cast <void *>(&v); }, v);
181
+ return std ::visit ([](auto & v){ return static_cast <void *>(&v); }, v);
182
182
}
183
183
184
184
private:
@@ -187,7 +187,7 @@ struct TrivialOffsetCalculatorVariant {
187
187
188
188
struct LoadWithCastVariant {
189
189
#define DEFINE_CASE (index ) std::unique_ptr<memory::LoadWithCast<index>>
190
- using LoadWithCastPtr = c10 ::variant<
190
+ using LoadWithCastPtr = std ::variant<
191
191
AT_FOR_8_CASES_WITH_COMMA (DEFINE_CASE)
192
192
>;
193
193
#undef DEFINE_CASE
@@ -207,7 +207,7 @@ struct LoadWithCastVariant {
207
207
}
208
208
209
209
void * data_ptr () {
210
- return c10 ::visit ([](auto & v){ return static_cast <void *>(v.get ()); }, v);
210
+ return std ::visit ([](auto & v){ return static_cast <void *>(v.get ()); }, v);
211
211
}
212
212
213
213
private:
@@ -216,7 +216,7 @@ struct LoadWithCastVariant {
216
216
217
217
struct StoreWithCastVariant {
218
218
#define DEFINE_CASE (index ) std::unique_ptr<memory::StoreWithCast<index>>
219
- using StoreWithCastPtr = c10 ::variant<
219
+ using StoreWithCastPtr = std ::variant<
220
220
AT_FOR_8_CASES_WITH_COMMA (DEFINE_CASE)
221
221
>;
222
222
#undef DEFINE_CASE
@@ -236,7 +236,7 @@ struct StoreWithCastVariant {
236
236
}
237
237
238
238
void * data_ptr () {
239
- return c10 ::visit ([](auto & v){ return static_cast <void *>(v.get ()); }, v);
239
+ return std ::visit ([](auto & v){ return static_cast <void *>(v.get ()); }, v);
240
240
}
241
241
242
242
private:
0 commit comments