Skip to content

Commit 6373914

Browse files
saitcakmakfacebook-github-bot
authored andcommitted
Check for feasibility in gen_candidates_scipy and error out for infeasible candidates (pytorch#2737)
Summary: As titled. Previously, it was possible to return infeasible candidates to the user, with or without warnings alerting the user to the issue. This diff makes it so that the optimizer will error out when infeasible candidates are generated, so that the user can adjust the setup as needed. Resolves pytorch#2708 Also includes a couple lint fixes in optimizer tests. Differential Revision: D69314159
1 parent 8770fa4 commit 6373914

File tree

5 files changed

+66
-60
lines changed

5 files changed

+66
-60
lines changed

botorch/generation/gen.py

+23-23
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,19 @@
2020
import numpy.typing as npt
2121
import torch
2222
from botorch.acquisition import AcquisitionFunction
23-
from botorch.exceptions.errors import OptimizationGradientError
23+
from botorch.exceptions.errors import (
24+
CandidateGenerationError,
25+
OptimizationGradientError,
26+
)
2427
from botorch.exceptions.warnings import OptimizationWarning
2528
from botorch.generation.utils import _remove_fixed_features_from_optimization
2629
from botorch.logging import logger
2730
from botorch.optim.parameter_constraints import (
2831
_arrayify,
32+
evaluate_feasibility,
2933
make_scipy_bounds,
3034
make_scipy_linear_constraints,
3135
make_scipy_nonlinear_inequality_constraints,
32-
nonlinear_constraint_is_feasible,
3336
)
3437
from botorch.optim.stopping import ExpMAStoppingCriterion
3538
from botorch.optim.utils import columnwise_clamp, fix_features
@@ -237,11 +240,12 @@ def f_np_wrapper(x: npt.NDArray, f: Callable):
237240
def f(x):
238241
return -acquisition_function(x)
239242

243+
method = options.get("method", "SLSQP" if constraints else "L-BFGS-B")
240244
res = minimize_with_timeout(
241245
fun=f_np_wrapper,
242246
args=(f,),
243247
x0=x0,
244-
method=options.get("method", "SLSQP" if constraints else "L-BFGS-B"),
248+
method=method,
245249
jac=with_grad,
246250
bounds=bounds,
247251
constraints=constraints,
@@ -260,26 +264,22 @@ def f(x):
260264
fixed_features=fixed_features,
261265
)
262266

263-
# SLSQP sometimes fails in the line search or may just fail to find a feasible
264-
# candidate in which case we just return the starting point. This happens rarely,
265-
# so it shouldn't be an issue given enough restarts.
266-
if nonlinear_inequality_constraints:
267-
for con, is_intrapoint in nonlinear_inequality_constraints:
268-
if not (
269-
feasible := nonlinear_constraint_is_feasible(
270-
con, is_intrapoint=is_intrapoint, x=candidates
271-
)
272-
).all():
273-
# Replace the infeasible batches with feasible ICs.
274-
candidates[~feasible] = (
275-
torch.from_numpy(x0).to(candidates).reshape(shapeX)[~feasible]
276-
)
277-
warnings.warn(
278-
"SLSQP failed to converge to a solution the satisfies the "
279-
"non-linear constraints. Returning the feasible starting point.",
280-
OptimizationWarning,
281-
stacklevel=2,
282-
)
267+
# SLSQP can sometimes fail to produce a feasible candidate. Check for
268+
# feasibility and error out if necessary.
269+
if not (
270+
is_feasible := evaluate_feasibility(
271+
X=candidates,
272+
inequality_constraints=inequality_constraints,
273+
equality_constraints=equality_constraints,
274+
nonlinear_inequality_constraints=nonlinear_inequality_constraints,
275+
)
276+
).all():
277+
raise CandidateGenerationError(
278+
f"The {method} optimizer produced infeasible candidates. "
279+
f"{(~is_feasible).sum().item()} out of {is_feasible.numel()} batches "
280+
"of candidates were infeasible. Please make sure the constraints are "
281+
"satisfiable and relax them if needed. "
282+
)
283283

284284
clamped_candidates = columnwise_clamp(
285285
X=candidates, lower=lower_bounds, upper=upper_bounds, raise_on_violation=True

botorch/optim/optimize.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1027,7 +1027,7 @@ def optimize_acqf_mixed(
10271027

10281028
if isinstance(acq_function, OneShotAcquisitionFunction):
10291029
if not hasattr(acq_function, "evaluate") and q > 1:
1030-
raise ValueError(
1030+
raise UnsupportedError(
10311031
"`OneShotAcquisitionFunction`s that do not implement `evaluate` "
10321032
"are currently not supported when `q > 1`. This is needed to "
10331033
"compute the joint acquisition value."

botorch/optim/parameter_constraints.py

+11-7
Original file line numberDiff line numberDiff line change
@@ -647,11 +647,11 @@ def evaluate_feasibility(
647647
intra-point or inter-point constraint (`True` for intra-point. `False` for
648648
inter-point). For more information on intra-point vs inter-point
649649
constraints, see the docstring of the `inequality_constraints` argument.
650-
tolerance: The tolerance used to check the feasibility of equality constraints
651-
and non-linear inequality constraints. For equality constraints, we check
652-
if `abs(const(X) - rhs) < tolerance`. For non-linear inequality constraints,
653-
we check if `const(X) >= -tolerance`. This avoids marking the candidates as
654-
infeasible due to tiny violations.
650+
tolerance: The tolerance used to check the feasibility of constraints.
651+
For linear inequality constraints, we check if `const(X) >= rhs - tolerance`.
652+
For equality constraints, we check if `abs(const(X) - rhs) < tolerance`.
653+
For non-linear inequality constraints, we check if `const(X) >= -tolerance`.
654+
This avoids marking the candidates as infeasible due to tiny violations.
655655
656656
Returns:
657657
A boolean tensor of shape `batch` indicating if the corresponding candidate of
@@ -662,10 +662,14 @@ def evaluate_feasibility(
662662
for idx, coef, rhs in inequality_constraints:
663663
if idx.ndim == 1:
664664
# Intra-point constraints.
665-
is_feasible &= ((X[..., idx] * coef).sum(dim=-1) >= rhs).all(dim=-1)
665+
is_feasible &= (
666+
(X[..., idx] * coef).sum(dim=-1) >= rhs - tolerance
667+
).all(dim=-1)
666668
else:
667669
# Inter-point constraints.
668-
is_feasible &= (X[..., idx[:, 0], idx[:, 1]] * coef).sum(dim=-1) >= rhs
670+
is_feasible &= (X[..., idx[:, 0], idx[:, 1]] * coef).sum(
671+
dim=-1
672+
) >= rhs - tolerance
669673
if equality_constraints is not None:
670674
for idx, coef, rhs in equality_constraints:
671675
if idx.ndim == 1:

test/generation/test_gen.py

+25-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,10 @@
1111

1212
import torch
1313
from botorch.acquisition import qExpectedImprovement, qKnowledgeGradient
14-
from botorch.exceptions.errors import OptimizationGradientError
14+
from botorch.exceptions.errors import (
15+
CandidateGenerationError,
16+
OptimizationGradientError,
17+
)
1518
from botorch.exceptions.warnings import OptimizationWarning
1619
from botorch.fit import fit_gpytorch_mll
1720
from botorch.generation.gen import (
@@ -378,6 +381,27 @@ def test_gen_candidates_scipy_invalid_method(self) -> None:
378381
upper_bounds=1,
379382
)
380383

384+
def test_gen_candidates_scipy_infeasible_candidates(self) -> None:
385+
# Check for error when infeasible candidates are generated.
386+
ics = torch.rand(2, 3, 1, device=self.device)
387+
with mock.patch(
388+
"botorch.generation.gen.minimize_with_timeout",
389+
return_value=OptimizeResult(x=ics.view(-1).cpu().numpy()),
390+
), self.assertRaisesRegex(
391+
CandidateGenerationError, "infeasible candidates. 2 out of 2"
392+
):
393+
gen_candidates_scipy(
394+
initial_conditions=ics,
395+
acquisition_function=MockAcquisitionFunction(),
396+
inequality_constraints=[
397+
( # X[..., 0] >= 2.0, which is infeasible.
398+
torch.tensor([0], device=self.device),
399+
torch.tensor([1.0], device=self.device),
400+
2.0,
401+
)
402+
],
403+
)
404+
381405

382406
class TestRandomRestartOptimization(TestBaseCandidateGeneration):
383407
def test_random_restart_optimization(self):

test/optim/test_optimize.py

+6-28
Original file line numberDiff line numberDiff line change
@@ -623,10 +623,10 @@ def test_optimize_acqf_batch_limit(self) -> None:
623623

624624
for ic_shape, expected_shape in [((2, 1, dim), 2), ((2, dim), 1)]:
625625
with self.subTest(gen_candidates=gen_candidates):
626+
ics = torch.ones((ic_shape))
626627
with self.assertWarnsRegex(
627628
RuntimeWarning, "botorch will default to old behavior"
628629
):
629-
ics = torch.ones((ic_shape))
630630
_candidates, acq_value_list = optimize_acqf(
631631
acq_function=SinOneOverXAcqusitionFunction(),
632632
bounds=torch.stack([-1 * torch.ones(dim), torch.ones(dim)]),
@@ -638,8 +638,7 @@ def test_optimize_acqf_batch_limit(self) -> None:
638638
gen_candidates=gen_candidates,
639639
batch_initial_conditions=ics,
640640
)
641-
642-
self.assertEqual(acq_value_list.shape, (expected_shape,))
641+
self.assertEqual(acq_value_list.shape, (expected_shape,))
643642

644643
def test_optimize_acqf_runs_given_batch_initial_conditions(self):
645644
num_restarts, raw_samples, dim = 1, 2, 3
@@ -915,27 +914,6 @@ def nlc1(x):
915914
torch.allclose(acq_value, torch.tensor([4], **tkwargs), atol=1e-3)
916915
)
917916

918-
# Make sure we return the initial solution if SLSQP fails to return
919-
# a feasible point.
920-
with mock.patch(
921-
"botorch.generation.gen.minimize_with_timeout"
922-
) as mock_minimize:
923-
# By setting "success" to True and "status" to 0, we prevent a
924-
# warning that `minimize` failed, which isn't the behavior
925-
# we're looking to test here.
926-
mock_minimize.return_value = OptimizeResult(
927-
x=np.array([4, 4, 4]), success=True, status=0
928-
)
929-
candidates, acq_value = optimize_acqf(
930-
acq_function=mock_acq_function,
931-
bounds=bounds,
932-
q=1,
933-
nonlinear_inequality_constraints=[(nlc1, True)],
934-
batch_initial_conditions=batch_initial_conditions,
935-
num_restarts=1,
936-
)
937-
self.assertAllClose(candidates, batch_initial_conditions[0, ...])
938-
939917
# Constrain all variables to be >= 1. The global optimum is 2.45 and
940918
# is attained by some permutation of [1, 1, 2]
941919
def nlc2(x):
@@ -1685,10 +1663,10 @@ def test_optimize_acqf_mixed_q2(self, mock_optimize_acqf):
16851663
self.assertTrue(torch.equal(acq_value, expected_acq_value))
16861664

16871665
def test_optimize_acqf_mixed_empty_ff(self):
1666+
mock_acq_function = MockAcquisitionFunction()
16881667
with self.assertRaisesRegex(
16891668
ValueError, expected_regex="fixed_features_list must be non-empty."
16901669
):
1691-
mock_acq_function = MockAcquisitionFunction()
16921670
optimize_acqf_mixed(
16931671
acq_function=mock_acq_function,
16941672
q=1,
@@ -1715,9 +1693,9 @@ def test_optimize_acqf_mixed_return_best_only_q2(self):
17151693
)
17161694

17171695
def test_optimize_acqf_one_shot_large_q(self):
1718-
with self.assertRaises(ValueError):
1719-
mock_acq_function = MockOneShotAcquisitionFunction()
1720-
fixed_features_list = [{i: i * 0.1} for i in range(2)]
1696+
mock_acq_function = MockOneShotAcquisitionFunction()
1697+
fixed_features_list = [{i: i * 0.1} for i in range(2)]
1698+
with self.assertRaisesRegex(UnsupportedError, "OneShotAcquisitionFunction"):
17211699
optimize_acqf_mixed(
17221700
acq_function=mock_acq_function,
17231701
q=2,

0 commit comments

Comments
 (0)