diff --git a/csrc/exceptions.cpp b/csrc/exceptions.cpp index da87f01f0af..e9da7f75a44 100644 --- a/csrc/exceptions.cpp +++ b/csrc/exceptions.cpp @@ -14,6 +14,7 @@ #include #include +#include #include #include #include diff --git a/csrc/python_frontend/fusion_record.h b/csrc/python_frontend/fusion_record.h index 0898f23ec02..b1381941610 100644 --- a/csrc/python_frontend/fusion_record.h +++ b/csrc/python_frontend/fusion_record.h @@ -366,6 +366,57 @@ struct OpRecord : RecordFunctor { std::function fusion_op_; }; +struct SliceOpRecord : RecordFunctor { + SliceOpRecord(std::vector _args, std::vector _outputs) + : RecordFunctor( + std::move(_args), + std::move(_outputs), + "ops.slice", + serde::RecordType::SliceOp) { + arg_names_[1] = "start_indices"; + arg_names_[2] = "end_indices"; + arg_names_[3] = "strides"; + } + ~SliceOpRecord() override = default; + RecordFunctor* clone() final { + return new SliceOpRecord(*this); + } + + void operator()(FusionState& fd) final { + TensorView* arg = fd.getFusionState(args_.at(0).index)->as(); + const std::vector& start = fd.getFusionStateVector(args_.at(1).index); + const std::vector& end = fd.getFusionStateVector(args_.at(2).index); + const std::vector& stride = + fd.getFusionStateVector(args_.at(3).index); + std::vector vec_slice; + for (const auto idx : c10::irange(arg->nDims())) { + // NOTE: there's an extra move, we can use emplace_back if we go write + // some constructors for Slice. + Val* start_idx = start.at(idx); + Val* end_idx = end.at(idx); + Val* stride_idx = stride.at(idx); + NVF_CHECK( + !start_idx->isConstInt() || start_idx->evaluate().as() >= 0, + "Slice operation start_indices must be greater than or equal to 0. Start Indices: ", + start_idx->evaluate().as()); + NVF_CHECK( + !start_idx->isConstInt() || !end_idx->isConstInt() || + end_idx->evaluate().as() >= + start_idx->evaluate().as(), + "Slice operation end_indices must be greater than or equal to start_indices. Start Indices: ", + start_idx->evaluate().as(), + " End Indices: ", + end_idx->evaluate().as()); + NVF_CHECK( + stride_idx->isConstInt() && stride_idx->evaluate().as() == 1, + "nvFuser Limitation: All slice operation strides must be of const size 1."); + vec_slice.push_back({start_idx, end_idx, stride_idx}); + } + auto output = slice(arg, vec_slice); + fd.setFusionState(outputs_.at(0).index, output); + } +}; + struct ReshapeOpRecord : RecordFunctor { ReshapeOpRecord(std::vector _args, std::vector _outputs) : RecordFunctor( @@ -1967,125 +2018,6 @@ struct ScalarRecord : RecordFunctor { PrimDataType dtype_; }; -struct SliceOpRecord : RecordFunctor { - SliceOpRecord( - std::vector _args, - std::vector _outputs, - std::vector start_indices, - std::vector end_indices, - std::vector strides) - : RecordFunctor( - std::move(_args), - std::move(_outputs), - "ops.slice", - serde::RecordType::SliceOp), - start_indices_(std::move(start_indices)), - end_indices_(std::move(end_indices)), - strides_(std::move(strides)) {} - ~SliceOpRecord() override = default; - RecordFunctor* clone() final { - return new SliceOpRecord(*this); - } - - //! Child specific hash function in lower 32 bits. - //! | 31 -------- 20 | 19 -------- 8 | 7 ------ 0 | - //! | start_indices | end_indices | strides | - size_t hash() const final { - auto result = RecordFunctor::hash(); - size_t start_idx_hash = 0; - for (auto i : start_indices_) { - start_idx_hash ^= static_cast(i); - } - size_t end_idx_hash = 0; - for (auto i : end_indices_) { - end_idx_hash ^= static_cast(i); - } - size_t stride_hash = 0; - for (auto i : strides_) { - stride_hash ^= static_cast(i); - } - - result |= (start_idx_hash & 0xfff) << 20; - result |= (end_idx_hash & 0xfff) << 8; - return result | (stride_hash & 0xff); - } - - bool operator==(const RecordFunctor& other) const final { - auto result = false; - if (auto child_ptr = dynamic_cast(&other)) { - result = RecordFunctor::operator==(other) && - (start_indices_ == child_ptr->start_indices_) && - (end_indices_ == child_ptr->end_indices_) && - (strides_ == child_ptr->strides_); - } - return result; - } - - void operator()(FusionState& fd) final { - auto arg = fd.getFusionState(args_.at(0).index)->as(); - TensorView* output = slice(arg, start_indices_, end_indices_, strides_); - fd.setFusionState(outputs_.at(0).index, output); - } - - void print(std::ostream& os, bool close_function = true) const final { - RecordFunctor::print(os, false); - os << ", start_indices=["; - bool first_arg = true; - for (auto idx : start_indices_) { - if (first_arg) { - first_arg = false; - } else { - os << ", "; - } - os << idx; - } - os << "], end_indices=["; - first_arg = true; - for (auto idx : end_indices_) { - if (first_arg) { - first_arg = false; - } else { - os << ", "; - } - os << idx; - } - os << "], strides=["; - first_arg = true; - for (auto stride : strides_) { - if (first_arg) { - first_arg = false; - } else { - os << ", "; - } - os << stride; - } - os << "]"; - if (close_function) { - os << ")"; - } - } - - std::pair> recordData( - flatbuffers::FlatBufferBuilder& builder) const final { - return { - serde::RecordData::Slice, - serde::CreateSliceDirect( - builder, &start_indices_, &end_indices_, &strides_) - .Union()}; - } - - private: - //! A slices beginning index for each dimension - //! Values must be greater-than or equal to 0 - std::vector start_indices_; - //! A slices end index for each dimension (excluded from the slice) - //! Values are greater than or equal to the start index for a dimension - std::vector end_indices_; - //! For a dim, the step between start and end. - //! NOTE: Strides are currently limited to steps of 1 - std::vector strides_; -}; - //! Specialized Record Functor for recording FusionDefinition Start. //! There should only ever be one instance of this Record in the //! Fusion Cache. diff --git a/csrc/python_frontend/python_bindings.cpp b/csrc/python_frontend/python_bindings.cpp index 79e7047a54f..744dcdc0dfd 100644 --- a/csrc/python_frontend/python_bindings.cpp +++ b/csrc/python_frontend/python_bindings.cpp @@ -60,7 +60,8 @@ template Vector define_vector_fn( FusionDefinition& self, ITERABLE& values, - bool inline_def = false) { + bool inline_def, + bool shape_check) { FUSER_PERF_SCOPE("python_frontend::define_vector_fn"); std::vector args; size_t idx = 0; @@ -68,7 +69,7 @@ Vector define_vector_fn( if (py::isinstance(item)) { auto int_value = py::cast(item); NVF_CHECK( - int_value >= -1, + !shape_check || int_value >= -1, "The value ", int_value, " at index ", @@ -99,11 +100,15 @@ Vector define_vector_explicit_fn( FusionDefinition& self, ITERABLE& values, PrimDataType dtype = DataType::Int) { - return define_vector_fn(self, values, /*inline_def=*/false); + return define_vector_fn( + self, values, /*inline_def=*/false, /*shape_check=*/true); } template -Vector ShapeAsVector(ShapeType shape, FusionDefinition& fd) { +Vector SequenceAsVector( + ShapeType shape, + FusionDefinition& fd, + bool shape_check = true) { static_assert( std::is_same_v || std::is_same_v || @@ -121,7 +126,8 @@ Vector ShapeAsVector(ShapeType shape, FusionDefinition& fd) { // ``` // would not work because the compiler would try to instantiate // define_vector_fn and fail. - return define_vector_fn(fd, shape, /*inline_def=*/true); + return define_vector_fn( + fd, shape, /*inline_def=*/true, /*shape_check=*/shape_check); } } @@ -134,7 +140,7 @@ Tensor broadcast_in_dim_fn( FUSER_PERF_SCOPE("Operators.broadcast_in_dim"); FusionDefinition* fd = op.fusion_definition; NVF_CHECK(op.validUse(), "Attempting to add to a completed definition!"); - Vector output_shape = ShapeAsVector(generic_output_shape, *fd); + Vector output_shape = SequenceAsVector(generic_output_shape, *fd); NVF_CHECK( output_shape.size >= broadcast_dims.size(), "broadcast_dims vector size is too big for output shape!"); @@ -156,7 +162,7 @@ Tensor full_op_fn( PrimDataType dtype) { NVF_CHECK(self.validUse(), "Attempting to add to a completed definition!"); FusionDefinition* fd = self.fusion_definition; - Vector output_shape = ShapeAsVector(generic_output_shape, *fd); + Vector output_shape = SequenceAsVector(generic_output_shape, *fd); Tensor output = fd->defineTensor(output_shape.size); fd->defineRecord(new FullOpRecord( {fd->recordingState(output_shape()), fd->recordingState(fill_value())}, @@ -173,7 +179,7 @@ Tensor reshape_fn( NVF_CHECK(self.validUse(), "Attempting to add to a completed definition!"); FusionDefinition* fd = self.fusion_definition; - Vector new_shape = ShapeAsVector(generic_new_shape, *fd); + Vector new_shape = SequenceAsVector(generic_new_shape, *fd); Tensor output = fd->defineTensor(new_shape.size); fd->defineRecord(new ReshapeOpRecord( @@ -200,7 +206,7 @@ Tensor random_dist_op_fn( "Random distributions only create floating point types! ", dtype); FusionDefinition* fd = self.fusion_definition; - Vector new_shape = ShapeAsVector(generic_new_shape, *fd); + Vector new_shape = SequenceAsVector(generic_new_shape, *fd); Tensor output = fd->defineTensor(new_shape.size); std::vector arg_states = { @@ -235,6 +241,78 @@ struct DimInfo { } }; +template +Tensor slice_fn( + FusionDefinition::Operators& self, + Tensor arg, + ShapeType start, + ShapeType end, + std::optional strides) { + NVF_CHECK(self.validUse(), "Attempting to add to a completed definition!"); + + FusionDefinition* fd = self.fusion_definition; + Vector new_start = SequenceAsVector(start, *fd, /*shape_check=*/false); + Vector new_end = SequenceAsVector(end, *fd, /*shape_check=*/false); + size_t stride_index = 0; + + if (strides.has_value()) { + Vector new_stride = + SequenceAsVector(strides.value(), *fd, /*shape_check=*/false); + NVF_CHECK( + new_start.size == new_stride.size, + "Slice start_indices and strides don't match! Start Indices: ", + new_start.size, + " Strides: ", + new_stride.size); + stride_index = new_stride(); + } else { + // set stride with default value; + std::vector stride_vec; + stride_vec.reserve(new_start.size); + // Note: we cannot re-use the same ScalarRecord, otherwise, serialized + // python program uses `define_vector`, which would create multiple + // ScalarRecord, causing a cache miss. + for (auto i : c10::irange(new_start.size)) { + (void)i; // Supress unused variable warning + Scalar out = fd->defineScalar(); + fd->defineRecord(new ScalarRecord( + {fd->recordingState(out())}, + 1, + DataType::Int, + /*inline_def=*/true)); + stride_vec.push_back(out); + } + // Cannot inline definition with `Vector` here, since + // `FusionDefinition.ops.slice` expects start/end/stride to have the same + // type. + Vector default_stride = define_vector_base_fn( + *fd, stride_vec, !std::is_same_v); + stride_index = default_stride(); + } + + NVF_CHECK( + arg.dims == new_start.size, + "Number of tensor dimensions does not match slice dimensions! Tensor-dims: ", + arg.dims, + " Slice-dims: ", + new_start.size); + NVF_CHECK( + new_start.size == new_end.size, + "Slice indexing attribute dimensions don't match! Start Indices: ", + new_start.size, + " End Indices: ", + new_end.size); + + Tensor output = fd->defineTensor(arg.dims); + fd->defineRecord(new SliceOpRecord( + {fd->recordingState(arg()), + fd->recordingState(new_start()), + fd->recordingState(new_end()), + fd->recordingState(stride_index)}, + {fd->recordingState(output())})); + return output; +} + } // namespace std::vector> computeContiguity( @@ -2714,87 +2792,23 @@ void initNvFuserPythonBindings(PyObject* module) { nvf_ops.def( "slice", - [](FusionDefinition::Operators& self, - Tensor arg, - const std::vector& start_indices, - const std::vector& end_indices, - // NOTE: Tried to use std::reference_wrapper to a vector and during - // testing, I was not getting the proper value back. It was like - // like the code was referencing the strides vector that holds the - // default value. - std::optional> opt_strides = - std::nullopt) -> Tensor { - FUSER_PERF_SCOPE("Operators.slice"); - NVF_CHECK( - self.validUse(), "Attempting to add to a completed definition!"); - - std::vector strides; - if (opt_strides.has_value()) { - NVF_CHECK( - start_indices.size() == opt_strides.value().size(), - "Slice start_indices and strides don't match! Start Indices: ", - start_indices.size(), - " Strides: ", - opt_strides.value().size()); - strides.assign( - opt_strides.value().begin(), opt_strides.value().end()); - } else { - strides.resize(start_indices.size(), 1); - } - - NVF_CHECK( - arg.dims == start_indices.size(), - "Number of tensor dimensions does not match slice dimensions! Tensor-dims: ", - arg.dims, - " Slice-dims: ", - start_indices.size()); - NVF_CHECK( - start_indices.size() == end_indices.size(), - "Slice indexing attribute dimensions don't match! Start Indices: ", - start_indices.size(), - " End Indices: ", - end_indices.size(), - " Strides: ", - strides.size()); - for (const auto i : c10::irange(arg.dims)) { - auto start_idx = start_indices[i]; - auto end_idx = end_indices[i]; - auto stride = strides[i]; - NVF_CHECK( - start_idx >= 0, - "Slice operation start_indices must be greater-than-or-equal-to 0. Start Indices: ", - start_indices, - " End Indices: ", - end_indices, - " Strides: ", - strides); - NVF_CHECK( - end_idx >= start_idx, - "Slice operation end_indices must be greater-than-or-equal-to start_indices. Start Indices: ", - start_indices, - " End Indices: ", - end_indices, - " Strides: ", - strides); - NVF_CHECK( - stride == 1, - "nvFuser Limitation: All slice operation strides must be of size 1. Start Indices: ", - start_indices, - " End Indices: ", - end_indices, - " Strides: ", - strides); - } - FusionDefinition* fd = self.fusion_definition; - Tensor output = fd->defineTensor(arg.dims); - fd->defineRecord(new SliceOpRecord( - {fd->recordingState(arg())}, - {fd->recordingState(output())}, - start_indices, - end_indices, - strides)); - return output; - }, + slice_fn, + py::arg("arg"), + py::arg("start_indices"), + py::arg("end_indices"), + py::arg("strides") = py::none(), + py::return_value_policy::reference); + nvf_ops.def( + "slice", + slice_fn, + py::arg("arg"), + py::arg("start_indices"), + py::arg("end_indices"), + py::arg("strides") = py::none(), + py::return_value_policy::reference); + nvf_ops.def( + "slice", + slice_fn, py::arg("arg"), py::arg("start_indices"), py::arg("end_indices"), diff --git a/csrc/serde/fusion_record.cpp b/csrc/serde/fusion_record.cpp index db208650e6f..17ce4a152f0 100644 --- a/csrc/serde/fusion_record.cpp +++ b/csrc/serde/fusion_record.cpp @@ -525,13 +525,8 @@ void RecordFunctorFactory::registerAllParsers() { registerParser(RecordType::ReshapeOp, deserializeReshapeRecord); auto deserializeSliceRecord = [](const RecordFunctor* buffer) { - auto data = buffer->data_as_Slice(); return new python_frontend::SliceOpRecord( - parseStateArgs(buffer->args()), - parseStateArgs(buffer->outputs()), - parseVector(data->start_indices()), - parseVector(data->end_indices()), - parseVector(data->strides())); + parseStateArgs(buffer->args()), parseStateArgs(buffer->outputs())); }; registerParser(RecordType::SliceOp, deserializeSliceRecord); diff --git a/tests/python/opinfo_input_generators.py b/tests/python/opinfo_input_generators.py index e369c03f9e7..fb5a0ea1a20 100644 --- a/tests/python/opinfo_input_generators.py +++ b/tests/python/opinfo_input_generators.py @@ -1153,17 +1153,17 @@ def slice_error_generator( check_start_indices = ErrorSample( {"start_indices": [-1, -2], "end_indices": [5, 5], "strides": [7, 7]}, - "Slice operation start_indices must be greater-than-or-equal-to 0.", + "Slice operation start_indices must be greater than or equal to 0.", ) check_end_indices = ErrorSample( {"start_indices": [3, 4], "end_indices": [1, 2], "strides": [1, 1]}, - "Slice operation end_indices must be greater-than-or-equal-to start_indices.", + "Slice operation end_indices must be greater than or equal to start_indices.", ) check_strides = ErrorSample( {"start_indices": [0, 0], "end_indices": [5, 5], "strides": [5, 5]}, - "nvFuser Limitation: All slice operation strides must be of size 1.", + "nvFuser Limitation: All slice operation strides must be of const size 1.", ) check_tensor_dims = ErrorSample( diff --git a/tests/python/test_python_frontend.py b/tests/python/test_python_frontend.py index 1bf5824c0bf..f793428ac1c 100644 --- a/tests/python/test_python_frontend.py +++ b/tests/python/test_python_frontend.py @@ -2103,15 +2103,15 @@ def legal(fd: FusionDefinition, acts) -> None: checks = [ ( check_start_indices, - "Slice operation start_indices must be greater-than-or-equal-to 0. .*", + "Slice operation start_indices must be greater than or equal to 0. .*", ), ( check_end_indices, - "Slice operation end_indices must be greater-than-or-equal-to start_indices. .*", + "Slice operation end_indices must be greater than or equal to start_indices. .*", ), ( check_strides, - "nvFuser Limitation: All slice operation strides must be of size 1. .*", + "nvFuser Limitation: All slice operation strides must be of const size 1.*", ), ( check_tensor_dims, @@ -4345,3 +4345,35 @@ def fusion_func(fd: FusionDefinition) -> None: ): with FusionDefinition() as fd: fusion_func(fd) + + def test_slice_api(self): + x = torch.randn((2, 5, 10), dtype=torch.float32, device="cuda:0") + + offset = (0, 1, 2) + + def fusion_func(fd: FusionDefinition) -> None: + T0 = fd.define_tensor( + shape=[-1, -1, -1], + contiguity=[True, True, True], + dtype=DataType.Float, + is_cpu=False, + stride_order=[2, 1, 0], + ) + T1 = fd.ops.slice( + T0, start_indices=offset, end_indices=(2, 5, 10), strides=(1, 1, 1) + ) + fd.add_output(T1) + V_start = fd.define_vector(offset) + V_end = T0.shape() + T2 = fd.ops.slice(T0, V_start, V_end) + fd.add_output(T2) + dynamic_start = fd.define_vector(3) + dynamic_end = fd.define_vector(3) + T3 = fd.ops.slice(T0, dynamic_start, dynamic_end) + fd.add_output(T3) + + inputs = [x, *offset, *x.shape] + + nvf_out, _ = self.exec_nvfuser(fusion_func, inputs) + for out in nvf_out: + self.assertTrue(out.allclose(x[:, 1:, 2:])) diff --git a/version.txt b/version.txt index d3b5ba4bfc3..f2722b13396 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -0.2.11 +0.2.12