Skip to content

Commit

Permalink
Add deprecated NodeBipartite column features
Browse files Browse the repository at this point in the history
  • Loading branch information
AntoinePrv committed Sep 9, 2021
1 parent 3f37f99 commit f264f41
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 0 deletions.
18 changes: 18 additions & 0 deletions python/src/ecole/core/observation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,18 @@ void bind_submodule(py::module_ const& m) {
"row_features",
&NodeBipartiteObs::row_features,
"A matrix where each row is represents a constraint, and each column a feature of the constraints.")
// FIXME remove in version >0.8
.def_property(
"column_features",
[](py::handle self) {
PyErr_WarnEx(PyExc_DeprecationWarning, "column_features is deprecated, use variable_features.", 1);
return self.attr("variable_features");
},
[](py::handle self, py::handle const val) {
PyErr_WarnEx(PyExc_DeprecationWarning, "column_features is deprecated, use variable_features.", 1);
self.attr("variable_features") = val;
},
"A matrix where each row is represents a variable, and each column a feature of the variables.")
.def_readwrite(
"edge_features",
&NodeBipartiteObs::edge_features,
Expand Down Expand Up @@ -129,6 +141,12 @@ void bind_submodule(py::module_ const& m) {
.value("is_basis_upper", NodeBipartiteObs::VariableFeatures::is_basis_upper)
.value("is_basis_zero", NodeBipartiteObs::VariableFeatures ::is_basis_zero);

// FIXME remove in Ecole >0.8
node_bipartite_obs.def_property_readonly_static("ColumnFeatures", [](py::handle self) {
PyErr_WarnEx(PyExc_DeprecationWarning, "ColumnFeatures is deprecated, use VariableFeatures.", 1);
return self.attr("VariableFeatures");
});

py::enum_<NodeBipartiteObs::RowFeatures>(node_bipartite_obs, "RowFeatures")
.value("bias", NodeBipartiteObs::RowFeatures::bias)
.value("objective_cosine_similarity", NodeBipartiteObs::RowFeatures::objective_cosine_similarity)
Expand Down
6 changes: 6 additions & 0 deletions python/tests/test_observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ def test_Nothing_observation(model):
assert make_obs(ecole.observation.Nothing(), model) is None


# FIXME remove in Ecole >0.8
@pytest.mark.filterwarnings("ignore::DeprecationWarning")
def test_NodeBipartite_observation(model):
"""Observation of NodeBipartite is a type with array attributes."""
obs = make_obs(ecole.observation.NodeBipartite(), model)
Expand All @@ -100,6 +102,10 @@ def test_NodeBipartite_observation(model):
assert len(obs.VariableFeatures.__members__) == obs.variable_features.shape[1]
assert len(obs.RowFeatures.__members__) == obs.row_features.shape[1]

# FIXME remove in Ecole >0.8
assert_array(obs.column_features, ndim=2)
assert len(obs.ColumnFeatures.__members__) == obs.variable_features.shape[1]


def test_MilpBipartite_observation(model):
"""Observation of MilpBipartite is a type with array attributes."""
Expand Down

0 comments on commit f264f41

Please sign in to comment.