Skip to content

Commit

Permalink
Fix: Upcast to np.float32 on gets, if bf16 is used
Browse files Browse the repository at this point in the history
  • Loading branch information
ashvardanian committed Aug 19, 2024
1 parent ca2111a commit 25b90f1
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 3 deletions.
2 changes: 1 addition & 1 deletion python/usearch/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def self_recall(index: Index, sample: Union[float, int] = 1.0, **kwargs) -> Sear
if "vectors" in kwargs:
vectors = kwargs.pop("vectors")
else:
vectors = index.get(keys, index.dtype)
vectors = index.get(keys)

matches = index.search(vectors, **kwargs)
count_matches: int = (
Expand Down
11 changes: 9 additions & 2 deletions python/usearch/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ def _normalize_dtype(


def _to_numpy_dtype(dtype: ScalarKind):
if dtype == ScalarKind.BF16:
return None
_normalize = {
ScalarKind.F64: np.float64,
ScalarKind.F32: np.float32,
Expand Down Expand Up @@ -738,10 +740,15 @@ def get(
"""
if not dtype:
dtype = self.dtype
view_dtype = _to_numpy_dtype(dtype)
if view_dtype is None:
dtype = ScalarKind.F32
view_dtype = np.float32
else:
dtype = _normalize_dtype(dtype)

view_dtype = _to_numpy_dtype(dtype)
view_dtype = _to_numpy_dtype(dtype)
if view_dtype is None:
raise NotImplementedError("The requested representation type is not supported by NumPy")

def cast(result):
if result is not None:
Expand Down

0 comments on commit 25b90f1

Please sign in to comment.