From f264f4139650ee8536c575a344ef803dbab182a8 Mon Sep 17 00:00:00 2001 From: AntoinePrv Date: Tue, 27 Jul 2021 17:09:22 -0400 Subject: [PATCH] Add deprecated NodeBipartite column features --- python/src/ecole/core/observation.cpp | 18 ++++++++++++++++++ python/tests/test_observation.py | 6 ++++++ 2 files changed, 24 insertions(+) diff --git a/python/src/ecole/core/observation.cpp b/python/src/ecole/core/observation.cpp index 4ec0af1c..4c5a9e38 100644 --- a/python/src/ecole/core/observation.cpp +++ b/python/src/ecole/core/observation.cpp @@ -102,6 +102,18 @@ void bind_submodule(py::module_ const& m) { "row_features", &NodeBipartiteObs::row_features, "A matrix where each row is represents a constraint, and each column a feature of the constraints.") + // FIXME remove in version >0.8 + .def_property( + "column_features", + [](py::handle self) { + PyErr_WarnEx(PyExc_DeprecationWarning, "column_features is deprecated, use variable_features.", 1); + return self.attr("variable_features"); + }, + [](py::handle self, py::handle const val) { + PyErr_WarnEx(PyExc_DeprecationWarning, "column_features is deprecated, use variable_features.", 1); + self.attr("variable_features") = val; + }, + "A matrix where each row is represents a variable, and each column a feature of the variables.") .def_readwrite( "edge_features", &NodeBipartiteObs::edge_features, @@ -129,6 +141,12 @@ void bind_submodule(py::module_ const& m) { .value("is_basis_upper", NodeBipartiteObs::VariableFeatures::is_basis_upper) .value("is_basis_zero", NodeBipartiteObs::VariableFeatures ::is_basis_zero); + // FIXME remove in Ecole >0.8 + node_bipartite_obs.def_property_readonly_static("ColumnFeatures", [](py::handle self) { + PyErr_WarnEx(PyExc_DeprecationWarning, "ColumnFeatures is deprecated, use VariableFeatures.", 1); + return self.attr("VariableFeatures"); + }); + py::enum_(node_bipartite_obs, "RowFeatures") .value("bias", NodeBipartiteObs::RowFeatures::bias) .value("objective_cosine_similarity", NodeBipartiteObs::RowFeatures::objective_cosine_similarity) diff --git a/python/tests/test_observation.py b/python/tests/test_observation.py index 45d26f6a..da60dcae 100644 --- a/python/tests/test_observation.py +++ b/python/tests/test_observation.py @@ -87,6 +87,8 @@ def test_Nothing_observation(model): assert make_obs(ecole.observation.Nothing(), model) is None +# FIXME remove in Ecole >0.8 +@pytest.mark.filterwarnings("ignore::DeprecationWarning") def test_NodeBipartite_observation(model): """Observation of NodeBipartite is a type with array attributes.""" obs = make_obs(ecole.observation.NodeBipartite(), model) @@ -100,6 +102,10 @@ def test_NodeBipartite_observation(model): assert len(obs.VariableFeatures.__members__) == obs.variable_features.shape[1] assert len(obs.RowFeatures.__members__) == obs.row_features.shape[1] + # FIXME remove in Ecole >0.8 + assert_array(obs.column_features, ndim=2) + assert len(obs.ColumnFeatures.__members__) == obs.variable_features.shape[1] + def test_MilpBipartite_observation(model): """Observation of MilpBipartite is a type with array attributes."""