Skip to content

Commit

Permalink
Support batched icwt; fix icwt with scaletype='linear'; depreca…
Browse files Browse the repository at this point in the history
…tion fixes
  • Loading branch information
OverLordGoldDragon authored Jul 25, 2024
1 parent 85563d0 commit fa4c720
Show file tree
Hide file tree
Showing 10 changed files with 48 additions and 26 deletions.
11 changes: 11 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,14 @@
### 0.6.6

#### FEATURES
- `icwt` now supports batched `Wx` (3D `Wx`, i.e. `cwt(x)` upon 2D `x`, `(n_inputs, n_times)`)

#### FIXES
- `icwt` with `scaletype='linear'`: fix constant scaling factor
- scipy deprecation: `scipy.integrate.trapz` -> `scipy.integrate.trapezoid`
- numpy deprecation: fix `int()` upon 1D array (of size 1)
- numpy deprecation: `np.cfloat` -> `np.complex128`

### 0.6.5

#### FIXES
Expand Down
2 changes: 1 addition & 1 deletion ssqueezepy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
"""


__version__ = '0.6.5'
__version__ = '0.6.6-dev'
__title__ = 'ssqueezepy'
__author__ = 'John Muradeli'
__license__ = __doc__
Expand Down
25 changes: 18 additions & 7 deletions ssqueezepy/_cwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,10 @@ def icwt(Wx, wavelet='gmw', scales='log-piecewise', nv=None, one_int=True,
Wx: np.ndarray
CWT computed via `ssqueezepy.cwt`.
- 2D: (n_scales, n_times)
- 3D: (n_inputs, n_scales, n_times).
Doesn't support `one_int=False`.
wavelet: str / tuple[str, dict] / `wavelets.Wavelet`
Wavelet sampled in Fourier frequency domain.
- str: name of builtin wavelet. `ssqueezepy.wavs()`
Expand All @@ -352,6 +356,8 @@ def icwt(Wx, wavelet='gmw', scales='log-piecewise', nv=None, one_int=True,
x_mean: float. mean of original `x` (not picked up in CWT since it's an
infinite scale component). Default 0.
Note: if `Wx` is 3D, `x_mean` should be 1D (`x.mean()` along samples
axis).
padtype: str
Pad scheme to apply on input, in case of `one_int=False`.
Expand All @@ -365,7 +371,9 @@ def icwt(Wx, wavelet='gmw', scales='log-piecewise', nv=None, one_int=True,
# Returns:
x: np.ndarray
The signal, as reconstructed from Wx.
The signal(s), as reconstructed from Wx.
If `Wx` is 3D, `x` has shape `(n_inputs, n_times)`.
# References:
1. One integral inverse CWT. John Muradeli.
Expand Down Expand Up @@ -394,7 +402,7 @@ def icwt(Wx, wavelet='gmw', scales='log-piecewise', nv=None, one_int=True,
synsq_cwt_iw.m
"""
#### Prepare for inversion ###############################################
na, n = Wx.shape
*_, na, n = Wx.shape
x_len = x_len or n
if not isinstance(scales, np.ndarray) and nv is None:
nv = 32 # must match forward's; default to `cwt`'s
Expand All @@ -414,27 +422,30 @@ def icwt(Wx, wavelet='gmw', scales='log-piecewise', nv=None, one_int=True,
padtype=padtype, rpadded=rpadded, l1_norm=l1_norm)

idx = logscale_transition_idx(scales)
x = icwt(Wx[:idx], scales=scales[:idx], **kw)
x += icwt(Wx[idx:], scales=scales[idx:], **kw)
x = icwt(Wx[..., :idx, :], scales=scales[:idx], **kw)
x += icwt(Wx[..., idx:, :], scales=scales[idx:], **kw)
return x
##########################################################################

