Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Slice python api #2932

Open
wants to merge 51 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 32 commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
93328b8
WIP
jjsjann123 Sep 11, 2024
77e04da
WIP
jjsjann123 Sep 11, 2024
0dda7a9
WIP
jjsjann123 Sep 11, 2024
2829743
fixing build
jjsjann123 Sep 11, 2024
64cfc8c
fixing build
jjsjann123 Sep 11, 2024
227e2b5
fixing build
jjsjann123 Sep 11, 2024
ae3c2d0
quick fix on kwargs
jjsjann123 Sep 11, 2024
c42bfe1
remove option thing
jjsjann123 Sep 11, 2024
266a4bb
typo
jjsjann123 Sep 11, 2024
9c9ea9a
Merge branch 'main' into slice_python_api
jjsjann123 Sep 16, 2024
43bfaec
CLANGFORMAT
jjsjann123 Sep 16, 2024
2e1fac5
fixing it?!
jjsjann123 Sep 16, 2024
6f011c6
typo
jjsjann123 Sep 16, 2024
ef56b1b
fixing test/build
jjsjann123 Sep 16, 2024
2be6beb
fixing logic
jjsjann123 Sep 16, 2024
69975f1
clangformat
jjsjann123 Sep 16, 2024
53ad605
some error check/message
jjsjann123 Sep 16, 2024
67ed3e4
fixing build, avoiding check
jjsjann123 Sep 16, 2024
4f15739
fixing error message
jjsjann123 Sep 16, 2024
804ec36
quick hack on exception
jjsjann123 Sep 16, 2024
691198d
wip
jjsjann123 Sep 17, 2024
7ee8f26
Merge branch 'main' into slice_python_api
jjsjann123 Sep 17, 2024
e15be6f
clangformat; quick fix on error message
jjsjann123 Sep 17, 2024
128aee5
fixing build
jjsjann123 Sep 17, 2024
a890311
fixing error message
jjsjann123 Sep 17, 2024
e9da45b
removing print
jjsjann123 Sep 17, 2024
19aa2e0
allow fusion cache to cache error message
jjsjann123 Sep 17, 2024
88f81dd
test added
jjsjann123 Sep 17, 2024
d389d0a
black
jjsjann123 Sep 17, 2024
0e72307
revert changes in #2953
jjsjann123 Sep 17, 2024
911d7c1
Merge remote-tracking branch 'origin/fusion_cache_error_cache' into HEAD
jjsjann123 Sep 17, 2024
175c82c
black
jjsjann123 Sep 17, 2024
8870702
Merge branch 'main' into fusion_cache_error_cache
jjsjann123 Sep 18, 2024
cc41d1d
comment message
jjsjann123 Sep 18, 2024
ecc1316
bumping version
jjsjann123 Sep 18, 2024
b872194
fixing deserialization with exception handling
jjsjann123 Sep 18, 2024
35d9fac
fixing typo
jjsjann123 Sep 18, 2024
dc610d4
BLACK
jjsjann123 Sep 18, 2024
0916e1a
Merge remote-tracking branch 'origin/fusion_cache_error_cache' into HEAD
jjsjann123 Sep 18, 2024
f2c511a
Merge remote-tracking branch 'origin/multidevice_typo_fix' into HEAD
jjsjann123 Sep 18, 2024
295b56c
review comments
jjsjann123 Sep 18, 2024
e8b4534
Merge branch 'main' into slice_python_api
jjsjann123 Sep 18, 2024
821a30a
Merge remote-tracking branch 'origin/main' into HEAD
jjsjann123 Sep 19, 2024
cbc819d
errr resolve conflicts
jjsjann123 Sep 19, 2024
e1dd9b0
Merge branch 'main' into slice_python_api
jjsjann123 Sep 19, 2024
eecb024
fixing CI: 1. bump version; 2. update assert exception string
jjsjann123 Sep 19, 2024
cd95c02
fixing missing header
jjsjann123 Sep 19, 2024
1e92624
because I'm a dumb dumb...
jjsjann123 Sep 19, 2024
cb65c77
dumb dumb again
jjsjann123 Sep 20, 2024
88ae47f
error message
jjsjann123 Sep 20, 2024
8ffa26d
fixing test error message
jjsjann123 Sep 20, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions csrc/python_frontend/fusion_cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,16 @@ bool TrieNode::isTerminal() const {
return (record.get()->recordType() == serde::RecordType::End);
}

void TrieNode::setException(const char* e) {
std::lock_guard<std::mutex> guard(trie_node_lock);
exception = e;
}

