Skip to content

Commit 6ecd658

Browse files
cyyeverpytorchmergebot
authored andcommitted
Remove unnecessary const_casts (pytorch#121225)
Fixes #ISSUE_NUMBER Pull Request resolved: pytorch#121225 Approved by: https://github.com/soulitzer
1 parent 85c807b commit 6ecd658

File tree

7 files changed

+13
-30
lines changed

7 files changed

+13
-30
lines changed

aten/src/ATen/DLConvertor.cpp

+2-6
Original file line numberDiff line numberDiff line change
@@ -283,12 +283,8 @@ DLManagedTensor* toDLPack(const Tensor& src) {
283283
atDLMTensor->tensor.dl_tensor.device = getDLDevice(src, device_id);
284284
atDLMTensor->tensor.dl_tensor.ndim = src.dim();
285285
atDLMTensor->tensor.dl_tensor.dtype = getDLDataType(src);
286-
atDLMTensor->tensor.dl_tensor.shape =
287-
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
288-
const_cast<int64_t*>(view.sizes().data());
289-
atDLMTensor->tensor.dl_tensor.strides =
290-
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
291-
const_cast<int64_t*>(view.strides().data());
286+
atDLMTensor->tensor.dl_tensor.shape = view.sizes().data();
287+
atDLMTensor->tensor.dl_tensor.strides = view.strides().data();
292288
atDLMTensor->tensor.dl_tensor.byte_offset = 0;
293289
return &(atDLMTensor->tensor);
294290
}

aten/src/ATen/cuda/detail/CUDAHooks.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ bool CUDAHooks::isPinnedPtr(const void* data) const {
137137
cudaPointerAttributes attr;
138138
// We do not believe that CUDA needs mutable access to the data
139139
// here.
140-
cudaError_t err = cudaPointerGetAttributes(&attr, const_cast<void*>(data));
140+
cudaError_t err = cudaPointerGetAttributes(&attr, data);
141141
#if !defined(USE_ROCM)
142142
if (err == cudaErrorInvalidValue) {
143143
(void)cudaGetLastError(); // clear CUDA error

aten/src/ATen/dlpack.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -195,12 +195,12 @@ typedef struct {
195195
/*! \brief The data type of the pointer*/
196196
DLDataType dtype;
197197
/*! \brief The shape of the tensor */
198-
int64_t* shape;
198+
const int64_t* shape;
199199
/*!
200200
* \brief strides of the tensor (in number of elements, not bytes)
201201
* can be NULL, indicating tensor is compact and row-majored.
202202
*/
203-
int64_t* strides;
203+
const int64_t* strides;
204204
/*! \brief The offset in bytes to the beginning pointer to data */
205205
uint64_t byte_offset;
206206
} DLTensor;

torch/csrc/autograd/VariableTypeUtils.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ inline void throw_error_for_complex_autograd(
104104

105105
// TODO: Blegh, bare references
106106

107-
inline void rebase_history(Variable& var, std::shared_ptr<Node> grad_fn) {
107+
inline void rebase_history(const Variable& var, std::shared_ptr<Node> grad_fn) {
108108
if (grad_fn && var.defined()) {
109109
grad_fn->add_input_metadata(var);
110110
impl::rebase_history(var, {std::move(grad_fn), 0});

torch/csrc/autograd/autograd_not_implemented_fallback.cpp

+4-8
Original file line numberDiff line numberDiff line change
@@ -184,8 +184,7 @@ static void basicAutogradNotImplementedFallbackImpl(
184184
// users typically call .backward() and backprop through
185185
// the entire program).
186186
if (t.is_view() && is_mutable_output) {
187-
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
188-
auto& base = const_cast<at::TensorBase&>(t._base());
187+
const auto& base = t._base();
189188
if (base.requires_grad()) {
190189
// Can only register_hook on tensors that require grad.
191190
base.register_hook([op_name](const at::TensorBase& grad) {
@@ -210,8 +209,7 @@ static void basicAutogradNotImplementedFallbackImpl(
210209
// rebase_history assumes single Tensor(a!) return, and in general
211210
// custom ops don't have a good in-place story.
212211
if (!is_mutable_output) {
213-
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
214-
set_history(const_cast<at::Tensor&>(t), grad_fn);
212+
set_history(t, grad_fn);
215213
}
216214
},
217215
stack,
@@ -418,11 +416,9 @@ static void autogradNotImplementedFallbackImpl(
418416
[&](size_t idx_tensor, size_t idx_ret, const at::Tensor& t) {
419417
if (isDifferentiableType(t.scalar_type())) {
420418
if (is_inplace_output[idx_ret]) {
421-
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
422-
rebase_history(const_cast<at::Tensor&>(t), grad_fn);
419+
rebase_history(t, grad_fn);
423420
} else {
424-
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
425-
set_history(const_cast<at::Tensor&>(t), grad_fn);
421+
set_history(t, grad_fn);
426422
}
427423
}
428424
},

torch/csrc/autograd/functions/utils.h

+2-10
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ inline bool compute_requires_grad(Args&&... args) {
6565
}
6666

6767
inline void set_history(
68-
at::Tensor& variable,
68+
const at::Tensor& variable,
6969
const std::shared_ptr<Node>& grad_fn) {
7070
TORCH_CHECK(grad_fn != nullptr);
7171
if (variable.defined()) {
@@ -81,15 +81,7 @@ inline void set_history(
8181
}
8282

8383
inline void set_history(
84-
std::vector<Variable>&& variables,
85-
const std::shared_ptr<Node>& grad_fn) {
86-
for (auto& variable : variables) {
87-
set_history(variable, grad_fn);
88-
}
89-
}
90-
91-
inline void set_history(
92-
std::vector<Variable>& variables,
84+
const std::vector<Variable>& variables,
9385
const std::shared_ptr<Node>& grad_fn) {
9486
for (auto& variable : variables) {
9587
set_history(variable, grad_fn);

torch/csrc/distributed/rpc/tensorpipe_utils.cpp

+1-2
Original file line numberDiff line numberDiff line change
@@ -152,8 +152,7 @@ std::tuple<tensorpipe::Message, TensorpipeWriteBuffers> tensorpipeSerialize(
152152
buffers.payload = std::move(rpcMessage->payload());
153153
// TensorPipe uses the same Message class for both reading and writing, thus
154154
// it uses non-const pointers even though it doesn't modify them when writing.
155-
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
156-
char* payloadPtr = const_cast<char*>(buffers.payload.data());
155+
char* payloadPtr = buffers.payload.data();
157156
// kTpMessagePayloadIdx = 2
158157
tpMessage.payloads.push_back(
159158
tensorpipe::Message::Payload{payloadPtr, buffers.payload.size()});

0 commit comments

Comments
 (0)