Skip to content

Commit 65b4080

Browse files
Revert "Relandx3 "SymIntify cat and narrow" (pytorch#86289)"
This reverts commit a00f848. Reverted pytorch#86289 on behalf of https://github.com/malfet due to @seemether unlanded the rest of the stack and it will fail intern import anyway
1 parent 5b69b87 commit 65b4080

18 files changed

+60
-127
lines changed

aten/src/ATen/BatchingRegistrations.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -1096,9 +1096,9 @@ TORCH_LIBRARY_IMPL(aten, Batched, m) {
10961096
m.impl("expand_as", native::expand_as); // composite wrt autograd
10971097
m.impl("movedim.intlist", movedim_batching_rule);
10981098
m.impl("movedim.int", static_cast<Tensor(*)(const Tensor&,int64_t,int64_t)>(native::movedim)); // composite wrt autograd
1099-
// There is another variant of narrow. However, we don't
1099+
// NB: static_cast because there's another variant of narrow. However, we don't
11001100
// want to support the other variant yet bc it isn't documented...
1101-
m.impl("narrow", native::narrow_symint); // composite wrt autograd
1101+
m.impl("narrow", static_cast<Tensor(*)(const Tensor&,int64_t,int64_t,int64_t)>(native::narrow)); // composite wrt autograd
11021102
m.impl("numpy_T", native::numpy_T); // composite wrt autograd
11031103
m.impl("matrix_H", native::matrix_H); // composite wrt autograd
11041104
m.impl("mT", native::mT); // composite wrt autograd

aten/src/ATen/WrapDimUtils.h

+19-27
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,22 @@
88

99
namespace at {
1010

11-
// if dim_post_expr is 0 and wrap_scalar is true, then dim must be in the
12-
// range [-1, 0]. This is a special case for scalar tensors and manifests in
13-
// e.g. torch.sum(scalar_tensor, 0) Otherwise, dim should be in the range
14-
// [-dim_post_expr, dim_post_expr-1].
15-
using c10::maybe_wrap_dim;
11+
static inline int64_t maybe_wrap_dim(
12+
int64_t dim,
13+
int64_t dim_post_expr,
14+
bool wrap_scalar = true) {
15+
// if dim_post_expr is 0 and wrap_scalar is true, then dim must be in the
16+
// range [-1, 0]. This is a special case for scalar tensors and manifests in
17+
// e.g. torch.sum(scalar_tensor, 0) Otherwise, dim should be in the range
18+
// [-dim_post_expr, dim_post_expr-1].
19+
return c10::maybe_wrap_dim(dim, dim_post_expr, wrap_scalar);
20+
}
1621

17-
inline int64_t maybe_wrap_dim(int64_t dim, TensorImpl* tensor) {
22+
static inline int64_t maybe_wrap_dim(int64_t dim, TensorImpl* tensor) {
1823
return maybe_wrap_dim(dim, tensor->dim());
1924
}
2025

21-
inline int64_t maybe_wrap_dim(int64_t dim, TensorList tensors) {
26+
static inline int64_t maybe_wrap_dim(int64_t dim, TensorList tensors) {
2227
if (tensors.size() == 0) {
2328
// can't wrap empty TensorList; rely on underlying implementation to throw
2429
// error if necessary.
@@ -27,7 +32,7 @@ inline int64_t maybe_wrap_dim(int64_t dim, TensorList tensors) {
2732
return maybe_wrap_dim(dim, tensors[0].dim());
2833
}
2934

30-
inline int64_t maybe_wrap_dim(
35+
static inline int64_t maybe_wrap_dim(
3136
int64_t dim,
3237
const std::vector<std::vector<int64_t>>& tensor_sizes) {
3338
if (tensor_sizes.size() == 0) {
@@ -40,7 +45,7 @@ inline int64_t maybe_wrap_dim(
4045

4146
// wrap each dim in the dims array, taking dim_post_expr as the true number of
4247
// dimensions
43-
inline void maybe_wrap_dims_n(
48+
static inline void maybe_wrap_dims_n(
4449
int64_t* dims,
4550
int64_t ndims,
4651
int64_t dim_post_expr) {
@@ -80,32 +85,19 @@ inline void maybe_wrap_dims(Container& dims, int64_t dim_post_expr) {
8085
// dimension behavior and dimension size checking). We maintain this behavior
8186
// for backwards compatibility, but only for this specific size (i.e. other
8287
// empty sizes are not skipped).
83-
template <typename T>
84-
inline int64_t _legacy_cat_wrap_dim(
88+
static inline int64_t legacy_cat_wrap_dim(
8589
int64_t dim,
86-
const std::vector<std::vector<T>>& tensor_sizes) {
90+
const std::vector<std::vector<int64_t>>& tensor_sizes) {
8791
for (auto& sizes : tensor_sizes) {
88-
if (sizes == std::vector<T>({0})) {
92+
if (sizes == std::vector<int64_t>({0})) {
8993
continue;
9094
}
9195
return maybe_wrap_dim(dim, sizes.size());
9296
}
9397
return dim;
9498
}
9599

96-
inline int64_t legacy_cat_wrap_dim(
97-
int64_t dim,
98-
const std::vector<std::vector<int64_t>>& tensor_sizes) {
99-
return _legacy_cat_wrap_dim<int64_t>(dim, tensor_sizes);
100-
}
101-
102-
inline int64_t legacy_cat_wrap_dim_symint(
103-
int64_t dim,
104-
const std::vector<std::vector<c10::SymInt>>& tensor_sizes) {
105-
return _legacy_cat_wrap_dim<c10::SymInt>(dim, tensor_sizes);
106-
}
107-
108-
inline int64_t legacy_cat_wrap_dim(
100+
static inline int64_t legacy_cat_wrap_dim(
109101
int64_t dim,
110102
const MaterializedITensorListRef& tensors) {
111103
for (const Tensor& tensor : tensors) {
@@ -118,7 +110,7 @@ inline int64_t legacy_cat_wrap_dim(
118110
}
119111

120112
// wrap negative dims in a vector
121-
inline void wrap_all_dims(
113+
static inline void wrap_all_dims(
122114
std::vector<int64_t>& dims_to_wrap,
123115
int64_t tensor_total_dims) {
124116
for (const auto i : c10::irange(dims_to_wrap.size())) {

aten/src/ATen/functorch/BatchRulesDecompositions.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
168168
OP_DECOMPOSE2(movedim, int);
169169
OP_DECOMPOSE(msort);
170170
OP_DECOMPOSE(mT);
171-
m.impl("narrow", native::narrow_symint);
171+
OP_DECOMPOSE(narrow);
172172
OP_DECOMPOSE(negative);
173173
OP_DECOMPOSE2(frobenius_norm, dim);
174174
OP_DECOMPOSE2(nuclear_norm, dim);

aten/src/ATen/native/NonSymbolicBC.h

-1
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,4 @@ namespace native {
99
// However, in certain cases (such as static runtime), we call the native versions of the ops directly.
1010
// In those cases, we will duplicate the signature here with non-symbolic ints, and also duplicate the C++ implementation.
1111
TORCH_API at::Tensor reshape(const at::Tensor& self, at::IntArrayRef proposed_shape);
12-
TORCH_API at::Tensor narrow(const at::Tensor& self, int64_t dim, int64_t start, int64_t length);
1312
}}

aten/src/ATen/native/TensorShape.cpp

+2-16
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
#include <ATen/core/DimVector.h>
1212
#include <ATen/core/IListRef.h>
1313
#include <ATen/native/Copy.h>
14-
#include <ATen/native/NonSymbolicBC.h>
1514
#include <ATen/native/Resize.h>
1615
#include <ATen/native/TensorIterator.h>
1716
#include <ATen/native/TensorShape.h>
@@ -1137,24 +1136,11 @@ Tensor narrow(const Tensor& self, int64_t dim, int64_t start, int64_t length) {
11371136
return at::slice(self, dim, start, start + length, 1);
11381137
}
11391138

1140-
Tensor narrow_symint(const Tensor& self, int64_t dim, SymInt start, SymInt length) {
1141-
TORCH_CHECK(self.dim() > 0, "narrow() cannot be applied to a 0-dim tensor.");
1142-
auto cur_size = self.sym_size(dim);
1143-
if (start != cur_size) { // start being the end is valid, but not a valid dim specification.
1144-
start = maybe_wrap_dim(start, cur_size);
1145-
}
1146-
TORCH_CHECK(length >= 0 && start <= cur_size - length,
1147-
"start (", start, ") + length (", length, ") exceeds dimension size (", cur_size, ").");
1148-
return at::slice_symint(self, dim, start, start + length, 1);
1149-
}
1150-
1151-
// This overload exists purely for XLA, because they wanted to pass in "symbolic"
1152-
// start via Tensor.
1153-
Tensor narrow_tensor_symint(const Tensor& self, int64_t dim, const Tensor& start, SymInt length) {
1139+
Tensor narrow(const Tensor& self, int64_t dim, const Tensor& start, int64_t length) {
11541140
TORCH_CHECK(start.dim() == 0 && isIntegralType(start.scalar_type(), /*includeBool=*/false),
11551141
"start must be an 0-dim integral Tensor.");
11561142
int64_t st = start.item<int64_t>();
1157-
return at::narrow_symint(self, dim, c10::SymInt(st), length);
1143+
return at::narrow(self, dim, st, length);
11581144
}
11591145

11601146
std::tuple<DimVector, DimVector, std::vector<int64_t>>

aten/src/ATen/native/native_functions.yaml

+2-6
Original file line numberDiff line numberDiff line change
@@ -3710,19 +3710,15 @@
37103710
dispatch:
37113711
CPU: narrow_copy_dense_cpu_out
37123712

3713-
- func: narrow(Tensor(a) self, int dim, SymInt start, SymInt length) -> Tensor(a)
3713+
- func: narrow(Tensor(a) self, int dim, int start, int length) -> Tensor(a)
37143714
variants: function, method
37153715
device_check: NoCheck
37163716
device_guard: False
3717-
dispatch:
3718-
CompositeImplicitAutograd: narrow_symint
37193717

3720-
- func: narrow.Tensor(Tensor(a) self, int dim, Tensor start, SymInt length) -> Tensor(a)
3718+
- func: narrow.Tensor(Tensor(a) self, int dim, Tensor start, int length) -> Tensor(a)
37213719
variants: function, method
37223720
device_check: NoCheck
37233721
device_guard: False
3724-
dispatch:
3725-
CompositeImplicitAutograd: narrow_tensor_symint
37263722

37273723
- func: native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)
37283724
dispatch:

aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp

-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
#include <ATen/native/transformers/cuda/flash_attn/fmha_api.h>
1818
#endif
1919

20-
#include <ATen/native/NonSymbolicBC.h>
2120
#include <ATen/native/nested/NestedTensorTransformerFunctions.h>
2221
#include <ATen/native/nested/NestedTensorMath.h>
2322
#include <ATen/native/nested/NestedTensorUtils.h>

aten/src/ATen/native/transformers/cuda/attention.cu

-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
#include <c10/cuda/CUDAMathCompat.h>
1818

19-
#include <ATen/native/NonSymbolicBC.h>
2019
#include <ATen/native/nested/NestedTensorUtils.h>
2120
#include <ATen/native/nested/NestedTensorTransformerFunctions.h>
2221

c10/core/WrapDimMinimal.cpp

+6-10
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@
33
namespace c10 {
44
namespace detail {
55

6-
template <typename T>
7-
T maybe_wrap_dim_slow(T dim, T dim_post_expr, bool wrap_scalar) {
6+
int64_t maybe_wrap_dim_slow(
7+
int64_t dim,
8+
int64_t dim_post_expr,
9+
bool wrap_scalar) {
810
TORCH_CHECK_INDEX(
911
dim_post_expr >= 0, "Rank cannot be negative but got ", dim_post_expr);
1012

@@ -17,8 +19,8 @@ T maybe_wrap_dim_slow(T dim, T dim_post_expr, bool wrap_scalar) {
1719
return c10::maybe_wrap_dim(dim, /*dim_post_expr=*/1, /*wrap_scalar=*/false);
1820
}
1921

20-
T min = dim_post_expr * -1;
21-
T max = dim_post_expr - 1;
22+
int64_t min = -dim_post_expr;
23+
int64_t max = dim_post_expr - 1;
2224
TORCH_CHECK_INDEX(
2325
min <= dim && dim <= max,
2426
"Dimension out of range (expected to be in range of [",
@@ -33,11 +35,5 @@ T maybe_wrap_dim_slow(T dim, T dim_post_expr, bool wrap_scalar) {
3335
false, "should never reach here as dim should be out-of-bounds");
3436
}
3537

36-
// Explicitly instantiate the template at the two types it will be used
37-
template C10_API int64_t
38-
maybe_wrap_dim_slow(int64_t dim, int64_t dim_post_expr, bool wrap_scalar);
39-
template C10_API SymInt
40-
maybe_wrap_dim_slow(SymInt dim, SymInt dim_post_expr, bool wrap_scalar);
41-
4238
} // namespace detail
4339
} // namespace c10

c10/core/WrapDimMinimal.h

+10-29
Original file line numberDiff line numberDiff line change
@@ -1,44 +1,25 @@
11
#pragma once
22

3-
#include <c10/core/SymInt.h>
43
#include <c10/util/Exception.h>
54

65
namespace c10 {
76

87
namespace detail {
9-
// This template can only be specialized at int64_t and c10::SymInt;
10-
// you'll get linker errors otherwise
11-
template <typename T>
12-
C10_API T maybe_wrap_dim_slow(T dim, T dim_post_expr, bool wrap_scalar);
13-
} // namespace detail
14-
15-
template <typename T>
16-
T _maybe_wrap_dim(T dim, T dim_post_expr, bool wrap_scalar = true) {
17-
// Inline the fast paths
18-
if (C10_LIKELY(dim_post_expr * -1 <= dim && dim < dim_post_expr)) {
19-
// For SymInts, we want an explicit control flow to trigger a guard, so we
20-
// may as well branch too.
21-
if (dim < 0) {
22-
return dim + dim_post_expr;
23-
}
24-
return dim;
25-
}
26-
// Check edge-cases out-of-line (wrapping scalars and out-of-bounds errors)
27-
return c10::detail::maybe_wrap_dim_slow<T>(dim, dim_post_expr, wrap_scalar);
8+
C10_API int64_t
9+
maybe_wrap_dim_slow(int64_t dim, int64_t dim_post_expr, bool wrap_scalar);
2810
}
2911

30-
inline int64_t maybe_wrap_dim(
12+
static inline int64_t maybe_wrap_dim(
3113
int64_t dim,
3214
int64_t dim_post_expr,
3315
bool wrap_scalar = true) {
34-
return _maybe_wrap_dim(dim, dim_post_expr, wrap_scalar);
35-
}
36-
37-
inline c10::SymInt maybe_wrap_dim(
38-
c10::SymInt dim,
39-
c10::SymInt dim_post_expr,
40-
bool wrap_scalar = true) {
41-
return _maybe_wrap_dim(dim, dim_post_expr, wrap_scalar);
16+
// Inline the fast paths
17+
if (C10_LIKELY(-dim_post_expr <= dim && dim < dim_post_expr)) {
18+
// Branch-less version of dim + (dim < 0 ? dim_post_expr : 0)
19+
return dim + dim_post_expr * (dim < 0);
20+
}
21+
// Check edge-cases out-of-line (wrapping scalars and out-of-bounds errors)
22+
return c10::detail::maybe_wrap_dim_slow(dim, dim_post_expr, wrap_scalar);
4223
}
4324

4425
} // namespace c10

functorch/test/test_aotdispatch.py

+6
Original file line numberDiff line numberDiff line change
@@ -738,13 +738,15 @@ def assert_compiler(gm: torch.fx.GraphModule, _):
738738
xfail('block_diag', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
739739
xfail('broadcast_tensors', ''), # 'int' and 'torch._C.SymIntNode'
740740
xfail('cartesian_prod', ''), # Cannot call numel() on tensor with symbolic sizes/strides
741+
xfail('cat', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
741742
xfail('cdist', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
742743
xfail('cholesky_inverse', ''), # could not find kernel
743744
xfail('cholesky_solve', ''), # could not find kernel
744745
xfail('chunk', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
745746
xfail('column_stack', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
746747
xfail('combinations', ''), # aten.masked_select.default
747748
xfail('complex', ''), # aten.view_as_real.default - couldn't find symbolic meta function/decomposition
749+
xfail('constant_pad_nd', ''), # aten.fill.Scalar - couldn't find symbolic meta function/decomposition
748750
xfail('cross', ''), # aten.linalg_cross.default - couldn't find symbolic meta function/decomposition
749751
xfail('cummax', ''), # aten.cummax.default - couldn't find symbolic meta function/decomposition
750752
xfail('cummin', ''), # aten.cummin.default - couldn't find symbolic meta function/decomposition
@@ -759,6 +761,7 @@ def assert_compiler(gm: torch.fx.GraphModule, _):
759761
xfail('digamma', ''), # aten.polygamma.default - couldn't find symbolic meta function/decomposition
760762
xfail('dist', ''), # aten.dist.default - couldn't find symbolic meta function/decomposition
761763
xfail('dsplit', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
764+
xfail('dstack', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
762765
xfail('einsum', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
763766
xfail('expand_as', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
764767
xfail('fft.fft2', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
@@ -788,6 +791,7 @@ def assert_compiler(gm: torch.fx.GraphModule, _):
788791
xfail('gather', ''), # aten.gather.default - couldn't find symbolic meta function/decomposition
789792
xfail('gradient', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
790793
xfail('hsplit', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
794+
xfail('hstack', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
791795
xfail('i0', ''), # aten.i0.default - couldn't find symbolic meta function/decomposition
792796
xfail('index_copy', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
793797
xfail('index_fill', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
@@ -938,6 +942,7 @@ def assert_compiler(gm: torch.fx.GraphModule, _):
938942
xfail('nn.functional.nll_loss', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
939943
xfail('nn.functional.normalize', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
940944
xfail('nn.functional.pad', 'circular'), # Cannot call sizes() on tensor with symbolic sizes/strides
945+
xfail('nn.functional.pad', 'constant'), # aten.fill.Scalar - couldn't find symbolic meta function/decom...
941946
xfail('nn.functional.pad', 'reflect'), # aten.reflection_pad1d.default - couldn't find symbolic meta fu...
942947
xfail('nn.functional.pad', 'replicate'), # aten.replication_pad1d.default - couldn't find symbolic meta...
943948
xfail('nn.functional.pairwise_distance', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
@@ -1029,6 +1034,7 @@ def assert_compiler(gm: torch.fx.GraphModule, _):
10291034
xfail('view_as_complex', ''), # aten.view_as_complex.default - couldn't find symbolic meta function/deco...
10301035
xfail('view_as', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
10311036
xfail('vsplit', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
1037+
xfail('vstack', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
10321038
}
10331039

10341040
def _test_aot_autograd_helper(self, device, dtype, op):

test/test_proxy_tensor.py

+5
Original file line numberDiff line numberDiff line change
@@ -1066,6 +1066,7 @@ def f(a, b, c, d, e):
10661066
xfail('cholesky_solve', ''), # Could not run 'aten::_cholesky_solve_helper' with arguments from the 'Meta' back...
10671067
xfail('chunk', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
10681068
xfail('column_stack', ''), # Tensors of type TensorImpl do not have numel
1069+
xfail('constant_pad_nd', ''), # aten.fill.Scalar - couldn't find symbolic meta function/decomposition
10691070
xfail('count_nonzero', ''), # Could not run 'aten::count_nonzero.dim_IntList' with arguments from the 'Meta' ba...
10701071
xfail('cross', ''), # aten.linalg_cross.default - couldn't find symbolic meta function/decomposition
10711072
xfail('cummax', ''), # aten.cummax.default - couldn't find symbolic meta function/decomposition
@@ -1206,6 +1207,7 @@ def f(a, b, c, d, e):
12061207
xfail('nn.functional.interpolate', 'linear'), # aten.upsample_linear1d.vec - couldn't find symbolic meta function/dec...
12071208
xfail('nn.functional.interpolate', 'nearest'), # aten.upsample_nearest1d.vec - couldn't find symbolic meta function/d...
12081209
xfail('nn.functional.interpolate', 'trilinear'), # aten.upsample_trilinear3d.vec - couldn't find symbolic meta functi...
1210+
xfail('nn.functional.local_response_norm', ''), # Tensors of type TensorImpl do not have numel
12091211
xfail('nn.functional.margin_ranking_loss', ''), # The underlying op of 'aten.stride' has no overload name '_schema'
12101212
xfail('nn.functional.max_pool1d', ''), # Trying to call aten.size on a tensor with symbolic shapes.
12111213
xfail('nn.functional.max_pool3d', ''), # aten.max_pool3d_with_indices.default - couldn't find symbolic meta function/d...
@@ -1215,6 +1217,7 @@ def f(a, b, c, d, e):
12151217
xfail('nn.functional.multi_margin_loss', ''), # Could not run 'aten::multi_margin_loss' with arguments from the...
12161218
xfail('nn.functional.multilabel_margin_loss', ''), # Could not run 'aten::multilabel_margin_loss_forward' with ...
12171219
xfail('nn.functional.pad', 'circular'), # aten.size.default - couldn't find symbolic meta function/decomposition
1220+
xfail('nn.functional.pad', 'constant'), # aten.fill.Scalar - couldn't find symbolic meta function/decomposition
12181221
xfail('nn.functional.pad', 'reflect'), # aten.reflection_pad1d.default - couldn't find symbolic meta function/decompo...
12191222
xfail('nn.functional.pad', 'replicate'), # aten.replication_pad1d.default - couldn't find symbolic meta function/deco...
12201223
xfail('nn.functional.pdist', ''), # Could not run 'aten::_pdist_forward' with arguments from the 'Meta' backend...
@@ -1284,6 +1287,8 @@ def f(a, b, c, d, e):
12841287
xfail('special.scaled_modified_bessel_k1', ''), # aten.special_scaled_modified_bessel_k1.default - couldn't find symbo...
12851288
xfail('special.xlog1py', ''), # aten.special_xlog1py.default - couldn't find symbolic meta function/decomposition
12861289
xfail('split', ''), # 'torch._C.SymIntNode' and 'int'
1290+
xfail('split', 'list_args'), # aten.size.default - couldn't find symbolic meta function/decomposition
1291+
xfail('split_with_sizes', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
12871292
xfail('stft', ''), # argument 'size' must be tuple of ints, but found element of type torch._C.SymIntNode at...
12881293
xfail('sum_to_size', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
12891294
xfail('svd', ''), # aten._linalg_svd.default - couldn't find symbolic meta function/decomposition

tools/autograd/derivatives.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -354,7 +354,7 @@
354354
self, other: matmul_backward(grad, self, other, grad_input_mask)
355355

356356
- name: cat(Tensor[] tensors, int dim=0) -> Tensor
357-
tensors: cat_tensors_backward(grad, to_args_sizes_symint(tensors), to_args_scalartypes(tensors), dim)
357+
tensors: cat_tensors_backward(grad, to_args_sizes(tensors), to_args_scalartypes(tensors), dim)
358358
result: cat_jvp(tensors, dim)
359359

360360
- name: cauchy_(Tensor(a!) self, float median=0, float sigma=1, *, Generator? generator=None) -> Tensor(a!)

tools/autograd/gen_python_functions.py

-5
Original file line numberDiff line numberDiff line change
@@ -1160,11 +1160,6 @@ def is_arg_smaller(t1: Type, t2: Type) -> bool:
11601160
# Prioritize IntArrayRef overload over SymIntArrayRef
11611161
str(t1) == "SymInt[]"
11621162
and str(t2) == "int[]"
1163-
or
1164-
# Make sure both in, SymInt are sorted consistently w.r.t. Tensor since Tensor can be implicitly
1165-
# converted to either int or SymInt. Prioritize the Tensor overload since it otherwise gets shadowed.
1166-
(str(t1) == "SymInt" or str(t1) == "int")
1167-
and str(t2) == "Tensor"
11681163
)
11691164

11701165
def is_smaller(s1: PythonSignature, s2: PythonSignature) -> bool:

0 commit comments

Comments
 (0)