std::optional<std::string> TrieNode::getException() {
std::lock_guard<std::mutex> guard(trie_node_lock);
return exception;
}

flatbuffers::Offset<serde::TrieNode> TrieNode::serialize(
flatbuffers::FlatBufferBuilder& builder,
const std::map<RecordFunctor*, size_t>&
Expand Down
3 changes: 3 additions & 0 deletions csrc/python_frontend/fusion_cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ struct TrieNode {
// Queries whether the entry denotes a leaf node which also represents
// a the end of Fusion entry in the cache.
bool isTerminal() const;
std::optional<std::string> getException();
void setException(const char* e);
//! Serialize TrieNode using flatbuffers
NVF_API flatbuffers::Offset<serde::TrieNode> serialize(
flatbuffers::FlatBufferBuilder& builder,
Expand All @@ -125,6 +127,7 @@ struct TrieNode {
TrieNode* parent;
//! For thread-Safe locking of a node
std::mutex trie_node_lock;
std::optional<std::string> exception = std::nullopt;
};

//! \class FusionCache
Expand Down
18 changes: 13 additions & 5 deletions csrc/python_frontend/fusion_definition.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,13 +87,19 @@ void FusionDefinition::finalizeDefinition() {
}
trie_node_ = fusionCache()->createChild(trie_node_, end_record_.get());
fusion_id_ = std::optional<size_t>(trie_node_->fusion_id);
NVF_CHECK(id().has_value(), "Invalid fusion id!");
try {
NVF_CHECK(id().has_value(), "Invalid fusion id!");

if (isDebugDumpEnabled(DebugDumpOption::PythonDefinition)) {
print(debug());
}
if (isDebugDumpEnabled(DebugDumpOption::PythonDefinition)) {
print(debug());
}

buildFusionIr(preschedFusion());
buildFusionIr(preschedFusion());
} catch (const std::exception& e) {
trie_node_->setException(e.what());
fusion_id_ = std::nullopt;
throw;
}

if (isDebugDumpEnabled(DebugDumpOption::FusionIrOriginal)) {
printIr();
Expand All @@ -103,6 +109,8 @@ void FusionDefinition::finalizeDefinition() {
debug() << "\nFusionDefinition: Terminal Node found!\n";
}
trie_node_ = child_node.value();
std::optional<std::string> opt_e = trie_node_->getException();
NVF_CHECK(!opt_e.has_value(), opt_e.value());
fusion_id_ = std::optional<size_t>(trie_node_->fusion_id);
}

Expand Down
170 changes: 51 additions & 119 deletions csrc/python_frontend/fusion_record.h
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,57 @@ struct OpRecord : RecordFunctor {
std::function<OutType(ArgTypes...)> fusion_op_;
};

struct SliceOpRecord : RecordFunctor {
SliceOpRecord(std::vector<State> _args, std::vector<State> _outputs)
: RecordFunctor(
std::move(_args),
std::move(_outputs),
"ops.slice",
serde::RecordType::SliceOp) {
arg_names_[1] = "start_indices";
arg_names_[2] = "end_indices";
arg_names_[3] = "strides";
}
~SliceOpRecord() override = default;
RecordFunctor* clone() final {
return new SliceOpRecord(*this);
}

void operator()(FusionState& fd) final {
TensorView* arg = fd.getFusionState(args_.at(0).index)->as<TensorView>();
const std::vector<Val*>& start = fd.getFusionStateVector(args_.at(1).index);
const std::vector<Val*>& end = fd.getFusionStateVector(args_.at(2).index);
const std::vector<Val*>& stride =
fd.getFusionStateVector(args_.at(3).index);
std::vector<Slice> vec_slice;
for (const auto idx : c10::irange(arg->nDims())) {
// NOTE: there's an extra move, we can use emplace_back if we go write
// some constructors for Slice.
jacobhinkle marked this conversation as resolved.
Show resolved Hide resolved
Val* start_idx = start.at(idx);
Val* end_idx = end.at(idx);
Val* stride_idx = stride.at(idx);
NVF_CHECK(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These checks are now inside finalizeDefinition instead of on the constructor. Hence exposes the issue on our FusionCache #2953

!start_idx->isConstInt() || start_idx->evaluate().as<int64_t>() >= 0,
"Slice operation start_indices must be greater-than-or-equal-to 0. Start Indices: ",
jacobhinkle marked this conversation as resolved.
Show resolved Hide resolved
start_idx->evaluate().as<int64_t>());
NVF_CHECK(
!start_idx->isConstInt() || !end_idx->isConstInt() ||
end_idx->evaluate().as<int64_t>() >=
start_idx->evaluate().as<int64_t>(),
"Slice operation end_indices must be greater-than-or-equal-to start_indices. Start Indices: ",
start_idx->evaluate().as<int64_t>(),
" End Indices: ",
end_idx->evaluate().as<int64_t>());
NVF_CHECK(
stride_idx->isConstInt() && stride_idx->evaluate().as<int64_t>() == 1,
"nvFuser Limitation: All slice operation strides must be of const size 1");
vec_slice.push_back({start_idx, end_idx, stride_idx});
}
auto output = slice(arg, vec_slice);
fd.setFusionState(outputs_.at(0).index, output);
}
};

struct ReshapeOpRecord : RecordFunctor {
ReshapeOpRecord(std::vector<State> _args, std::vector<State> _outputs)
: RecordFunctor(
Expand Down Expand Up @@ -1969,125 +2020,6 @@ struct ScalarRecord : RecordFunctor {
PrimDataType dtype_;
};

struct SliceOpRecord : RecordFunctor {
SliceOpRecord(
std::vector<State> _args,
std::vector<State> _outputs,
std::vector<int64_t> start_indices,
std::vector<int64_t> end_indices,
std::vector<int64_t> strides)
: RecordFunctor(
std::move(_args),
std::move(_outputs),
"ops.slice",
serde::RecordType::SliceOp),
start_indices_(std::move(start_indices)),
end_indices_(std::move(end_indices)),
strides_(std::move(strides)) {}
~SliceOpRecord() override = default;
RecordFunctor* clone() final {
return new SliceOpRecord(*this);
}

//! Child specific hash function in lower 32 bits.
//! | 31 -------- 20 | 19 -------- 8 | 7 ------ 0 |
//! | start_indices | end_indices | strides |
size_t hash() const final {
auto result = RecordFunctor::hash();
size_t start_idx_hash = 0;
for (auto i : start_indices_) {
start_idx_hash ^= static_cast<size_t>(i);
}
size_t end_idx_hash = 0;
for (auto i : end_indices_) {
end_idx_hash ^= static_cast<size_t>(i);
}
size_t stride_hash = 0;
for (auto i : strides_) {
stride_hash ^= static_cast<size_t>(i);
}

result |= (start_idx_hash & 0xfff) << 20;
result |= (end_idx_hash & 0xfff) << 8;
return result | (stride_hash & 0xff);
}

bool operator==(const RecordFunctor& other) const final {
auto result = false;
if (auto child_ptr = dynamic_cast<const SliceOpRecord*>(&other)) {
result = RecordFunctor::operator==(other) &&
(start_indices_ == child_ptr->start_indices_) &&
(end_indices_ == child_ptr->end_indices_) &&
(strides_ == child_ptr->strides_);
}
return result;
}

void operator()(FusionState& fd) final {
auto arg = fd.getFusionState(args_.at(0).index)->as<TensorView>();
TensorView* output = slice(arg, start_indices_, end_indices_, strides_);
fd.setFusionState(outputs_.at(0).index, output);
}

void print(std::ostream& os, bool close_function = true) const final {
RecordFunctor::print(os, false);
os << ", start_indices=[";
bool first_arg = true;
for (auto idx : start_indices_) {
if (first_arg) {
first_arg = false;
} else {
os << ", ";
}
os << idx;
}
os << "], end_indices=[";
first_arg = true;
for (auto idx : end_indices_) {
if (first_arg) {
first_arg = false;
} else {
os << ", ";
}
os << idx;
}
os << "], strides=[";
first_arg = true;
for (auto stride : strides_) {
if (first_arg) {
first_arg = false;
} else {
os << ", ";
}
os << stride;
}
os << "]";
if (close_function) {
os << ")";
}
}

std::pair<serde::RecordData, flatbuffers::Offset<void>> recordData(
flatbuffers::FlatBufferBuilder& builder) const final {
return {
serde::RecordData::Slice,
serde::CreateSliceDirect(
builder, &start_indices_, &end_indices_, &strides_)
.Union()};
}

private:
//! A slices beginning index for each dimension
//! Values must be greater-than or equal to 0
std::vector<int64_t> start_indices_;
//! A slices end index for each dimension (excluded from the slice)
//! Values are greater than or equal to the start index for a dimension
std::vector<int64_t> end_indices_;
//! For a dim, the step between start and end.
//! NOTE: Strides are currently limited to steps of 1
std::vector<int64_t> strides_;
};

//! Specialized Record Functor for recording FusionDefinition Start.
//! There should only ever be one instance of this Record in the
//! Fusion Cache.
Expand Down
Loading
Loading