Skip to content

Commit

Permalink
Fix bug in extra data
Browse files Browse the repository at this point in the history
Previous to this commit, when you tried to insert a single
number as extra_data, you'd encounter an exception. This fixes
that case.

Thanks to Gino Franco Fazzi at University of Copenhagen for reporting.
  • Loading branch information
brandonrobertz committed May 28, 2024
1 parent 11f3815 commit 731066a
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 3 deletions.
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@

setuptools.setup(
name='sparselsh',
version='2.1.1',
version='2.2.0',
author='Brandon Roberts',
author_email='[email protected]',
description='A locality sensitive hashing library with an emphasis on large, sparse datasets.',
long_description=long_description,
long_description_content_type="text/markdown",
url='https://github.com/brandonrobertz/sparselsh',
download_url='https://github.com/brandonrobertz/SparseLSH/releases/tag/v2.1.1',
download_url='https://github.com/brandonrobertz/SparseLSH/releases/tag/v2.2.0',
keywords=['clustering', 'sparse', 'lsh'],
packages=setuptools.find_packages(),
install_requires=[
Expand Down
2 changes: 1 addition & 1 deletion sparselsh/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@
__title__ = 'sparselsh'
__author__ = 'Brandon Roberts ([email protected])'
__license__ = 'MIT'
__version__ = '2.1.1'
__version__ = '2.2.0'

from sparselsh.lsh import LSH
2 changes: 2 additions & 0 deletions sparselsh/lsh.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,8 @@ def query(self, query_points, distance_func=None, dist_threshold=None,
# Sort extra_data by ranked distances
try:
extra_data_sorted = itemgetter(*list(indices))(extra_datas)
if isinstance(extra_data_sorted, (int, float)):
extra_data_sorted = [extra_data_sorted]
# we have no results, so no extra_datas
except TypeError:
extra_data_sorted = []
Expand Down
64 changes: 64 additions & 0 deletions tests/extra_data_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import unittest

import numpy as np
from scipy.sparse import csr_matrix

from sparselsh import LSH


class ExtraDataTestCase(unittest.TestCase):
"""
Regressions related to extra data on LSH indices.
"""
def test_accepts_same_dimensional_extra_data(self):
lsh = LSH(hash_size=256, input_dim=2)

data = csr_matrix([[0, 0], [100, 100], [200, 200]])
extra_data = [0.1, 0.2, 0.3]
lsh.index(data, extra_data=extra_data)
results = lsh.query(
csr_matrix([0, 0]), num_results=1)
(row, extra), dist = results[0]
self.assertEqual(extra, 0.1)

def test_rejects_bad_extra_data_dimensions(self):
lsh = LSH(hash_size=8, input_dim=2)
data = csr_matrix([[0, 0], [100, 100], [200, 200]])
extra_data = [0.1, 0.2, 0.3]
with self.assertRaises(AssertionError):
lsh.index(data, extra_data=extra_data[:-1])

def test_int_extra_data_regression(self):
"""
This tests the case where extra_data is a single
object, this caused an exception in versions <2.1.1
"""
lsh = LSH(hash_size=8, input_dim=2)
data = [[0, 0], [100, 100], [200, 200]]
for ix, point in enumerate(data):
x = csr_matrix(point)
lsh.index(x, extra_data=ix)

def test_string_extra_data_regression(self):
lsh = LSH(hash_size=8, input_dim=2)
data = [[0, 0], [100, 100], [200, 200]]
for ix, point in enumerate(data):
x = csr_matrix(point)
lsh.index(x, extra_data=str(ix))

def test_numpy_extra_data(self):
lsh = LSH(hash_size=8, input_dim=2)
data = [[0, 0], [100, 100], [200, 200]]
extra_datas = []
for ix, point in enumerate(data):
x = csr_matrix(point)
extra_datas.append(np.array([ix]))
lsh.index(x, extra_data=extra_datas[-1])
results = lsh.query(
csr_matrix(data[0]), num_results=1)
(row, extra), dist = results[0]
self.assertEqual(extra, extra_datas[0])


if __name__ == '__main__':
unittest.main()

0 comments on commit 731066a

Please sign in to comment.