Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Slice python api #2932

Merged
merged 51 commits into from
Sep 20, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
93328b8
WIP
jjsjann123 Sep 11, 2024
77e04da
WIP
jjsjann123 Sep 11, 2024
0dda7a9
WIP
jjsjann123 Sep 11, 2024
2829743
fixing build
jjsjann123 Sep 11, 2024
64cfc8c
fixing build
jjsjann123 Sep 11, 2024
227e2b5
fixing build
jjsjann123 Sep 11, 2024
ae3c2d0
quick fix on kwargs
jjsjann123 Sep 11, 2024
c42bfe1
remove option thing
jjsjann123 Sep 11, 2024
266a4bb
typo
jjsjann123 Sep 11, 2024
9c9ea9a
Merge branch 'main' into slice_python_api
jjsjann123 Sep 16, 2024
43bfaec
CLANGFORMAT
jjsjann123 Sep 16, 2024
2e1fac5
fixing it?!
jjsjann123 Sep 16, 2024
6f011c6
typo
jjsjann123 Sep 16, 2024
ef56b1b
fixing test/build
jjsjann123 Sep 16, 2024
2be6beb
fixing logic
jjsjann123 Sep 16, 2024
69975f1
clangformat
jjsjann123 Sep 16, 2024
53ad605
some error check/message
jjsjann123 Sep 16, 2024
67ed3e4
fixing build, avoiding check
jjsjann123 Sep 16, 2024
4f15739
fixing error message
jjsjann123 Sep 16, 2024
804ec36
quick hack on exception
jjsjann123 Sep 16, 2024
691198d
wip
jjsjann123 Sep 17, 2024
7ee8f26
Merge branch 'main' into slice_python_api
jjsjann123 Sep 17, 2024
e15be6f
clangformat; quick fix on error message
jjsjann123 Sep 17, 2024
128aee5
fixing build
jjsjann123 Sep 17, 2024
a890311
fixing error message
jjsjann123 Sep 17, 2024
e9da45b
removing print
jjsjann123 Sep 17, 2024
19aa2e0
allow fusion cache to cache error message
jjsjann123 Sep 17, 2024
88f81dd
test added
jjsjann123 Sep 17, 2024
d389d0a
black
jjsjann123 Sep 17, 2024
0e72307
revert changes in #2953
jjsjann123 Sep 17, 2024
911d7c1
Merge remote-tracking branch 'origin/fusion_cache_error_cache' into HEAD
jjsjann123 Sep 17, 2024
175c82c
black
jjsjann123 Sep 17, 2024
8870702
Merge branch 'main' into fusion_cache_error_cache
jjsjann123 Sep 18, 2024
cc41d1d
comment message
jjsjann123 Sep 18, 2024
ecc1316
bumping version
jjsjann123 Sep 18, 2024
b872194
fixing deserialization with exception handling
jjsjann123 Sep 18, 2024
35d9fac
fixing typo
jjsjann123 Sep 18, 2024
dc610d4
BLACK
jjsjann123 Sep 18, 2024
0916e1a
Merge remote-tracking branch 'origin/fusion_cache_error_cache' into HEAD
jjsjann123 Sep 18, 2024
f2c511a
Merge remote-tracking branch 'origin/multidevice_typo_fix' into HEAD
jjsjann123 Sep 18, 2024
295b56c
review comments
jjsjann123 Sep 18, 2024
e8b4534
Merge branch 'main' into slice_python_api
jjsjann123 Sep 18, 2024
821a30a
Merge remote-tracking branch 'origin/main' into HEAD
jjsjann123 Sep 19, 2024
cbc819d
errr resolve conflicts
jjsjann123 Sep 19, 2024
e1dd9b0
Merge branch 'main' into slice_python_api
jjsjann123 Sep 19, 2024
eecb024
fixing CI: 1. bump version; 2. update assert exception string
jjsjann123 Sep 19, 2024
cd95c02
fixing missing header
jjsjann123 Sep 19, 2024
1e92624
because I'm a dumb dumb...
jjsjann123 Sep 19, 2024
cb65c77
dumb dumb again
jjsjann123 Sep 20, 2024
88ae47f
error message
jjsjann123 Sep 20, 2024
8ffa26d
fixing test error message
jjsjann123 Sep 20, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion csrc/python_frontend/fusion_cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ bool TrieNode::isTerminal() const {
return (record.get()->recordType() == serde::RecordType::End);
}

