Skip to content

Commit

Permalink
Add: Swift & Obj-C bindings for filteredSearch
Browse files Browse the repository at this point in the history
Added a couple of tests in Swift to test both functions.

Issue: #470
  • Loading branch information
vardhan committed Aug 25, 2024
1 parent a2719b9 commit e8bf04c
Show file tree
Hide file tree
Showing 5 changed files with 241 additions and 14 deletions.
34 changes: 34 additions & 0 deletions objc/USearchObjective.mm
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,23 @@ - (UInt32)getSingle:(USearchKey)key
return static_cast<UInt32>(result);
}

- (UInt32)filteredSearchSingle:(Float32 const *_Nonnull)vector
count:(UInt32)wanted
filter:(USearchFilterFn)predicate
keys:(USearchKey *_Nullable)keys
distances:(Float32 *_Nullable)distances {
search_result_t result = _native->filtered_search(vector, static_cast<std::size_t>(wanted), predicate);

if (!result) {
@throw [NSException exceptionWithName:@"Can't find in index"
reason:[NSString stringWithUTF8String:result.error.release()]
userInfo:nil];
}

std::size_t found = result.dump_to(keys, distances);
return static_cast<UInt32>(found);
}

- (void)addDouble:(USearchKey)key
vector:(Float64 const *_Nonnull)vector {
add_result_t result = _native->add(key, (f64_t const *)vector);
Expand Down Expand Up @@ -215,6 +232,23 @@ - (UInt32)getDouble:(USearchKey)key
return static_cast<UInt32>(result);
}

- (UInt32)filteredSearchDouble:(Float64 const *_Nonnull)vector
count:(UInt32)wanted
filter:(USearchFilterFn)predicate
keys:(USearchKey *_Nullable)keys
distances:(Float32 *_Nullable)distances {
search_result_t result = _native->filtered_search((f64_t const *) vector, static_cast<std::size_t>(wanted), predicate);

if (!result) {
@throw [NSException exceptionWithName:@"Can't find in index"
reason:[NSString stringWithUTF8String:result.error.release()]
userInfo:nil];
}

std::size_t found = result.dump_to(keys, distances);
return static_cast<UInt32>(found);
}

