Skip to content

Commit dee1009

Browse files
cyyeverpytorchmergebot
authored andcommitted
[2/N] Move c10::variant to std::variant (pytorch#109723)
This PR moves most of c10::variant calls to std::variant. Pull Request resolved: pytorch#109723 Approved by: https://github.com/ezyang
1 parent c13177f commit dee1009

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+337
-341
lines changed

aten/src/ATen/cuda/jiterator_impl.h

+12-12
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,14 @@
33

44
#if AT_USE_JITERATOR()
55

6-
#include <c10/util/variant.h>
76
#include <ATen/native/TensorIterator.h>
87
#include <ATen/cuda/detail/OffsetCalculator.cuh>
98
#include <ATen/native/cuda/jit_utils.h>
109
#include <ATen/native/cuda/MemoryAccess.cuh>
1110
#include <ATen/native/cuda/JitLoops.cuh>
1211

1312
#include <string>
13+
#include <variant>
1414
#include <vector>
1515

1616
namespace at::native {
@@ -93,7 +93,7 @@ static std::unique_ptr<OffsetCalculator<N>> make_unique_offset_calculator(
9393
template <bool IS_INPUT>
9494
struct OffsetCalculatorVariant {
9595
#define DEFINE_CASE(index) std::unique_ptr<OffsetCalculator<index>>
96-
using OffsetCalculatorTypes = c10::variant<
96+
using OffsetCalculatorTypes = std::variant<
9797
AT_FOR_8_CASES_WITH_COMMA(DEFINE_CASE)
9898
>;
9999
#undef DEFINE_CASE
@@ -113,7 +113,7 @@ struct OffsetCalculatorVariant {
113113
}
114114

115115
void* data_ptr() {
116-
return c10::visit([](auto & v){ return static_cast<void*>(v.get()); }, v);
116+
return std::visit([](auto & v){ return static_cast<void*>(v.get()); }, v);
117117
}
118118

119119
private:
@@ -123,7 +123,7 @@ struct OffsetCalculatorVariant {
123123
struct ArrayVariant {
124124
// works for up to 8 input + 8 outputs
125125
#define DEFINE_CASE(index) at::detail::Array<char*, index>, at::detail::Array<char*, index+8>
126-
using ArrayTypes = c10::variant<
126+
using ArrayTypes = std::variant<
127127
AT_FOR_8_CASES_WITH_COMMA(DEFINE_CASE)
128128
>;
129129
#undef DEFINE_CASE
@@ -142,15 +142,15 @@ struct ArrayVariant {
142142
TORCH_CHECK(false, "ArrayVariant is not implemented for ntensors = ", ntensors);
143143
}
144144

145-
c10::visit([&](auto& a) {
145+
std::visit([&](auto& a) {
146146
for (auto i = 0; i < ntensors; ++i) {
147147
a[i] = (char*)iter.data_ptr(i);
148148
}
149149
}, array);
150150
}
151151

152152
void* data_ptr() {
153-
return c10::visit([](auto & a){ return static_cast<void*>(&a); }, array);
153+
return std::visit([](auto & a){ return static_cast<void*>(&a); }, array);
154154
}
155155

156156
private:
@@ -159,7 +159,7 @@ struct ArrayVariant {
159159

160160
struct TrivialOffsetCalculatorVariant {
161161
#define DEFINE_CASE(index) TrivialOffsetCalculator<index>
162-
using TrivialOffsetCalculatorTypes = c10::variant<
162+
using TrivialOffsetCalculatorTypes = std::variant<
163163
AT_FOR_8_CASES_WITH_COMMA(DEFINE_CASE)
164164
>;
165165
#undef DEFINE_CASE
@@ -178,7 +178,7 @@ struct TrivialOffsetCalculatorVariant {
178178
}
179179

180180
void* data_ptr() {
181-
return c10::visit([](auto & v){ return static_cast<void*>(&v); }, v);
181+
return std::visit([](auto & v){ return static_cast<void*>(&v); }, v);
182182
}
183183

184184
private:
@@ -187,7 +187,7 @@ struct TrivialOffsetCalculatorVariant {
187187

188188
struct LoadWithCastVariant {
189189
#define DEFINE_CASE(index) std::unique_ptr<memory::LoadWithCast<index>>
190-
using LoadWithCastPtr = c10::variant<
190+
using LoadWithCastPtr = std::variant<
191191
AT_FOR_8_CASES_WITH_COMMA(DEFINE_CASE)
192192
>;
193193
#undef DEFINE_CASE
@@ -207,7 +207,7 @@ struct LoadWithCastVariant {
207207
}
208208

209209
void* data_ptr() {
210-
return c10::visit([](auto & v){ return static_cast<void*>(v.get()); }, v);
210+
return std::visit([](auto & v){ return static_cast<void*>(v.get()); }, v);
211211
}
212212

213213
private:
@@ -216,7 +216,7 @@ struct LoadWithCastVariant {
216216

217217
struct StoreWithCastVariant {
218218
#define DEFINE_CASE(index) std::unique_ptr<memory::StoreWithCast<index>>
219-
using StoreWithCastPtr = c10::variant<
219+
using StoreWithCastPtr = std::variant<
220220
AT_FOR_8_CASES_WITH_COMMA(DEFINE_CASE)
221221
>;
222222
#undef DEFINE_CASE
@@ -236,7 +236,7 @@ struct StoreWithCastVariant {
236236
}
237237

238238
void* data_ptr() {
239-
return c10::visit([](auto & v){ return static_cast<void*>(v.get()); }, v);
239+
return std::visit([](auto & v){ return static_cast<void*>(v.get()); }, v);
240240
}
241241

242242
private:

aten/src/ATen/functorch/DynamicLayer.h

-3
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,6 @@
99
#include <c10/core/DispatchKey.h>
1010
#include <ATen/core/function_schema.h>
1111
#include <c10/util/Optional.h>
12-
#include <c10/util/variant.h>
13-
#include <unordered_map>
14-
#include <mutex>
1512
#include <c10/core/impl/LocalDispatchKeySet.h>
1613
#include <ATen/functorch/Interpreter.h>
1714
#include <ATen/functorch/VmapInterpreter.h>

c10/util/Exception.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,11 @@
33

44
#include <c10/macros/Macros.h>
55
#include <c10/util/StringUtil.h>
6-
#include <c10/util/variant.h>
76

87
#include <cstddef>
98
#include <exception>
109
#include <string>
10+
#include <variant>
1111
#include <vector>
1212

1313
#if defined(_MSC_VER) && _MSC_VER <= 1900
@@ -115,7 +115,7 @@ class C10_API Warning {
115115
class C10_API UserWarning {};
116116
class C10_API DeprecationWarning {};
117117

118-
using warning_variant_t = c10::variant<UserWarning, DeprecationWarning>;
118+
using warning_variant_t = std::variant<UserWarning, DeprecationWarning>;
119119

120120
Warning(
121121
warning_variant_t type,

c10/util/MaybeOwned.h

+1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <c10/util/Exception.h>
55
#include <c10/util/in_place.h>
66

7+
#include <memory>
78
#include <type_traits>
89

910
namespace c10 {

test/cpp/api/enum.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include <gtest/gtest.h>
22

33
#include <torch/torch.h>
4+
#include <variant>
45

56
#include <test/cpp/api/support.h>
67

@@ -13,7 +14,7 @@
1314
}
1415

1516
TEST(EnumTest, AllEnums) {
16-
c10::variant<
17+
std::variant<
1718
torch::enumtype::kLinear,
1819
torch::enumtype::kConv1D,
1920
torch::enumtype::kConv2D,

test/cpp/tensorexpr/test_external_calls.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -951,11 +951,11 @@ TEST(ExternalCall, JitCustomFusionOp) {
951951
torch::jit::tensorexpr::BufHandle result_buf(
952952
"nnc_add_mul_res_buf", output_shape, output_dtype);
953953
const torch::jit::tensorexpr::BufHandle& a =
954-
c10::get<torch::jit::tensorexpr::BufHandle>(inputs[0]);
954+
std::get<torch::jit::tensorexpr::BufHandle>(inputs[0]);
955955
const torch::jit::tensorexpr::BufHandle& b =
956-
c10::get<torch::jit::tensorexpr::BufHandle>(inputs[1]);
956+
std::get<torch::jit::tensorexpr::BufHandle>(inputs[1]);
957957
const torch::jit::tensorexpr::BufHandle& c =
958-
c10::get<torch::jit::tensorexpr::BufHandle>(inputs[1]);
958+
std::get<torch::jit::tensorexpr::BufHandle>(inputs[1]);
959959
torch::jit::tensorexpr::StmtPtr s =
960960
torch::jit::tensorexpr::ExternalCall::make(
961961
result_buf, external_func_name, {a, b, c}, {});

test/cpp/tensorexpr/test_kernel.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -1667,7 +1667,7 @@ Tensor lowerNanToNum(
16671667
const std::vector<ExprHandle>& outputStrides,
16681668
const c10::optional<ScalarType>& outputType,
16691669
at::Device device) {
1670-
auto input_buf = c10::get<BufHandle>(inputs[0]);
1670+
auto input_buf = std::get<BufHandle>(inputs[0]);
16711671
auto e = Compute(
16721672
"custom_nan_to_num",
16731673
outputShape,

torch/csrc/Exceptions.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,7 @@ PyObject* map_warning_to_python_type(const c10::Warning& warning) {
286286
return PyExc_DeprecationWarning;
287287
}
288288
};
289-
return c10::visit(Visitor(), warning.type());
289+
return std::visit(Visitor(), warning.type());
290290
}
291291

292292
/// See NOTE [ Conversion Cpp Python Warning ] for noexcept justification

torch/csrc/api/include/torch/enum.h

+6-6
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
#pragma once
22

33
#include <string>
4+
#include <variant>
45

56
#include <ATen/core/Reduction.h>
67
#include <c10/util/Exception.h>
7-
#include <c10/util/variant.h>
88
#include <torch/csrc/Export.h>
99

1010
#define TORCH_ENUM_DECLARE(name) \
@@ -42,7 +42,7 @@
4242
//
4343
// ```
4444
// struct TORCH_API SomeOptions {
45-
// typedef c10::variant<enumtype::kNone, enumtype::kMean, enumtype::kSum>
45+
// typedef std::variant<enumtype::kNone, enumtype::kMean, enumtype::kSum>
4646
// reduction_t; SomeOptions(reduction_t reduction = torch::kMean) :
4747
// reduction_(reduction) {}
4848
//
@@ -188,16 +188,16 @@ struct _compute_enum_name {
188188

189189
template <typename V>
190190
std::string get_enum_name(V variant_enum) {
191-
return c10::visit(enumtype::_compute_enum_name{}, variant_enum);
191+
return std::visit(enumtype::_compute_enum_name{}, variant_enum);
192192
}
193193

194194
template <typename V>
195195
at::Reduction::Reduction reduction_get_enum(V variant_enum) {
196-
if (c10::get_if<enumtype::kNone>(&variant_enum)) {
196+
if (std::holds_alternative<enumtype::kNone>(variant_enum)) {
197197
return at::Reduction::None;
198-
} else if (c10::get_if<enumtype::kMean>(&variant_enum)) {
198+
} else if (std::holds_alternative<enumtype::kMean>(variant_enum)) {
199199
return at::Reduction::Mean;
200-
} else if (c10::get_if<enumtype::kSum>(&variant_enum)) {
200+
} else if (std::holds_alternative<enumtype::kSum>(variant_enum)) {
201201
return at::Reduction::Sum;
202202
} else {
203203
TORCH_CHECK(

torch/csrc/api/include/torch/nn/functional/conv.h

+3-3
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ inline Tensor conv1d(
3131
const Conv1dFuncOptions::padding_t& padding,
3232
ExpandingArray<1> dilation,
3333
int64_t groups) {
34-
return c10::visit(
34+
return std::visit(
3535
[&](const auto& pad) {
3636
return torch::conv1d(
3737
input, weight, bias, stride, padding_unwrap(pad), dilation, groups);
@@ -77,7 +77,7 @@ inline Tensor conv2d(
7777
const Conv2dFuncOptions::padding_t& padding,
7878
ExpandingArray<2> dilation,
7979
int64_t groups) {
80-
return c10::visit(
80+
return std::visit(
8181
[&](const auto& pad) {
8282
return torch::conv2d(
8383
input, weight, bias, stride, padding_unwrap(pad), dilation, groups);
@@ -123,7 +123,7 @@ inline Tensor conv3d(
123123
const Conv3dFuncOptions::padding_t& padding,
124124
ExpandingArray<3> dilation,
125125
int64_t groups) {
126-
return c10::visit(
126+
return std::visit(
127127
[&](const auto& pad) {
128128
return torch::conv3d(
129129
input, weight, bias, stride, padding_unwrap(pad), dilation, groups);

torch/csrc/api/include/torch/nn/functional/embedding.h

+4-4
Original file line numberDiff line numberDiff line change
@@ -135,11 +135,11 @@ inline Tensor embedding_bag(
135135

136136
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
137137
int mode_enum;
138-
if (c10::get_if<enumtype::kSum>(&mode)) {
138+
if (std::holds_alternative<enumtype::kSum>(mode)) {
139139
mode_enum = 0;
140-
} else if (c10::get_if<enumtype::kMean>(&mode)) {
140+
} else if (std::holds_alternative<enumtype::kMean>(mode)) {
141141
mode_enum = 1;
142-
} else if (c10::get_if<enumtype::kMax>(&mode)) {
142+
} else if (std::holds_alternative<enumtype::kMax>(mode)) {
143143
mode_enum = 2;
144144
TORCH_CHECK(
145145
!scale_grad_by_freq,
@@ -155,7 +155,7 @@ inline Tensor embedding_bag(
155155
}
156156

157157
TORCH_CHECK(
158-
!per_sample_weights_.defined() || c10::get_if<enumtype::kSum>(&mode),
158+
!per_sample_weights_.defined() || std::get_if<enumtype::kSum>(&mode),
159159
"embedding_bag: per_sample_weights was not null. ",
160160
"per_sample_weights is only supported for mode='kSum' (got mode='",
161161
torch::enumtype::get_enum_name(mode),

torch/csrc/api/include/torch/nn/functional/loss.h

+10-9
Original file line numberDiff line numberDiff line change
@@ -50,23 +50,24 @@ inline Tensor kl_div(
5050
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
5151
torch::Reduction::Reduction reduction_enum;
5252

53-
if (c10::get_if<enumtype::kMean>(&reduction)) {
53+
if (std::holds_alternative<enumtype::kMean>(reduction)) {
5454
TORCH_WARN(
5555
"reduction: 'mean' divides the total loss by both the batch size and the support size."
5656
"'batchmean' divides only by the batch size, and aligns with the KL div math definition."
5757
"'mean' will be changed to behave the same as 'batchmean' in the next major release.");
5858
}
5959

6060
// special case for batchmean
61-
if (c10::get_if<enumtype::kBatchMean>(&reduction)) {
61+
if (std::holds_alternative<enumtype::kBatchMean>(reduction)) {
6262
reduction_enum = torch::Reduction::Sum;
6363
} else {
6464
reduction_enum = enumtype::reduction_get_enum(reduction);
6565
}
6666

6767
auto reduced = torch::kl_div(input, target, reduction_enum, log_target);
6868

69-
if (c10::get_if<enumtype::kBatchMean>(&reduction) && input.dim() != 0) {
69+
if (std::holds_alternative<enumtype::kBatchMean>(reduction) &&
70+
input.dim() != 0) {
7071
reduced = reduced / input.sizes()[0];
7172
}
7273

@@ -531,11 +532,11 @@ inline Tensor multilabel_soft_margin_loss(
531532

532533
Tensor ret;
533534

534-
if (c10::get_if<enumtype::kNone>(&reduction)) {
535+
if (std::holds_alternative<enumtype::kNone>(reduction)) {
535536
ret = loss;
536-
} else if (c10::get_if<enumtype::kMean>(&reduction)) {
537+
} else if (std::holds_alternative<enumtype::kMean>(reduction)) {
537538
ret = loss.mean();
538-
} else if (c10::get_if<enumtype::kSum>(&reduction)) {
539+
} else if (std::holds_alternative<enumtype::kSum>(reduction)) {
539540
ret = loss.sum();
540541
} else {
541542
ret = input;
@@ -661,11 +662,11 @@ inline Tensor triplet_margin_with_distance_loss(
661662
auto loss = torch::clamp_min(dist_pos - dist_neg + margin, 0);
662663

663664
Tensor ret;
664-
if (c10::get_if<enumtype::kNone>(&reduction)) {
665+
if (std::holds_alternative<enumtype::kNone>(reduction)) {
665666
ret = loss;
666-
} else if (c10::get_if<enumtype::kMean>(&reduction)) {
667+
} else if (std::holds_alternative<enumtype::kMean>(reduction)) {
667668
ret = loss.mean();
668-
} else if (c10::get_if<enumtype::kSum>(&reduction)) {
669+
} else if (std::holds_alternative<enumtype::kSum>(reduction)) {
669670
ret = loss.sum();
670671
} else {
671672
ret = anchor;

torch/csrc/api/include/torch/nn/functional/padding.h

+4-4
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,13 @@ inline Tensor pad(
1515
PadFuncOptions::mode_t mode,
1616
double value) {
1717
const auto mode_enum = [&] {
18-
if (c10::get_if<enumtype::kConstant>(&mode)) {
18+
if (std::holds_alternative<enumtype::kConstant>(mode)) {
1919
return at::padding_mode::constant;
20-
} else if (c10::get_if<enumtype::kReflect>(&mode)) {
20+
} else if (std::holds_alternative<enumtype::kReflect>(mode)) {
2121
return at::padding_mode::reflect;
22-
} else if (c10::get_if<enumtype::kReplicate>(&mode)) {
22+
} else if (std::holds_alternative<enumtype::kReplicate>(mode)) {
2323
return at::padding_mode::replicate;
24-
} else if (c10::get_if<enumtype::kCircular>(&mode)) {
24+
} else if (std::holds_alternative<enumtype::kCircular>(mode)) {
2525
return at::padding_mode::circular;
2626
}
2727
TORCH_CHECK(false, "Unrecognised padding mode");

0 commit comments

Comments
 (0)