From d20f8da548437460fe4905b54e5a62f133992ce4 Mon Sep 17 00:00:00 2001 From: Max Balandat Date: Sat, 21 May 2022 22:24:50 -0700 Subject: [PATCH 1/5] [RFC] Validate constraints in optimize_acqf In some cases we may allow not using box constraints (e.g. when optimizing Alebo). Also, in general it would be good to the optimziation raise an error if the constraint set is empty. --- botorch/optim/optimize.py | 102 +++++++++++++++++++++++++++++++++++--- 1 file changed, 96 insertions(+), 6 deletions(-) diff --git a/botorch/optim/optimize.py b/botorch/optim/optimize.py index dd01031663..fc6f023bb4 100644 --- a/botorch/optim/optimize.py +++ b/botorch/optim/optimize.py @@ -12,6 +12,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union +import numpy as np import torch from botorch.acquisition.acquisition import ( AcquisitionFunction, @@ -26,6 +27,7 @@ gen_one_shot_kg_initial_conditions, ) from botorch.optim.stopping import ExpMAStoppingCriterion +from scipy.optimize import linprog from torch import Tensor INIT_OPTION_KEYS = { @@ -75,10 +77,10 @@ def optimize_acqf( raw_samples: The number of samples for initialization. This is required if `batch_initial_conditions` is not specified. options: Options for candidate generation. - inequality constraints: A list of tuples (indices, coefficients, rhs), + inequality_constraints: A list of tuples (indices, coefficients, rhs), with each tuple encoding an inequality constraint of the form `\sum_i (X[indices[i]] * coefficients[i]) >= rhs` - equality constraints: A list of tuples (indices, coefficients, rhs), + equality_constraints: A list of tuples (indices, coefficients, rhs), with each tuple encoding an inequality constraint of the form `\sum_i (X[indices[i]] * coefficients[i]) = rhs` nonlinear_inequality_constraints: A list of callables with that represent @@ -125,10 +127,11 @@ def optimize_acqf( >>> qEI, bounds, 3, 15, 256, sequential=True >>> ) """ - if not (bounds.ndim == 2 and bounds.shape[0] == 2): - raise ValueError( - f"bounds should be a `2 x d` tensor, current shape: {list(bounds.shape)}." - ) + _validate_constraints( + bounds=bounds, + inequality_constraints=inequality_constraints, + equality_constraints=equality_constraints, + ) if sequential and q > 1: if not return_best_only: @@ -707,6 +710,93 @@ def _gen_batch_initial_conditions_local_search( raise RuntimeError(f"Failed to generate at least {min_points} initial conditions") +def _validate_constraints( + bounds: Tensor, + inequality_constraints: Optional[List[Tuple[Tensor, Tensor, float]]] = None, + equality_constraints: Optional[List[Tuple[Tensor, Tensor, float]]] = None, +) -> None: + r"""Validate constraints for acquisition function optimization. + + Checks that the constraints define a bounded, non-empty polytope. + + Args: + bounds: A `2 x d` tensor of lower and upper bounds for each column of `X`. + If there are no box constraints, bounds should be an empty `0 x d`-dim + tensor. + inequality constraints: A list of tuples (indices, coefficients, rhs), + with each tuple encoding an inequality constraint of the form + `\sum_i (X[indices[i]] * coefficients[i]) >= rhs` + equality constraints: A list of tuples (indices, coefficients, rhs), + with each tuple encoding an inequality constraint of the form + `\sum_i (X[indices[i]] * coefficients[i]) = rhs` + """ + # We solve the following Linear Program to ensure that he constraint set + # is non-empty and bounded: + # + # max_x |x|_1 s.t. bounds, inequality_constraints, constraints + # + # To do this we can introduce auxiliary variables s and solve the + # following standard formulation: + # + # min_(x, s) - sum_i(s_i) + # s.t. -x <= s <= x + # bounds(x) + # inequality_constraints(x) + # equality_constraints(x) + # + if bounds.numel() == 0: + if inequality_constraints is None: + raise UnsupportedError( + "Must provide either `bounds` or `inequality_constraints` (or both)." + ) + elif not (bounds.ndim == 2 and bounds.shape[0] == 2): + raise ValueError( + f"bounds should be a `2 x d` tensor, current shape: {list(bounds.shape)}." + ) + d = bounds.shape[-1] + bounds_lp, A_ub, b_ub, A_eq, b_eq = None, None, None, None, None + if bounds.numel() > 0: + bounds_lp = [tuple(b_i) for b_i in bounds.t()] + [(None, None)] * d + A_ub = np.zeros((2 * d, 2 * d)) + b_ub = np.zeros(2 * d) + A_ub[:d, :d] = -1.0 + A_ub[:d, d : 2 * d] = -1.0 + A_ub[d : 2 * d, :d] = -1.0 + A_ub[d : 2 * d, d : 2 * d] = 1.0 + if inequality_constraints is not None: + A_ineq = np.zeros((len(inequality_constraints), 2 * d)) + b_ineq = np.zeros(len(inequality_constraints)) + for i, (indices, coefficients, rhs) in enumerate(inequality_constraints): + A_ineq[i, indices] = -coefficients + b_ineq[i] = -rhs + A_ub = np.concatenate((A_ub, A_ineq)) + b_ub = np.concatenate((b_ub, b_ineq)) + if equality_constraints is not None: + A_eq = np.zeros((len(equality_constraints), 2 * d)) + b_eq = np.zeros(len(equality_constraints)) + for i, (indices, coefficients, rhs) in enumerate(equality_constraints): + A_eq[i, indices] = coefficients + b_eq[i] = rhs + c = np.concatenate((np.zeros(d), -np.ones(d))) + result = linprog( + c=c, + bounds=bounds_lp, + A_ub=A_ub, + b_ub=b_ub, + A_eq=A_eq, + b_eq=b_eq, + ) + if not result.success: + if result.status == 2: + raise ValueError("Feasible set non-empty. Check your constraints") + if result.status == 3: + raise ValueError("Feasible set unbounded.") + warnings.warn( + "Ran into issus when checking for boundedness of feasible set. " + f"Optimizer message: {result.message}." + ) + + def optimize_acqf_discrete_local_search( acq_function: AcquisitionFunction, discrete_choices: List[Tensor], From b581ec1728a7de3e33578a5b1d4967e9869a5dcd Mon Sep 17 00:00:00 2001 From: Max Balandat Date: Sun, 22 May 2022 09:04:38 -0700 Subject: [PATCH 2/5] Add unit tests --- botorch/optim/optimize.py | 14 +++++++-- test/optim/test_optimize.py | 61 +++++++++++++++++++++++++++++++++++++ 2 files changed, 72 insertions(+), 3 deletions(-) diff --git a/botorch/optim/optimize.py b/botorch/optim/optimize.py index fc6f023bb4..be4cf66e00 100644 --- a/botorch/optim/optimize.py +++ b/botorch/optim/optimize.py @@ -751,18 +751,22 @@ def _validate_constraints( ) elif not (bounds.ndim == 2 and bounds.shape[0] == 2): raise ValueError( - f"bounds should be a `2 x d` tensor, current shape: {list(bounds.shape)}." + f"bounds should be a `2 x d` tensor, current shape: {tuple(bounds.shape)}." ) d = bounds.shape[-1] bounds_lp, A_ub, b_ub, A_eq, b_eq = None, None, None, None, None + # The first `d` variables are `x`, the last `d` are the auxiliary `s` if bounds.numel() > 0: + # `s` is unbounded bounds_lp = [tuple(b_i) for b_i in bounds.t()] + [(None, None)] * d + # Encode the constraint `-x <= s <= x` A_ub = np.zeros((2 * d, 2 * d)) b_ub = np.zeros(2 * d) A_ub[:d, :d] = -1.0 A_ub[:d, d : 2 * d] = -1.0 A_ub[d : 2 * d, :d] = -1.0 A_ub[d : 2 * d, d : 2 * d] = 1.0 + # Convet and add additional inequality constraints if present if inequality_constraints is not None: A_ineq = np.zeros((len(inequality_constraints), 2 * d)) b_ineq = np.zeros(len(inequality_constraints)) @@ -771,13 +775,16 @@ def _validate_constraints( b_ineq[i] = -rhs A_ub = np.concatenate((A_ub, A_ineq)) b_ub = np.concatenate((b_ub, b_ineq)) + # Convert equality constraints if present if equality_constraints is not None: A_eq = np.zeros((len(equality_constraints), 2 * d)) b_eq = np.zeros(len(equality_constraints)) for i, (indices, coefficients, rhs) in enumerate(equality_constraints): A_eq[i, indices] = coefficients b_eq[i] = rhs + # Objective is `- sum_i s_i` (note: the `s_i` are guaranteed to be positive) c = np.concatenate((np.zeros(d), -np.ones(d))) + # Solve the problem result = linprog( c=c, bounds=bounds_lp, @@ -786,13 +793,14 @@ def _validate_constraints( A_eq=A_eq, b_eq=b_eq, ) + # Check what's going on if unsuccessful if not result.success: if result.status == 2: - raise ValueError("Feasible set non-empty. Check your constraints") + raise ValueError("Feasible set non-empty. Check your constraints.") if result.status == 3: raise ValueError("Feasible set unbounded.") warnings.warn( - "Ran into issus when checking for boundedness of feasible set. " + "Ran into issues when checking for boundedness of feasible set. " f"Optimizer message: {result.message}." ) diff --git a/test/optim/test_optimize.py b/test/optim/test_optimize.py index f7a11f7856..e327a88823 100644 --- a/test/optim/test_optimize.py +++ b/test/optim/test_optimize.py @@ -19,6 +19,7 @@ _filter_invalid, _gen_batch_initial_conditions_local_search, _generate_neighbors, + _validate_constraints, optimize_acqf, optimize_acqf_cyclic, optimize_acqf_discrete, @@ -72,6 +73,66 @@ def rounding_func(X: Tensor) -> Tensor: class TestOptimizeAcqf(BotorchTestCase): + def test_validate_constraints(self): + for dtype in (torch.float, torch.double): + tkwargs = {"device": self.device, "dtype": dtype} + with self.assertRaisesRegex( + UnsupportedError, "Must provide either `bounds` or `inequality_constraints`" + ): + _validate_constraints(bounds=torch.empty(0, 2, **tkwargs)) + with self.assertRaisesRegex( + # TODO: Figure out why the full rendered string doesn't regex-match + ValueError, + f"bounds should be a `2 x d` tensor, current shape:", # {(3, 2)}." + ): + _validate_constraints(bounds=torch.zeros(3, 2), inequality_constraints=[]) + # Check standard box bounds + bounds = torch.stack((torch.zeros(2, **tkwargs), torch.ones(2, **tkwargs))) + _validate_constraints(bounds=bounds) + # Check failure on empty box + with self.assertRaisesRegex( + ValueError, "Feasible set non-empty. Check your constraints." + ): + _validate_constraints(bounds=bounds.flip(0)) + # Check failure on unbounded "box" + bounds[1, 1] = float("inf") + with self.assertRaisesRegex(ValueError, "Feasible set unbounded."): + _validate_constraints(bounds=bounds) + # Check that added inequality constraint resolve this + _validate_constraints( + bounds=bounds, + inequality_constraints=[ + ( + torch.tensor([1], device=self.device), + torch.tensor([-1.0], **tkwargs), + -2.0, + ) + ], + ) + # Check hat added equality constraint resolves this + _validate_constraints( + bounds=bounds, + equality_constraints=[ + ( + torch.tensor([0, 1], device=self.device), + torch.tensor([1.0, -1.0], **tkwargs), + 0.0, + ) + ], + ) + # Check that inequality constraints alone work + zero = torch.tensor([0], device=self.device) + one = torch.tensor([1], device=self.device) + inequality_constraints = [ + (zero, torch.tensor([1.0], **tkwargs), 0.0), + (zero, torch.tensor([-1.0], **tkwargs), -1.0), + (one, torch.tensor([1.0], **tkwargs), 0.0), + (one, torch.tensor([-1.0], **tkwargs), -1.0), + ] + _validate_constraints( + bounds=bounds, inequality_constraints=inequality_constraints + ) + @mock.patch("botorch.optim.optimize.gen_batch_initial_conditions") @mock.patch("botorch.optim.optimize.gen_candidates_scipy") def test_optimize_acqf_joint( From 797bb5ba8d40186653be9aa32a7db96767febe48 Mon Sep 17 00:00:00 2001 From: Max Balandat Date: Sun, 22 May 2022 09:06:46 -0700 Subject: [PATCH 3/5] Fix typo --- botorch/optim/optimize.py | 5 ++++- test/optim/test_optimize.py | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/botorch/optim/optimize.py b/botorch/optim/optimize.py index be4cf66e00..13726f3b93 100644 --- a/botorch/optim/optimize.py +++ b/botorch/optim/optimize.py @@ -733,7 +733,10 @@ def _validate_constraints( # We solve the following Linear Program to ensure that he constraint set # is non-empty and bounded: # - # max_x |x|_1 s.t. bounds, inequality_constraints, constraints + # max_x |x|_1 + # s.t. bounds(x) + # inequality_constraints(x) + # equality_constraints(x) # # To do this we can introduce auxiliary variables s and solve the # following standard formulation: diff --git a/test/optim/test_optimize.py b/test/optim/test_optimize.py index e327a88823..7d86a58ace 100644 --- a/test/optim/test_optimize.py +++ b/test/optim/test_optimize.py @@ -109,7 +109,7 @@ def test_validate_constraints(self): ) ], ) - # Check hat added equality constraint resolves this + # Check that added equality constraint resolves this _validate_constraints( bounds=bounds, equality_constraints=[ From 02ea85c3d71680d6593ddb460338b4729d2c32b9 Mon Sep 17 00:00:00 2001 From: Max Balandat Date: Sun, 22 May 2022 09:29:24 -0700 Subject: [PATCH 4/5] Fix string regex check, close coverage gap --- botorch/optim/optimize.py | 8 ++++++-- test/optim/test_optimize.py | 21 +++++++++++++++++---- 2 files changed, 23 insertions(+), 6 deletions(-) diff --git a/botorch/optim/optimize.py b/botorch/optim/optimize.py index 13726f3b93..66cbb7a3a8 100644 --- a/botorch/optim/optimize.py +++ b/botorch/optim/optimize.py @@ -10,6 +10,8 @@ from __future__ import annotations +import warnings + from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np @@ -19,7 +21,7 @@ OneShotAcquisitionFunction, ) from botorch.acquisition.knowledge_gradient import qKnowledgeGradient -from botorch.exceptions import InputDataError, UnsupportedError +from botorch.exceptions import InputDataError, OptimizationWarning, UnsupportedError from botorch.generation.gen import gen_candidates_scipy from botorch.logging import logger from botorch.optim.initializers import ( @@ -30,6 +32,7 @@ from scipy.optimize import linprog from torch import Tensor + INIT_OPTION_KEYS = { # set of options for initialization that we should # not pass to scipy.optimize.minimize to avoid @@ -804,7 +807,8 @@ def _validate_constraints( raise ValueError("Feasible set unbounded.") warnings.warn( "Ran into issues when checking for boundedness of feasible set. " - f"Optimizer message: {result.message}." + f"Optimizer message: {result.message}.", + OptimizationWarning, ) diff --git a/test/optim/test_optimize.py b/test/optim/test_optimize.py index 7d86a58ace..ed4ad5c86e 100644 --- a/test/optim/test_optimize.py +++ b/test/optim/test_optimize.py @@ -5,15 +5,18 @@ # LICENSE file in the root directory of this source tree. import itertools + +import warnings from unittest import mock import numpy as np import torch +from botorch import settings from botorch.acquisition.acquisition import ( AcquisitionFunction, OneShotAcquisitionFunction, ) -from botorch.exceptions import InputDataError, UnsupportedError +from botorch.exceptions import InputDataError, OptimizationWarning, UnsupportedError from botorch.optim.optimize import ( _filter_infeasible, _filter_invalid, @@ -81,9 +84,7 @@ def test_validate_constraints(self): ): _validate_constraints(bounds=torch.empty(0, 2, **tkwargs)) with self.assertRaisesRegex( - # TODO: Figure out why the full rendered string doesn't regex-match - ValueError, - f"bounds should be a `2 x d` tensor, current shape:", # {(3, 2)}." + ValueError, r"bounds should be a `2 x d` tensor, current shape: \(3, 2\)." ): _validate_constraints(bounds=torch.zeros(3, 2), inequality_constraints=[]) # Check standard box bounds @@ -132,6 +133,18 @@ def test_validate_constraints(self): _validate_constraints( bounds=bounds, inequality_constraints=inequality_constraints ) + # Check that other messages are surfaced as warnings + bounds = torch.stack((torch.zeros(2, **tkwargs), torch.ones(2, **tkwargs))) + mock_result = OptimizeResult(success=False, status=-1, message="foo") + with mock.patch("botorch.optim.optimize.linprog", return_value=mock_result): + with warnings.catch_warnings(record=True) as ws, settings.debug(True): + _validate_constraints(bounds=bounds) + self.assertTrue(any(issubclass(w.category, OptimizationWarning)) for w in ws) + expected_msg = ( + "Ran into issues when checking for boundedness of feasible set. " + "Optimizer message: foo." + ) + self.assertTrue(any(expected_msg in str(w.message) for w in ws)) @mock.patch("botorch.optim.optimize.gen_batch_initial_conditions") @mock.patch("botorch.optim.optimize.gen_candidates_scipy") From b499da9136bf81bde98768d1ecd175385dea7f17 Mon Sep 17 00:00:00 2001 From: Max Balandat Date: Fri, 27 May 2022 19:03:09 -0700 Subject: [PATCH 5/5] Add option to validate constraints This allows not running the validation. This is useful e.g in inner loops in order to avoid re-running the validation with the same parameters repeatedly. --- botorch/optim/optimize.py | 41 ++++++++++++++++++++++++++++++++----- test/optim/test_optimize.py | 21 ++++++++++++------- 2 files changed, 50 insertions(+), 12 deletions(-) diff --git a/botorch/optim/optimize.py b/botorch/optim/optimize.py index 66cbb7a3a8..626885e4bc 100644 --- a/botorch/optim/optimize.py +++ b/botorch/optim/optimize.py @@ -67,6 +67,7 @@ def optimize_acqf( batch_initial_conditions: Optional[Tensor] = None, return_best_only: bool = True, sequential: bool = False, + validate_constraints: bool = True, **kwargs: Any, ) -> Tuple[Tensor, Tensor]: r"""Generate a set of candidates via multi-start optimization. @@ -105,6 +106,8 @@ def optimize_acqf( random restart initializations of the optimization. sequential: If False, uses joint optimization, otherwise uses sequential optimization. + validate_constraints: If True, validate that the constraint set is + non-empty and bounded by solving a Linear Program. kwargs: Additonal keyword arguments. Returns: @@ -130,11 +133,12 @@ def optimize_acqf( >>> qEI, bounds, 3, 15, 256, sequential=True >>> ) """ - _validate_constraints( - bounds=bounds, - inequality_constraints=inequality_constraints, - equality_constraints=equality_constraints, - ) + if validate_constraints: + _validate_constraints( + bounds=bounds, + inequality_constraints=inequality_constraints, + equality_constraints=equality_constraints, + ) if sequential and q > 1: if not return_best_only: @@ -164,6 +168,7 @@ def optimize_acqf( batch_initial_conditions=None, return_best_only=True, sequential=False, + validate_constraints=False, ) candidate_list.append(candidate) acq_value_list.append(acq_value) @@ -273,6 +278,7 @@ def optimize_acqf_cyclic( post_processing_func: Optional[Callable[[Tensor], Tensor]] = None, batch_initial_conditions: Optional[Tensor] = None, cyclic_options: Optional[Dict[str, Union[bool, float, int, str]]] = None, + validate_constraints: bool = True, ) -> Tuple[Tensor, Tensor]: r"""Generate a set of `q` candidates via cyclic optimization. @@ -300,6 +306,8 @@ def optimize_acqf_cyclic( If no initial conditions are provided, the default initialization will be used. cyclic_options: Options for stopping criterion for outer cyclic optimization. + validate_constraints: If True, validate that the constraint set is + non-empty and bounded by solving a Linear Program. Returns: A two-element tuple containing @@ -334,6 +342,7 @@ def optimize_acqf_cyclic( batch_initial_conditions=batch_initial_conditions, return_best_only=True, sequential=True, + validate_constraints=validate_constraints, ) if q > 1: cyclic_options = cyclic_options or {} @@ -364,6 +373,7 @@ def optimize_acqf_cyclic( batch_initial_conditions=candidates[i].unsqueeze(0), return_best_only=True, sequential=True, + validate_constraints=False, ) candidates[i] = candidate_i acq_vals[i] = acq_val_i @@ -383,6 +393,7 @@ def optimize_acqf_list( equality_constraints: Optional[List[Tuple[Tensor, Tensor, float]]] = None, fixed_features: Optional[Dict[int, float]] = None, post_processing_func: Optional[Callable[[Tensor], Tensor]] = None, + validate_constraints: bool = True, ) -> Tuple[Tensor, Tensor]: r"""Generate a list of candidates from a list of acquisition functions. @@ -408,6 +419,8 @@ def optimize_acqf_list( post_processing_func: A function that post-processes an optimization result appropriately (i.e., according to `round-trip` transformations). + validate_constraints: If True, validate that the constraint set is + non-empty and bounded by solving a Linear Program. Returns: A two-element tuple containing @@ -419,6 +432,13 @@ def optimize_acqf_list( """ if not acq_function_list: raise ValueError("acq_function_list must be non-empty.") + if validate_constraints: + _validate_constraints( + bounds=bounds, + inequality_constraints=inequality_constraints, + equality_constraints=equality_constraints, + ) + candidate_list, acq_value_list = [], [] candidates = torch.tensor([], device=bounds.device, dtype=bounds.dtype) base_X_pending = acq_function_list[0].X_pending @@ -442,6 +462,7 @@ def optimize_acqf_list( post_processing_func=post_processing_func, return_best_only=True, sequential=False, + validate_constraints=False, ) candidate_list.append(candidate) acq_value_list.append(acq_value) @@ -461,6 +482,7 @@ def optimize_acqf_mixed( equality_constraints: Optional[List[Tuple[Tensor, Tensor, float]]] = None, post_processing_func: Optional[Callable[[Tensor], Tensor]] = None, batch_initial_conditions: Optional[Tensor] = None, + validate_constraints: bool = True, **kwargs: Any, ) -> Tuple[Tensor, Tensor]: r"""Optimize over a list of fixed_features and returns the best solution. @@ -491,6 +513,8 @@ def optimize_acqf_mixed( transformations). batch_initial_conditions: A tensor to specify the initial conditions. Set this if you do not want to use default initialization strategy. + validate_constraints: If True, validate that the constraint set is + non-empty and bounded by solving a Linear Program. Returns: A two-element tuple containing @@ -508,6 +532,12 @@ def optimize_acqf_mixed( "are currently not supported when `q > 1`. This is needed to " "compute the joint acquisition value." ) + if validate_constraints: + _validate_constraints( + bounds=bounds, + inequality_constraints=inequality_constraints, + equality_constraints=equality_constraints, + ) if q == 1: ff_candidate_list, ff_acq_value_list = [], [] @@ -525,6 +555,7 @@ def optimize_acqf_mixed( post_processing_func=post_processing_func, batch_initial_conditions=batch_initial_conditions, return_best_only=True, + validate_constraints=False, ) ff_candidate_list.append(candidate) ff_acq_value_list.append(acq_value) diff --git a/test/optim/test_optimize.py b/test/optim/test_optimize.py index ed4ad5c86e..90f0992122 100644 --- a/test/optim/test_optimize.py +++ b/test/optim/test_optimize.py @@ -663,11 +663,19 @@ def test_optimize_acqf_cyclic(self, mock_optimize_acqf): if i == 0: # first cycle expected_call_args.update( - {"batch_initial_conditions": None, "q": q} + { + "batch_initial_conditions": None, + "q": q, + "validate_constraints": True, + } ) else: expected_call_args.update( - {"batch_initial_conditions": orig_candidates[i - 1 : i], "q": 1} + { + "batch_initial_conditions": orig_candidates[i - 1 : i], + "q": 1, + "validate_constraints": False, + } ) orig_candidates[i - 1] = candidate_rvs[i] for k, v in call_args_list[i][1].items(): @@ -689,9 +697,6 @@ def test_optimize_acqf_list(self, mock_optimize_acqf): options = {} tkwargs = {"device": self.device} bounds = torch.stack([torch.zeros(3), 4 * torch.ones(3)]) - inequality_constraints = [ - [torch.tensor([3]), torch.tensor([4]), torch.tensor(5)] - ] # reinitialize so that dtype mock_acq_function_1 = MockAcquisitionFunction() mock_acq_function_2 = MockAcquisitionFunction() @@ -701,8 +706,8 @@ def test_optimize_acqf_list(self, mock_optimize_acqf): # clear previous X_pending m.set_X_pending(None) tkwargs["dtype"] = dtype - inequality_constraints[0] = [ - t.to(**tkwargs) for t in inequality_constraints[0] + inequality_constraints = [ + [torch.tensor([3]), torch.tensor([4.0], **tkwargs), 5.0] ] mock_optimize_acqf.reset_mock() bounds = bounds.to(**tkwargs) @@ -775,6 +780,7 @@ def test_optimize_acqf_list(self, mock_optimize_acqf): "batch_initial_conditions": None, "return_best_only": True, "sequential": False, + "validate_constraints": False, } for i in range(len(call_args_list)): expected_call_args["acq_function"] = mock_acq_function_list[i] @@ -855,6 +861,7 @@ def test_optimize_acqf_mixed_q1(self, mock_optimize_acqf): "batch_initial_conditions": None, "return_best_only": True, "sequential": False, + "validate_constraints": False, } for i in range(len(call_args_list)): expected_call_args["fixed_features"] = fixed_features_list[i]