- (void)addHalf:(USearchKey)key
vector:(void const *_Nonnull)vector {
add_result_t result = _native->add(key, (f16_t const *)vector);
Expand Down
28 changes: 18 additions & 10 deletions objc/include/USearchObjective.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ typedef NS_ENUM(NSUInteger, USearchMetric) {

typedef UInt64 USearchKey;

typedef bool (^USearchFilterFn)(USearchKey key);

API_AVAILABLE(ios(13.0), macos(10.15), tvos(13.0), watchos(6.0))
@interface USearchIndex : NSObject

Expand Down Expand Up @@ -104,6 +106,22 @@ API_AVAILABLE(ios(13.0), macos(10.15), tvos(13.0), watchos(6.0))
vector:(void *_Nonnull)vector
count:(UInt32)count NS_SWIFT_NAME(getSingle(key:vector:count:));

/**
* @brief Approximate nearest neighbors search.
* @param vector Double-precision query vector.
* @param count Upper limit on the number of matches to retrieve.
* @param filter Closure called for each key, determining whether to include or
* skip key in the results.
* @param keys Optional output buffer for keys of approximate neighbors.
* @param distances Optional output buffer for (increasing) distances to approximate neighbors.
* @return Number of matches exported to `keys` and `distances`.
*/
- (UInt32)filteredSearchSingle:(Float32 const *_Nonnull)vector
count:(UInt32)count
filter:(USearchFilterFn)filter
keys:(USearchKey *_Nullable)keys
distances:(Float32 *_Nullable)distances NS_SWIFT_NAME(filteredSearchSingle(vector:count:filter:keys:distances:));

/**
* @brief Adds a labeled vector to the index.
* @param vector Double-precision vector.
Expand All @@ -124,16 +142,6 @@ API_AVAILABLE(ios(13.0), macos(10.15), tvos(13.0), watchos(6.0))
keys:(USearchKey *_Nullable)keys
distances:(Float32 *_Nullable)distances NS_SWIFT_NAME(searchDouble(vector:count:keys:distances:));

/**
* @brief Retrieves a labeled double-precision vector from the index.
* @param vector A buffer to store the vector.
* @param count For multi-indexes, the number of vectors to retrieve.
* @return Number of vectors exported to `vector`.
*/
- (UInt32)getDouble:(USearchKey)key
vector:(void *_Nonnull)vector
count:(UInt32)count NS_SWIFT_NAME(getDouble(key:vector:count:));

/**
* @brief Adds a labeled vector to the index.
* @param vector Half-precision vector.
Expand Down
2 changes: 1 addition & 1 deletion simsimd
73 changes: 73 additions & 0 deletions swift/Index+Sugar.swift
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ extension USearchIndex {
public typealias Key = USearchKey
public typealias Metric = USearchMetric
public typealias Scalar = USearchScalar
/// Function type used to filter out keys in results during search.
/// The filter function should return true to include, and false to skip.
public typealias FilterFn = (Key) -> Bool

/// Adds a labeled vector to the index.
/// - Parameter key: Unique identifier for that object.
Expand Down Expand Up @@ -147,6 +150,76 @@ extension USearchIndex {
}
}

/// Approximate nearest neighbors search.
/// - Parameter vector: Single-precision query vector.
/// - Parameter count: Upper limit on the number of matches to retrieve.
/// - Parameter filter: Closure used to determine whether to skip a key in the results.
/// - Returns: Labels and distances to closest approximate matches in decreasing similarity order.
/// - Throws: If runs out of memory.
public func filteredSearch(vector: ArraySlice<Float32>, count: Int, filter: @escaping FilterFn) -> ([Key], [Float])
{
var matches: [Key] = Array(repeating: 0, count: count)
var distances: [Float] = Array(repeating: 0, count: count)
let results = vector.withContiguousStorageIfAvailable {
filteredSearchSingle(
vector: $0.baseAddress!,
count:
CUnsignedInt(count),
filter: filter,
keys: &matches,
distances: &distances
)
}
matches.removeLast(count - Int(results!))
distances.removeLast(count - Int(results!))
return (matches, distances)
}

/// Approximate nearest neighbors search.
/// - Parameter vector: Single-precision query vector.
/// - Parameter count: Upper limit on the number of matches to retrieve.
/// - Parameter filter: Closure used to determine whether to skip a key in the results.
/// - Returns: Labels and distances to closest approximate matches in decreasing similarity order.
/// - Throws: If runs out of memory.
public func filteredSearch(vector: [Float32], count: Int, filter: @escaping FilterFn) -> ([Key], [Float]) {
filteredSearch(vector: vector[...], count: count, filter: filter)
}

/// Approximate nearest neighbors search.
/// - Parameter vector: Double-precision query vector.
/// - Parameter count: Upper limit on the number of matches to retrieve.
/// - Parameter filter: Closure used to determine whether to skip a key in the results.
/// - Returns: Labels and distances to closest approximate matches in decreasing similarity order.
/// - Throws: If runs out of memory.
public func filteredSearch(vector: ArraySlice<Float64>, count: Int, filter: @escaping FilterFn) -> ([Key], [Float])
{
var matches: [Key] = Array(repeating: 0, count: count)
var distances: [Float] = Array(repeating: 0, count: count)
let results = vector.withContiguousStorageIfAvailable {
filteredSearchDouble(
vector: $0.baseAddress!,
count:
CUnsignedInt(count),
filter: filter,
keys: &matches,
distances: &distances
)
}
matches.removeLast(count - Int(results!))
distances.removeLast(count - Int(results!))
return (matches, distances)
}

/// Approximate nearest neighbors search.
/// - Parameter vector: Double-precision query vector.
/// - Parameter count: Upper limit on the number of matches to retrieve.
/// - Parameter filter: Closure used to determine whether to skip a key in the results.
/// - Returns: Labels and distances to closest approximate matches in decreasing similarity order.
/// - Throws: If runs out of memory.
public func filteredSearch(vector: [Float64], count: Int, filter: @escaping FilterFn) -> ([Key], [Float]) {
filteredSearch(vector: vector[...], count: count, filter: filter)
}

#if arch(arm64)

/// Adds a labeled vector to the index.
Expand Down
118 changes: 115 additions & 3 deletions swift/Test.swift
Original file line number Diff line number Diff line change
Expand Up @@ -104,19 +104,131 @@ class Test: XCTestCase {
index.add(key: 2, vector: [2.1])
index.add(key: 3, vector: [3.1])
XCTAssertEqual(index.count, 3)
XCTAssertEqual(index.search(vector: [1.0], count: 3).0, [1, 2, 3]) // works 😎
XCTAssertEqual(index.search(vector: [1.0], count: 3).0, [1, 2, 3]) // works 😎

// replace second-added entry then ensure all 3 are still returned
index.remove(key: 2)
index.add(key: 2, vector: [2.2])
XCTAssertEqual(index.count, 3)
XCTAssertEqual(index.search(vector: [1.0], count: 3).0, [1, 2, 3]) // works 😎
XCTAssertEqual(index.search(vector: [1.0], count: 3).0, [1, 2, 3]) // works 😎

// replace first-added entry then ensure all 3 are still returned
index.remove(key: 1)
index.add(key: 1, vector: [1.2])
let afterReplacingInitial = index.search(vector: [1.0], count: 3).0
XCTAssertEqual(index.count, 3)
XCTAssertEqual(afterReplacingInitial, [1, 2, 3]) // v2.11.7 fails with "[1] != [1, 2, 3]" 😨
XCTAssertEqual(afterReplacingInitial, [1, 2, 3]) // v2.11.7 fails with "[1] != [1, 2, 3]" 😨
}

func testFilteredSearchSingle() {
let index = USearchIndex.make(
metric: USearchMetric.l2sq,
dimensions: 1,
connectivity: 8,
quantization: USearchScalar.F32
)
index.reserve(3)

// add 3 entries
index.add(key: 1, vector: [1.1])
index.add(key: 2, vector: [2.1])
index.add(key: 3, vector: [3.1])
XCTAssertEqual(index.count, 3)

// filter which accepts all keys:
XCTAssertEqual(
index.filteredSearch(vector: [1.0], count: 3) {
key in true
}.0,
[1, 2, 3]
) // works 😎

// filter which rejects all keys:
XCTAssertEqual(
index.filteredSearch(vector: [1.0], count: 3) {
key in false
}.0,
[]
) // works 😎

// filter function accepts a set of keys passed in through a capture.
let acceptedKeys: [USearchKey] = [1, 2]
XCTAssertEqual(
index.filteredSearch(vector: [1.0], count: 3) {
key in acceptedKeys.contains(key)
}.0,
acceptedKeys
) // works 😎

// filter function accepts a set of keys passed in through a capture,
// and also adheres to the count.
XCTAssertEqual(
index.filteredSearch(vector: [1.0], count: 1) {
key in key > 1
}.0,
[2]
) // works 😎
XCTAssertEqual(
index.filteredSearch(vector: [1.0], count: 2) {
key in key > 1
}.0,
[2, 3]
) // works 😎
}

func testFilteredSearchDouble() {
let index = USearchIndex.make(
metric: USearchMetric.l2sq,
dimensions: 1,
connectivity: 8,
quantization: USearchScalar.F64
)
index.reserve(3)

// add 3 entries
index.add(key: 1, vector: [Float64(1.1)])
index.add(key: 2, vector: [Float64(2.1)])
index.add(key: 3, vector: [Float64(3.1)])
XCTAssertEqual(index.count, 3)

// filter which accepts all keys:
XCTAssertEqual(
index.filteredSearch(vector: [Float64(1.0)], count: 3) {
key in true
}.0,
[1, 2, 3]
) // works 😎

// filter which rejects all keys:
XCTAssertEqual(
index.filteredSearch(vector: [Float64(1.0)], count: 3) {
key in false
}.0,
[]
) // works 😎

// filter function accepts a set of keys passed in through a capture.
let acceptedKeys: [USearchKey] = [1, 2]
XCTAssertEqual(
index.filteredSearch(vector: [Float64(1.0)], count: 3) {
key in acceptedKeys.contains(key)
}.0,
acceptedKeys
) // works 😎

// filter function accepts a set of keys passed in through a capture,
// and also respects the count.
XCTAssertEqual(
index.filteredSearch(vector: [Float64(1.0)], count: 1) {
key in key > 1
}.0,
[2]
) // works 😎
XCTAssertEqual(
index.filteredSearch(vector: [Float64(1.0)], count: 2) {
key in key > 1
}.0,
[2, 3]
) // works 😎
}
}

0 comments on commit e8bf04c

Please sign in to comment.