#### Invert ##############################################################
if one_int:
x = _icwt_1int(Wx, scales, scaletype, l1_norm)
else:
if Wx.ndim == 3:
raise NotImplementedError("batched `Wx` requires `one_int=True`.")
x = _icwt_2int(Wx, scales, scaletype, l1_norm,
wavelet, x_len, padtype, rpadded)

# admissibility coefficient
Cpsi = (adm_ssq(wavelet) if one_int else
adm_cwt(wavelet))
if scaletype == 'log':
# Eq 4.67 in [1]; Theorem 4.5 in [1]; below Eq 14 in [2]
# Eq 4.67 in [3]; Theorem 4.5 in [3]; below Eq 14 in [5]
# ln(2**(1/nv)) == ln(2)/nv == diff(ln(scales))[0]
x *= (2 / Cpsi) * np.log(2 ** (1 / nv))
else:
x *= (2 / Cpsi)
# unclear why the `pi/4` here but it improves inversion
x *= (2 / Cpsi) * np.pi / 4

x += x_mean # CWT doesn't capture mean (infinite scale)
return x
Expand Down Expand Up @@ -466,7 +477,7 @@ def _icwt_2int(Wx, scales, scaletype, l1_norm, wavelet, x_len,
def _icwt_1int(Wx, scales, scaletype, l1_norm):
"""One-integral iCWT; assumes analytic wavelet."""
norm = _icwt_norm(scaletype, l1_norm)
return (Wx.real / norm(scales)).sum(axis=0)
return (Wx.real / norm(scales)).sum(axis=-2)


def _icwt_norm(scaletype, l1_norm):
Expand Down
2 changes: 1 addition & 1 deletion ssqueezepy/_ssq_cwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,7 +581,7 @@ def _differentiate(Wx, dt):

# epsilon from Daubechies, H-T Wu, et al.
# gamma from Brevdo, H-T Wu, et al.
gamma = gamma or 10 * (EPS64 if Wx.dtype == np.cfloat else EPS32)
gamma = gamma or 10 * (EPS64 if Wx.dtype == np.complex128 else EPS32)
w[np.abs(Wx) < gamma] = np.inf

# see `phase_cwt`, though negatives may no longer be in minority
Expand Down
4 changes: 2 additions & 2 deletions ssqueezepy/ridge_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,8 @@ def fw_bw_ridge_tracking(energy_to_track, penalty_matrix, eps):

return ridge_idxs_fw_bw

eps = EPS64 if Tf.dtype == np.cfloat else EPS32
dtype = np.float64 if Tf.dtype == np.cfloat else np.float32
eps = EPS64 if Tf.dtype == np.complex128 else EPS32
dtype = np.float64 if Tf.dtype == np.complex128 else np.float32
scales, eps, penalty = [np.asarray(x, dtype=dtype)
for x in (scales, eps, penalty)]

Expand Down
6 changes: 3 additions & 3 deletions ssqueezepy/utils/cwt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ def infer_scaletype(scales):
if np.mean(np.abs(np.diff(np.log(scales), 2, axis=0))) < th_log:
scaletype = 'log'
# ceil to avoid faulty float-int roundoffs
nv = int(np.round(1 / np.diff(np.log2(scales), axis=0)[0]))
nv = int(np.round(1 / np.diff(np.log2(scales), axis=0)[0].squeeze()))

elif np.mean(np.abs(np.diff(scales, 2, axis=0))) < th_lin:
scaletype = 'linear'
Expand Down Expand Up @@ -620,11 +620,11 @@ def _integrate_near_zero():
# (.001 to .1 may not be negligible, however; better captured by logspace)
t = np.logspace(-15, -1, 1000)
arr = int_fn(t)
return integrate.trapz(arr, t)
return integrate.trapezoid(arr, t)

int_nz = _integrate_near_zero()
arr, t = _find_convergent_array()
return integrate.trapz(arr, t) + int_nz
return integrate.trapezoid(arr, t) + int_nz


def find_max_scale_alt(wavelet, N, min_cutoff=.1, max_cutoff=.8):
Expand Down
2 changes: 1 addition & 1 deletion ssqueezepy/utils/fft_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ def _process_input(self, x, axis, patience, real, inverse, n):
shape_out = self._get_output_shape(x, axis, real, inverse, n)

# dtypes
double = x.dtype in (np.float64, np.cfloat)
double = x.dtype in (np.float64, np.complex128)
cdtype = 'complex128' if double else 'complex64'
rdtype = 'float64' if double else 'float32'
dtype_in = rdtype if (real and not inverse) else cdtype
Expand Down
8 changes: 4 additions & 4 deletions ssqueezepy/utils/stft_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,8 +208,8 @@ def window_resolution(window):
apsi2 = np.abs(window)**2
apsih2s = np.abs(psihs)**2

var_w = integrate.trapz(ws**2 * apsih2s, ws) / integrate.trapz(apsih2s, ws)
var_t = integrate.trapz(t**2 * apsi2, t) / integrate.trapz(apsi2, t)
var_w = integrate.trapezoid(ws**2 * apsih2s, ws) / integrate.trapezoid(apsih2s, ws)
var_t = integrate.trapezoid(t**2 * apsi2, t) / integrate.trapezoid(apsi2, t)

std_w, std_t = np.sqrt(var_w), np.sqrt(var_t)
harea = std_w * std_t
Expand All @@ -226,11 +226,11 @@ def window_area(window, time=True, frequency=False):

if time:
t = np.arange(-len(window)/2, len(window)/2, step=1)
at = integrate.trapz(np.abs(window)**2, t)
at = integrate.trapezoid(np.abs(window)**2, t)
if frequency:
ws = fftshift(_xifn(1, len(window)))
apsih2s = np.abs(fftshift(fft(window)))**2
aw = integrate.trapz(apsih2s, ws)
aw = integrate.trapezoid(apsih2s, ws)

if time and frequency:
return at, aw
Expand Down
12 changes: 6 additions & 6 deletions ssqueezepy/wavelets.py
Original file line number Diff line number Diff line change
Expand Up @@ -703,8 +703,8 @@ def _energy_wc(wavelet, scale, N, force_int):
scale = (4/pi) * wc_ct

w, psih, apsih2 = _params(wavelet, scale, N)
wc = (integrate.trapz(apsih2 * w) /
integrate.trapz(apsih2))
wc = (integrate.trapezoid(apsih2 * w) /
integrate.trapezoid(apsih2))

if use_formula:
wc *= (scale / scale_orig)
Expand Down Expand Up @@ -794,8 +794,8 @@ def _viz():
wce = center_frequency(wavelet, scale, force_int=force_int, kind='energy')

apsih2 = np.abs(psih)**2
var_w = (integrate.trapz((w - wce)**2 * apsih2, w) /
integrate.trapz(apsih2, w))
var_w = (integrate.trapezoid((w - wce)**2 * apsih2, w) /
integrate.trapezoid(apsih2, w))

std_w = np.sqrt(var_w)
if use_formula:
Expand Down Expand Up @@ -910,8 +910,8 @@ def _make_integration_t(wavelet, scale, N, min_decay, max_mult, min_mult):
psi = asnumpy(ifft(psih * (-1)**np.arange(Nt)))

apsi2 = np.abs(psi)**2
var_t = (integrate.trapz(t**2 * apsi2, t) /
integrate.trapz(apsi2, t))
var_t = (integrate.trapezoid(t**2 * apsi2, t) /
integrate.trapezoid(apsi2, t))

std_t = np.sqrt(var_t)
if use_formula:
Expand Down
2 changes: 1 addition & 1 deletion tests/z_all_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ def test_trigdiff():
dWx = dWx[:, n1:n1+N]

mae = np.mean(np.abs(dWx - dWx2))
th = 1e-15 if dWx.dtype == np.cfloat else 1e-7
th = 1e-15 if dWx.dtype == np.complex128 else 1e-7
assert mae < th, mae


Expand Down

0 comments on commit fa4c720

Please sign in to comment.