Skip to content

Commit

Permalink
Moved many tests from notebook to test_resampler
Browse files Browse the repository at this point in the history
  • Loading branch information
hhoppe committed Feb 13, 2024
1 parent 8dc1d78 commit 598e642
Show file tree
Hide file tree
Showing 4 changed files with 2,181 additions and 3,474 deletions.
20 changes: 15 additions & 5 deletions resampler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import functools
import itertools
import math
import multiprocessing
import os
import sys
import types
import typing
Expand Down Expand Up @@ -237,7 +237,17 @@ def func(array: _NDArray) -> _NDArray:
print(f'Creating numba jit-wrapper for {signature}.')
jitted_function = numba.njit(func, parallel=True, fastmath=True, cache=True)
self._jitted_function[signature] = jitted_function
result = jitted_function(a)

try:
result = jitted_function(a)
except RuntimeError:
message = (
'resampler: This runtime error may be due to a corrupt resampler/__pycache__;'
' try deleting that directory.'
)
print(message, file=sys.stderr, flush=True)
raise

return result[..., 0] if array.ndim == 2 else result


Expand Down Expand Up @@ -439,7 +449,7 @@ def premult_with_sparse(
assert self.array.ndim == sparse.ndim == 2 and sparse.shape[1] == self.array.shape[0]
# Empirically faster than with default numba.config.NUMBA_NUM_THREADS (e.g., 24).
if using_numba:
num_threads2 = min(6, multiprocessing.cpu_count()) if num_threads == 'auto' else num_threads
num_threads2 = min(6, os.cpu_count()) if num_threads == 'auto' else num_threads
src = np.ascontiguousarray(self.array) # Like .ravel() in _mul_multivector().
dtype = np.result_type(sparse.dtype, src.dtype)
dst = np.empty((sparse.shape[0], src.shape[1]), dtype)
Expand Down Expand Up @@ -3794,14 +3804,14 @@ def _find_closest_filter(filter: str, resizer: Callable[..., Any]) -> str:
skimage_transform_resize: 'box',
tf_image_resize: 'trapezoid',
torch_nn_resize: 'trapezoid',
}.get(resize, 'box')
}.get(resizer, 'box')
case 'cubic_like':
return {
cv_resize: 'sharpcubic',
scipy_ndimage_resize: 'cardinal3',
skimage_transform_resize: 'cardinal3',
torch_nn_resize: 'sharpcubic',
}.get(resize, 'cubic')
}.get(resizer, 'cubic')
case 'high_quality':
return {
pil_image_resize: 'lanczos3',
Expand Down
Loading

0 comments on commit 598e642

Please sign in to comment.