Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Thresh_as_int parameter added #42

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions include/tl2cgen/compiler_param.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ struct CompilerParam {
compilation time and reduce memory consumption during
compilation. */
int parallel_comp{0};
/*! \brief Wether to interpret threshold points as integers (0: no, >0: yes) */
bool thresh_as_int{false};
/*! \brief If >0, produce extra messages */
int verbose{0};
/*! \brief Native lib name (without extension) */
Expand Down
4 changes: 3 additions & 1 deletion include/tl2cgen/detail/compiler/ast/ast.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,15 +94,17 @@ class NumericalConditionNode : public ConditionNode {
public:
using ThresholdVariantT = std::variant<float, double>;
NumericalConditionNode(std::uint32_t split_index, bool default_left, treelite::Operator op,
ThresholdVariantT threshold, std::optional<int> quantized_threshold)
ThresholdVariantT threshold, std::optional<int> quantized_threshold, bool thresh_as_int)
: ConditionNode(split_index, default_left),
op_(op),
threshold_(threshold),
quantized_threshold_(quantized_threshold),
thresh_as_int_(thresh_as_int),
zero_quantized_(-1) {}
treelite::Operator op_;
ThresholdVariantT threshold_;
std::optional<int> quantized_threshold_;
bool thresh_as_int_;
int zero_quantized_; // quantized value of 0.0f (useful when convert_missing_to_zero is set)
std::string GetDump() const override;
};
Expand Down
4 changes: 2 additions & 2 deletions include/tl2cgen/detail/compiler/ast/builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class ASTBuilder {
ASTBuilder() : main_node_(nullptr) {}

/* \brief Initially build AST from model */
void BuildAST(treelite::Model const& model);
void BuildAST(treelite::Model const& model, bool thresh_as_int);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding the new option thresh_as_int is fine for now. But if we were to add more options to BuildAST, we should consider passing around the CompilerParam struct instead, to avoid having too many function arguments.

/* \brief Generate is_categorical[] array, which tells whether each feature
is categorical or numerical */
void GenerateIsCategoricalArray();
Expand Down Expand Up @@ -69,7 +69,7 @@ class ASTBuilder {
template <typename ThresholdType, typename LeafOutputType>
ASTNode* BuildASTFromTree(ASTNode* parent,
treelite::Tree<ThresholdType, LeafOutputType> const& tree, int tree_id,
std::int32_t target_id, std::int32_t class_id, int nid);
std::int32_t target_id, std::int32_t class_id, int nid, bool thresh_as_int);

// Keep tract of all nodes built so far, to prevent memory leak
std::vector<std::unique_ptr<ASTNode>> nodes_;
Expand Down
16 changes: 8 additions & 8 deletions src/compiler/ast/build.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ std::optional<std::vector<std::int32_t>> ComputeAverageFactor(treelite::Model co

namespace tl2cgen::compiler::detail::ast {

void ASTBuilder::BuildAST(treelite::Model const& model) {
void ASTBuilder::BuildAST(treelite::Model const& model, bool thresh_as_int) {
main_node_ = AddNode<MainNode>(
nullptr, model.base_scores.AsVector(), ComputeAverageFactor(model), model.postprocessor);
meta_.num_target_ = model.num_target;
Expand All @@ -72,7 +72,7 @@ void ASTBuilder::BuildAST(treelite::Model const& model) {
[&](auto&& model_preset) {
for (std::size_t tree_id = 0; tree_id < model_preset.trees.size(); ++tree_id) {
ASTNode* tree_head = BuildASTFromTree(func, model_preset.trees[tree_id],
static_cast<int>(tree_id), model.target_id[tree_id], model.class_id[tree_id], 0);
static_cast<int>(tree_id), model.target_id[tree_id], model.class_id[tree_id], 0, thresh_as_int);
func->children_.push_back(tree_head);
}
using ModelPresetT = std::remove_const_t<std::remove_reference_t<decltype(model_preset)>>;
Expand All @@ -97,7 +97,7 @@ void ASTBuilder::BuildAST(treelite::Model const& model) {
template <typename ThresholdType, typename LeafOutputType>
ASTNode* ASTBuilder::BuildASTFromTree(ASTNode* parent,
treelite::Tree<ThresholdType, LeafOutputType> const& tree, int tree_id, std::int32_t target_id,
std::int32_t class_id, int nid) {
std::int32_t class_id, int nid, bool thresh_as_int) {
ASTNode* ast_node = nullptr;
if (tree.IsLeaf(nid)) {
if (meta_.leaf_vector_shape_[0] == 1 && meta_.leaf_vector_shape_[1] == 1) {
Expand All @@ -109,7 +109,7 @@ ASTNode* ASTBuilder::BuildASTFromTree(ASTNode* parent,
} else {
if (tree.NodeType(nid) == treelite::TreeNodeType::kNumericalTestNode) {
ast_node = AddNode<NumericalConditionNode>(parent, tree.SplitIndex(nid),
tree.DefaultLeft(nid), tree.ComparisonOp(nid), tree.Threshold(nid), std::nullopt);
tree.DefaultLeft(nid), tree.ComparisonOp(nid), tree.Threshold(nid), std::nullopt, thresh_as_int);
} else {
ast_node = AddNode<CategoricalConditionNode>(parent, tree.SplitIndex(nid),
tree.DefaultLeft(nid), tree.CategoryList(nid), tree.CategoryListRightChild(nid));
Expand All @@ -118,9 +118,9 @@ ASTNode* ASTBuilder::BuildASTFromTree(ASTNode* parent,
dynamic_cast<ConditionNode*>(ast_node)->gain_ = tree.Gain(nid);
}
ast_node->children_.push_back(
BuildASTFromTree(ast_node, tree, tree_id, target_id, class_id, tree.LeftChild(nid)));
BuildASTFromTree(ast_node, tree, tree_id, target_id, class_id, tree.LeftChild(nid), thresh_as_int));
ast_node->children_.push_back(
BuildASTFromTree(ast_node, tree, tree_id, target_id, class_id, tree.RightChild(nid)));
BuildASTFromTree(ast_node, tree, tree_id, target_id, class_id, tree.RightChild(nid), thresh_as_int));
}
ast_node->node_id_ = nid;
ast_node->tree_id_ = tree_id;
Expand All @@ -135,8 +135,8 @@ ASTNode* ASTBuilder::BuildASTFromTree(ASTNode* parent,
}

template ASTNode* ASTBuilder::BuildASTFromTree(
ASTNode*, treelite::Tree<float, float> const&, int, std::int32_t, std::int32_t, int);
ASTNode*, treelite::Tree<float, float> const&, int, std::int32_t, std::int32_t, int, bool);
template ASTNode* ASTBuilder::BuildASTFromTree(
ASTNode*, treelite::Tree<double, double> const&, int, std::int32_t, std::int32_t, int);
ASTNode*, treelite::Tree<double, double> const&, int, std::int32_t, std::int32_t, int, bool);

} // namespace tl2cgen::compiler::detail::ast
48 changes: 44 additions & 4 deletions src/compiler/codegen/condition_node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,55 @@ std::string GetFabsCFunc(std::string const& threshold_type) {
}
}

std::string float_to_bin(float number) {
std::stringstream ss;
ss << "0x" << std::hex << std::setw(8) << std::setfill('0') << *(reinterpret_cast<unsigned int*>(&number));
return ss.str();
}

std::string double_to_bin(double number) {
std::stringstream ss;
ss << "0x" << std::hex << std::setw(16) << std::setfill('0') << *(reinterpret_cast<unsigned long long*>(&number));
return ss.str();
}

std::string getOppositeOperator(const std::string& op) {
if (op == "<") return ">=";
else if (op == ">") return "<=";
else if (op == "<=") return ">";
else if (op == ">=") return "<";
else return ""; // unknown operator
}

std::string thresh_as_int(const std::string& threshold_type, ast::NumericalConditionNode const* node) {
std::string negatstring = "";
float splitval = std::get<float>(node->threshold_);
std::string op = treelite::OperatorToString(node->op_);
std::string new_dtype = (threshold_type == "double") ? "long long" : "int";

if (splitval < 0) {
splitval = -splitval;
negatstring = " ^ (0b1 << 31)";
op = getOppositeOperator(op);
}

std::string split_val_bin = float_to_bin(splitval);
return "(*( (("+new_dtype+"*)(data)) + "+std::to_string(node->split_index_)+" )"
+negatstring+")"+op+"(("+new_dtype+")("+split_val_bin+"))";
Comment on lines +70 to +71
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's use fmt::format to format strings for the sake of legibility. It works like f-strings in Python.

Suggested change
return "(*( (("+new_dtype+"*)(data)) + "+std::to_string(node->split_index_)+" )"
+negatstring+")"+op+"(("+new_dtype+")("+split_val_bin+"))";
return fmt::format("(*( (({new_dtype}*)(data)) + {split_index} ) {negatstring} ) {op} (( {new_dtype} )( {split_val_bin} ))",
"new_dtype"_a = new_dtype, "split_index"_a = node->split_index_,
"negatstring"_a = negatstring, "op"_a = op,
"new_dtype"_a = new_dtype, "split_val_bin"_a = split_val_bin);

}

inline std::string ExtractNumericalCondition(ast::NumericalConditionNode const* node) {
std::string const threshold_type = codegen::GetThresholdCType(node);
std::string result;
if (node->quantized_threshold_) { // Quantized threshold
std::string lhs
= fmt::format("data[{split_index}].qvalue", "split_index"_a = node->split_index_);
std::string lhs = fmt::format("data[{split_index}].qvalue", "split_index"_a = node->split_index_);
result = fmt::format("{lhs} {opname} {threshold}", "lhs"_a = lhs,
"opname"_a = treelite::OperatorToString(node->op_),
"threshold"_a = *node->quantized_threshold_);
"opname"_a = treelite::OperatorToString(node->op_), "threshold"_a = *node->quantized_threshold_);
} else if (node->thresh_as_int_) { // Threshold as integer
if (!(threshold_type == "float" || threshold_type == "double")) { // Only float and double are supported
throw std::runtime_error("Invalid threshold type.");
}
result = thresh_as_int(threshold_type, node);
} else {
result = std::visit(
[&](auto&& threshold) -> std::string {
Expand Down
2 changes: 1 addition & 1 deletion src/compiler/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ detail::ast::ASTBuilder LowerToAST(
treelite::Model const& model, tl2cgen::compiler::CompilerParam const& param) {
/* 1. Lower the tree ensemble model into Abstract Syntax Tree (AST) */
detail::ast::ASTBuilder builder;
builder.BuildAST(model);
builder.BuildAST(model, param.thresh_as_int);

/* 2. Apply optimization passes to AST */
if (param.annotate_in != "NULL") {
Expand Down
4 changes: 4 additions & 0 deletions src/compiler/compiler_param.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ CompilerParam CompilerParam::ParseFromJSON(char const* param_json_str) {
TL2CGEN_CHECK(e.value.IsInt()) << "Expected an integer for 'quantize'";
param.quantize = e.value.GetInt();
TL2CGEN_CHECK_GE(param.quantize, 0) << "'quantize' must be 0 or greater";
} else if (key == "thresh_as_int") {
TL2CGEN_CHECK(e.value.IsInt()) << "Expected an integer for 'thresh_as_int'";
param.thresh_as_int = e.value.GetInt();
TL2CGEN_CHECK_GE(param.quantize, 0) << "'thresh_as_int' must be 0 or greater";
} else if (key == "parallel_comp") {
TL2CGEN_CHECK(e.value.IsInt()) << "Expected an integer for 'parallel_comp'";
param.parallel_comp = e.value.GetInt();
Expand Down
Loading