void TrieNode::markException(std::exception e) {
void TrieNode::setException(const char* e) {
std::lock_guard<std::mutex> guard(trie_node_lock);
exception = e;
}
Expand Down
4 changes: 2 additions & 2 deletions csrc/python_frontend/fusion_cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ struct TrieNode {
// a the end of Fusion entry in the cache.
bool isTerminal() const;
std::optional<std::exception> getException();
void markException(std::exception e);
void setException(const char* e);
//! Serialize TrieNode using flatbuffers
NVF_API flatbuffers::Offset<serde::TrieNode> serialize(
flatbuffers::FlatBufferBuilder& builder,
Expand All @@ -127,7 +127,7 @@ struct TrieNode {
TrieNode* parent;
//! For thread-Safe locking of a node
std::mutex trie_node_lock;
std::optional<std::exception> exception;
std::optional<std::string> exception = std::nullopt;
};

//! \class FusionCache
Expand Down
3 changes: 1 addition & 2 deletions csrc/python_frontend/fusion_definition.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,8 @@ void FusionDefinition::finalizeDefinition() {

buildFusionIr(preschedFusion());
} catch (const std::exception& e) {
trie_node_->markException(e);
trie_node_->setException(e.what());
fusion_id_ = std::nullopt;
std::cout << "error: " << e.what() << std::endl;
throw;
}

Expand Down
28 changes: 15 additions & 13 deletions csrc/python_frontend/fusion_record.h
Original file line number Diff line number Diff line change
Expand Up @@ -395,19 +395,21 @@ struct SliceOpRecord : RecordFunctor {
Val* start_idx = start.at(idx);
Val* end_idx = end.at(idx);
Val* stride_idx = stride.at(idx);
NVF_CHECK(
!start_idx->isConstInt() || start_idx->evaluate().as<int64_t>() >= 0,
"Slice operation start_indices must be greater-than-or-equal-to 0. Start Indices: ",
start_idx->evaluate().as<int64_t>());
NVF_CHECK(
!start_idx->isConstInt() || !end_idx->isConstInt() || end_idx->evaluate().as<int64_t>() >= start_idx->evaluate().as<int64_t>(),
"Slice operation end_indices must be greater-than-or-equal-to start_indices. Start Indices: ",
start_idx->evaluate().as<int64_t>(),
" End Indices: ",
end_idx->evaluate().as<int64_t>());
NVF_CHECK(
stride_idx->isConstInt() && stride_idx->evaluate().as<int64_t>() == 1,
"nvFuser Limitation: All slice operation strides must be of const size 1");
NVF_CHECK(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These checks are now inside finalizeDefinition instead of on the constructor. Hence exposes the issue on our FusionCache #2953

!start_idx->isConstInt() || start_idx->evaluate().as<int64_t>() >= 0,
"Slice operation start_indices must be greater-than-or-equal-to 0. Start Indices: ",
jacobhinkle marked this conversation as resolved.
Show resolved Hide resolved
start_idx->evaluate().as<int64_t>());
NVF_CHECK(
!start_idx->isConstInt() || !end_idx->isConstInt() ||
end_idx->evaluate().as<int64_t>() >=
start_idx->evaluate().as<int64_t>(),
"Slice operation end_indices must be greater-than-or-equal-to start_indices. Start Indices: ",
start_idx->evaluate().as<int64_t>(),
" End Indices: ",
end_idx->evaluate().as<int64_t>());
NVF_CHECK(
stride_idx->isConstInt() && stride_idx->evaluate().as<int64_t>() == 1,
"nvFuser Limitation: All slice operation strides must be of const size 1");
vec_slice.push_back({start_idx, end_idx, stride_idx});
}
auto output = slice(arg, vec_slice);
Expand Down
14 changes: 10 additions & 4 deletions csrc/python_frontend/python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,15 @@ Vector define_vector_explicit_fn(
FusionDefinition& self,
ITERABLE& values,
PrimDataType dtype = DataType::Int) {
return define_vector_fn<ITERABLE>(self, values, /*inline_def=*/false, /*shape_check=*/true);
return define_vector_fn<ITERABLE>(
self, values, /*inline_def=*/false, /*shape_check=*/true);
}

template <class ShapeType>
Vector SequenceAsVector(ShapeType shape, FusionDefinition& fd, bool shape_check=true) {
Vector SequenceAsVector(
ShapeType shape,
FusionDefinition& fd,
bool shape_check = true) {
static_assert(
std::is_same_v<ShapeType, Vector> ||
std::is_same_v<ShapeType, py::list> ||
Expand All @@ -122,7 +126,8 @@ Vector SequenceAsVector(ShapeType shape, FusionDefinition& fd, bool shape_check=
// ```
// would not work because the compiler would try to instantiate
// define_vector_fn<Vector> and fail.
return define_vector_fn<ShapeType>(fd, shape, /*inline_def=*/true, /*shape_check=*/shape_check);
return define_vector_fn<ShapeType>(
fd, shape, /*inline_def=*/true, /*shape_check=*/shape_check);
}
}

Expand Down Expand Up @@ -251,7 +256,8 @@ Tensor slice_fn(
size_t stride_index = 0;

if (strides.has_value()) {
Vector new_stride = SequenceAsVector(strides.value(), *fd, /*shape_check=*/false);
Vector new_stride =
SequenceAsVector(strides.value(), *fd, /*shape_check=*/false);
NVF_CHECK(
new_start.size == new_stride.size,
"Slice start_indices and strides don't match! Start Indices: ",
Expand Down
Loading