Skip to content

Commit bc8a1ad

Browse files
saitcakmakfacebook-github-bot
authored andcommitted
Move quadrature utilities to OSS (#2780)
Summary: -- Differential Revision: D71578344
1 parent bc4b0c6 commit bc8a1ad

File tree

3 files changed

+182
-0
lines changed

3 files changed

+182
-0
lines changed

botorch/utils/quadrature.py

+112
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Optional, Tuple
7+
8+
import numpy as np
9+
import torch
10+
11+
from torch import Tensor
12+
13+
14+
def clenshaw_curtis_quadrature(
15+
deg: int,
16+
a: float = 0.0,
17+
b: float = 1.0,
18+
dtype: Optional[torch.dtype] = None,
19+
device: Optional[torch.device] = None,
20+
) -> Tuple[Tensor, Tensor]:
21+
"""
22+
Clenshaw-Curtis quadrature.
23+
24+
This might be useful if we want to use Chebyshev interpolants for the evaluation
25+
of the component functions. We could even approximate the GP prior as a distribution
26+
over Chebyshev polynomials.
27+
28+
Clenshaw-Curtis quadrature uses the same nodes as Chebyshev interpolants but for
29+
integration.
30+
31+
Args:
32+
deg: Number of sample points and weights. Integrates poynomials of degree
33+
`deg - 1` exactly.
34+
a: Lower bound of the integration domain.
35+
b: Upper bound of the integration domain.
36+
dtype: Desired floating point type of the return Tensors.
37+
device: Desired device type of the return Tensors.
38+
39+
Returns:
40+
A tuple of Clenshaw-Curtis quadrature nodes and weights of length order.
41+
"""
42+
dtype = dtype if dtype is not None else torch.get_default_dtype()
43+
x, w = _clenshaw_curtis_quadrature(order=deg - 1)
44+
x = torch.as_tensor(x, dtype=dtype, device=device)
45+
w = torch.as_tensor(w, dtype=dtype, device=device)
46+
if not (a == 0 and b == 1): # need to normalize for different domain
47+
x = (b - a) * x + a
48+
w = w * (b - a)
49+
return x, w
50+
51+
52+
def higher_dimensional_quadrature(
53+
xs: Tuple[Tensor, ...], ws: Tuple[Tensor, ...]
54+
) -> Tuple[Tensor, Tensor]:
55+
"""
56+
Returns:
57+
A tuple of higher-dimensional quadrature nodes and weights. The nodes are
58+
`n^d x d`-dimensional, the weights are `n^d`-dimensional.
59+
"""
60+
x = torch.cartesian_prod(*xs)
61+
w = torch.cartesian_prod(*ws).prod(-1)
62+
return x, w
63+
64+
65+
def _clenshaw_curtis_quadrature(order: int) -> Tuple[np.ndarray, np.ndarray]:
66+
"""
67+
Clenshaw-Curtis quadrature on integration domain of [0, 1], modified from ChaosPy.
68+
69+
Args:
70+
order: Integrates poynomials of degree order.
71+
72+
Returns:
73+
A tuple of Clenshaw-Curtis quadrature nodes and weights of length order + 1.
74+
"""
75+
if order == 0:
76+
return np.array([0.5]), np.array([1.0])
77+
elif order == 1:
78+
return np.array([0.0, 1.0]), np.array([0.5, 0.5])
79+
80+
theta = (order - np.arange(order + 1)) * np.pi / order
81+
abscissas = 0.5 * np.cos(theta) + 0.5
82+
83+
steps = np.arange(1, order, 2)
84+
length = len(steps)
85+
remains = order - length
86+
87+
beta = np.hstack(
88+
[2.0 / (steps * (steps - 2)), [1.0 / steps[-1]], np.zeros(remains)]
89+
)
90+
beta = -beta[:-1] - beta[:0:-1]
91+
92+
gamma = -np.ones(order)
93+
gamma[length] += order
94+
gamma[remains] += order
95+
gamma /= order**2 - 1 + (order % 2)
96+
97+
# original implementation:
98+
weights = np.fft.ihfft(beta + gamma)
99+
if max(weights.imag) > 1e-15:
100+
raise ValueError(
101+
"Clenshaw-Curtis quadrature weights are not real. Expected imaginary "
102+
f"values to be <1e-15, got {max(weights.imag)=}"
103+
)
104+
weights = weights.real
105+
weights = np.hstack([weights, weights[len(weights) - 2 + (order % 2) :: -1]]) / 2
106+
107+
# implementation based on irfft:
108+
# weights = np.fft.irfft(beta + gamma, order)
109+
# weights = weights / 2
110+
# weights = np.hstack((weights, weights[0]))
111+
112+
return abscissas, weights

sphinx/source/utils.rst

+5
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,11 @@ Safe Math
9797
.. automodule:: botorch.utils.safe_math
9898
:members:
9999

100+
Quadrature
101+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
102+
.. automodule:: botorch.utils.quadrature
103+
:members:
104+
100105
Multi-Objective Utilities
101106
-------------------------------------------
102107

test/utils/test_quadrature.py

+65
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
#
4+
# This source code is licensed under the MIT license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from __future__ import annotations
8+
9+
import itertools
10+
11+
import torch
12+
from botorch.utils.quadrature import (
13+
clenshaw_curtis_quadrature,
14+
higher_dimensional_quadrature,
15+
)
16+
from botorch.utils.testing import BotorchTestCase
17+
18+
19+
class TestQuadrature(BotorchTestCase):
20+
def test_clenshaw_curtis_quadrature(self):
21+
deg_list = [5, 8, 11]
22+
bounds_list = [
23+
None,
24+
(0, 1),
25+
(-1, 1),
26+
(torch.tensor(torch.e), torch.tensor(2 * torch.pi)),
27+
]
28+
29+
for (deg, bounds), dtype in itertools.product(
30+
zip(deg_list, bounds_list), (torch.float32, torch.float64)
31+
):
32+
tkwargs = {"dtype": dtype, "device": self.device}
33+
if bounds is None:
34+
x, w = clenshaw_curtis_quadrature(deg=deg, **tkwargs)
35+
a, b = 0, 1
36+
else:
37+
a, b = bounds
38+
if isinstance(a, torch.Tensor):
39+
a = a.to(**tkwargs)
40+
if isinstance(b, torch.Tensor):
41+
b = b.to(**tkwargs)
42+
x, w = clenshaw_curtis_quadrature(deg=deg, a=a, b=b, **tkwargs)
43+
self.assertEqual(x[0].item(), a)
44+
self.assertEqual(x[-1].item(), b)
45+
self.assertEqual(len(x), deg)
46+
self.assertAllClose(w.sum(), torch.tensor(b - a, **tkwargs), atol=1e-6)
47+
48+
# integrates polynomials of degree up to deg exactly
49+
for i in range(0, deg):
50+
self.assertAllClose(
51+
x.pow(i).dot(w),
52+
torch.tensor((b ** (i + 1) - a ** (i + 1)) / (i + 1), **tkwargs),
53+
atol=1e-6,
54+
)
55+
56+
a, b = 0, 1
57+
x, w = clenshaw_curtis_quadrature(deg=deg, **tkwargs)
58+
xd, wd = higher_dimensional_quadrature((x, x), (w, w))
59+
# testing integral of multi-dimensional additive function
60+
for i in range(0, deg):
61+
self.assertAllClose(
62+
xd.pow(i).sum(dim=-1).dot(wd),
63+
2 * torch.tensor((b ** (i + 1) - a ** (i + 1)) / (i + 1), **tkwargs),
64+
atol=1e-6,
65+
)

0 commit comments

Comments
 (0)