From a002e38d3603a9146155040dde70fbdb5cbe22e7 Mon Sep 17 00:00:00 2001 From: AntoinePrv Date: Wed, 28 Jul 2021 14:59:32 -0400 Subject: [PATCH] Code cleanup --- .../src/observation/test-nodebipartite.cpp | 12 ++++------- python/src/ecole/core/observation.cpp | 20 +++++++++---------- 2 files changed, 14 insertions(+), 18 deletions(-) diff --git a/libecole/tests/src/observation/test-nodebipartite.cpp b/libecole/tests/src/observation/test-nodebipartite.cpp index 20d115c8..d02ef310 100644 --- a/libecole/tests/src/observation/test-nodebipartite.cpp +++ b/libecole/tests/src/observation/test-nodebipartite.cpp @@ -44,16 +44,12 @@ TEST_CASE("NodeBipartite return correct observation", "[obs]") { } SECTION("Variable features are not all nan") { - auto const& var_feat = optional_obs.value().variable_features; - for (std::size_t i = 0; i < var_feat.shape()[1]; ++i) { - REQUIRE_FALSE(xt::all(xt::isnan(xt::col(var_feat, static_cast(i))))); - } + auto const& obs = optional_obs.value(); + REQUIRE_FALSE(xt::all(xt::isnan(obs.variable_features))); } SECTION("Row features are not all nan") { - auto const& row_feat = optional_obs.value().row_features; - for (std::size_t i = 0; i < row_feat.shape()[1]; ++i) { - REQUIRE_FALSE(xt::all(xt::isnan(xt::col(row_feat, static_cast(i))))); - } + auto const& obs = optional_obs.value(); + REQUIRE_FALSE(xt::all(xt::isnan(obs.row_features))); } } diff --git a/python/src/ecole/core/observation.cpp b/python/src/ecole/core/observation.cpp index 4c5a9e38..c9bb937f 100644 --- a/python/src/ecole/core/observation.cpp +++ b/python/src/ecole/core/observation.cpp @@ -85,7 +85,7 @@ void bind_submodule(py::module_ const& m) { The optimization problem is represented as an heterogenous bipartite graph. On one side, a node is associated with one variable, on the other side a node is - associated with one constraint. + associated with one LP row. There exist an edge between a variable and a constraint if the variable exists in the constraint with a non-zero coefficient. @@ -236,13 +236,13 @@ void bind_submodule(py::module_ const& m) { Strong branching score observation function on branch-and bound node. This observation obtains scores for all LP or pseudo candidate variables at a - branch-and-bound node. The strong branching score measures the quality of branching - for each variable. This observation can be used as an expert for imitation - learning algorithms. + branch-and-bound node. + The strong branching score measures the quality of branching for each variable. + This observation can be used as an expert for imitation learning algorithms. This observation function extracts an array containing the strong branching score for - each variable in the problem which can be indexed by the action set. Variables for which - a strong branching score is not applicable are filled with NaN. + each variable in the problem which can be indexed by the action set. + Variables for which a strong branching score is not applicable are filled with ``NaN``. )"); strong_branching_scores.def(py::init(), py::arg("pseudo_candidates") = true, R"( Constructor for StrongBranchingScores. @@ -269,8 +269,8 @@ void bind_submodule(py::module_ const& m) { pseudocost branching (also known as hybrid branching). This observation function extracts an array containing the pseudocost for - each variable in the problem which can be indexed by the action set. Variables for which - a pseudocost is not applicable are filled with NaN. + each variable in the problem which can be indexed by the action set. + Variables for which a pseudocost is not applicable are filled with ``NaN``. )"); pseudocosts.def(py::init<>()); def_before_reset(pseudocosts, R"(Do nothing.)"); @@ -283,8 +283,8 @@ void bind_submodule(py::module_ const& m) { The observation is a matrix where rows represent all variables and columns represent features related to these variables. - Only rows representing pseudo branching candidate contain meaningful observation, other rows are filled with - ``NaN``. + Only rows representing pseudo branching candidate contain meaningful observation, other rows are filled + with ``NaN``. See [Khalil2016]_ for a complete reference on this observation function. The first :py:attr:`Khalil2016Obs.n_static_features` are static (they do not change through the solving