Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add: Swift & Obj-C bindings for filteredSearch #471

Merged
merged 1 commit into from
Aug 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions objc/USearchObjective.mm
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,23 @@ - (UInt32)getSingle:(USearchKey)key
return static_cast<UInt32>(result);
}

- (UInt32)filteredSearchSingle:(Float32 const *_Nonnull)vector
count:(UInt32)wanted
filter:(USearchFilterFn)predicate
keys:(USearchKey *_Nullable)keys
distances:(Float32 *_Nullable)distances {
search_result_t result = _native->filtered_search(vector, static_cast<std::size_t>(wanted), predicate);

if (!result) {
@throw [NSException exceptionWithName:@"Can't find in index"
reason:[NSString stringWithUTF8String:result.error.release()]
userInfo:nil];
}

std::size_t found = result.dump_to(keys, distances);
return static_cast<UInt32>(found);
}

- (void)addDouble:(USearchKey)key
vector:(Float64 const *_Nonnull)vector {
add_result_t result = _native->add(key, (f64_t const *)vector);
Expand Down Expand Up @@ -215,6 +232,23 @@ - (UInt32)getDouble:(USearchKey)key
return static_cast<UInt32>(result);
}

- (UInt32)filteredSearchDouble:(Float64 const *_Nonnull)vector
count:(UInt32)wanted
filter:(USearchFilterFn)predicate
keys:(USearchKey *_Nullable)keys
distances:(Float32 *_Nullable)distances {
search_result_t result = _native->filtered_search((f64_t const *) vector, static_cast<std::size_t>(wanted), predicate);

if (!result) {
@throw [NSException exceptionWithName:@"Can't find in index"
reason:[NSString stringWithUTF8String:result.error.release()]
userInfo:nil];
}

std::size_t found = result.dump_to(keys, distances);
return static_cast<UInt32>(found);
}

- (void)addHalf:(USearchKey)key
vector:(void const *_Nonnull)vector {
add_result_t result = _native->add(key, (f16_t const *)vector);
Expand Down
33 changes: 33 additions & 0 deletions objc/include/USearchObjective.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ typedef NS_ENUM(NSUInteger, USearchMetric) {

typedef UInt64 USearchKey;

typedef bool (^USearchFilterFn)(USearchKey key);

API_AVAILABLE(ios(13.0), macos(10.15), tvos(13.0), watchos(6.0))
@interface USearchIndex : NSObject

Expand Down Expand Up @@ -104,6 +106,22 @@ API_AVAILABLE(ios(13.0), macos(10.15), tvos(13.0), watchos(6.0))
vector:(void *_Nonnull)vector
count:(UInt32)count NS_SWIFT_NAME(getSingle(key:vector:count:));

/**
* @brief Approximate nearest neighbors search.
* @param vector Double-precision query vector.
* @param count Upper limit on the number of matches to retrieve.
* @param filter Closure called for each key, determining whether to include or
* skip key in the results.
* @param keys Optional output buffer for keys of approximate neighbors.
* @param distances Optional output buffer for (increasing) distances to approximate neighbors.
* @return Number of matches exported to `keys` and `distances`.
*/
- (UInt32)filteredSearchSingle:(Float32 const *_Nonnull)vector
count:(UInt32)count
filter:(USearchFilterFn)filter
keys:(USearchKey *_Nullable)keys
distances:(Float32 *_Nullable)distances NS_SWIFT_NAME(filteredSearchSingle(vector:count:filter:keys:distances:));

/**
* @brief Adds a labeled vector to the index.
* @param vector Double-precision vector.
Expand Down Expand Up @@ -134,6 +152,21 @@ API_AVAILABLE(ios(13.0), macos(10.15), tvos(13.0), watchos(6.0))
vector:(void *_Nonnull)vector
count:(UInt32)count NS_SWIFT_NAME(getDouble(key:vector:count:));

/**
* @brief Approximate nearest neighbors search.
* @param vector Double-precision query vector.
* @param count Upper limit on the number of matches to retrieve.
* @param filter Closure called for each key, determining whether to include or
* skip key in the results.
* @param keys Optional output buffer for keys of approximate neighbors.
* @param distances Optional output buffer for (increasing) distances to approximate neighbors.
* @return Number of matches exported to `keys` and `distances`.
*/
- (UInt32)filteredSearchDouble:(Float64 const *_Nonnull)vector
count:(UInt32)wanted
filter:(USearchFilterFn)predicate
keys:(USearchKey *_Nullable)keys
distances:(Float32 *_Nullable)distances NS_SWIFT_NAME(filteredSearchDouble(vector:count:filter:keys:distances:));
/**
* @brief Adds a labeled vector to the index.
* @param vector Half-precision vector.
Expand Down
73 changes: 73 additions & 0 deletions swift/Index+Sugar.swift
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ extension USearchIndex {
public typealias Key = USearchKey
public typealias Metric = USearchMetric
public typealias Scalar = USearchScalar
/// Function type used to filter out keys in results during search.
/// The filter function should return true to include, and false to skip.
public typealias FilterFn = (Key) -> Bool

/// Adds a labeled vector to the index.
/// - Parameter key: Unique identifier for that object.
Expand Down Expand Up @@ -147,6 +150,76 @@ extension USearchIndex {
}
}

/// Approximate nearest neighbors search.
/// - Parameter vector: Single-precision query vector.
/// - Parameter count: Upper limit on the number of matches to retrieve.
/// - Parameter filter: Closure used to determine whether to skip a key in the results.
/// - Returns: Labels and distances to closest approximate matches in decreasing similarity order.
/// - Throws: If runs out of memory.
public func filteredSearch(vector: ArraySlice<Float32>, count: Int, filter: @escaping FilterFn) -> ([Key], [Float])
{
var matches: [Key] = Array(repeating: 0, count: count)
var distances: [Float] = Array(repeating: 0, count: count)
let results = vector.withContiguousStorageIfAvailable {
filteredSearchSingle(
vector: $0.baseAddress!,
count:
CUnsignedInt(count),
filter: filter,
keys: &matches,
distances: &distances
)
}
matches.removeLast(count - Int(results!))
distances.removeLast(count - Int(results!))
return (matches, distances)
}

/// Approximate nearest neighbors search.
/// - Parameter vector: Single-precision query vector.
/// - Parameter count: Upper limit on the number of matches to retrieve.
/// - Parameter filter: Closure used to determine whether to skip a key in the results.
/// - Returns: Labels and distances to closest approximate matches in decreasing similarity order.
/// - Throws: If runs out of memory.
public func filteredSearch(vector: [Float32], count: Int, filter: @escaping FilterFn) -> ([Key], [Float]) {
filteredSearch(vector: vector[...], count: count, filter: filter)
}

/// Approximate nearest neighbors search.
/// - Parameter vector: Double-precision query vector.
/// - Parameter count: Upper limit on the number of matches to retrieve.
/// - Parameter filter: Closure used to determine whether to skip a key in the results.
/// - Returns: Labels and distances to closest approximate matches in decreasing similarity order.
/// - Throws: If runs out of memory.
public func filteredSearch(vector: ArraySlice<Float64>, count: Int, filter: @escaping FilterFn) -> ([Key], [Float])
{
var matches: [Key] = Array(repeating: 0, count: count)
var distances: [Float] = Array(repeating: 0, count: count)
let results = vector.withContiguousStorageIfAvailable {
filteredSearchDouble(
vector: $0.baseAddress!,
count:
CUnsignedInt(count),
filter: filter,
keys: &matches,
distances: &distances
)
}
matches.removeLast(count - Int(results!))
distances.removeLast(count - Int(results!))
return (matches, distances)
}

/// Approximate nearest neighbors search.
/// - Parameter vector: Double-precision query vector.
/// - Parameter count: Upper limit on the number of matches to retrieve.
/// - Parameter filter: Closure used to determine whether to skip a key in the results.
/// - Returns: Labels and distances to closest approximate matches in decreasing similarity order.
/// - Throws: If runs out of memory.
public func filteredSearch(vector: [Float64], count: Int, filter: @escaping FilterFn) -> ([Key], [Float]) {
filteredSearch(vector: vector[...], count: count, filter: filter)
}

#if arch(arm64)

/// Adds a labeled vector to the index.
Expand Down
118 changes: 115 additions & 3 deletions swift/Test.swift
Original file line number Diff line number Diff line change
Expand Up @@ -104,19 +104,131 @@ class Test: XCTestCase {
index.add(key: 2, vector: [2.1])
index.add(key: 3, vector: [3.1])
XCTAssertEqual(index.count, 3)
XCTAssertEqual(index.search(vector: [1.0], count: 3).0, [1, 2, 3]) // works 😎
XCTAssertEqual(index.search(vector: [1.0], count: 3).0, [1, 2, 3]) // works 😎

// replace second-added entry then ensure all 3 are still returned
index.remove(key: 2)
index.add(key: 2, vector: [2.2])
XCTAssertEqual(index.count, 3)
XCTAssertEqual(index.search(vector: [1.0], count: 3).0, [1, 2, 3]) // works 😎
XCTAssertEqual(index.search(vector: [1.0], count: 3).0, [1, 2, 3]) // works 😎

// replace first-added entry then ensure all 3 are still returned
index.remove(key: 1)
index.add(key: 1, vector: [1.2])
let afterReplacingInitial = index.search(vector: [1.0], count: 3).0
XCTAssertEqual(index.count, 3)
XCTAssertEqual(afterReplacingInitial, [1, 2, 3]) // v2.11.7 fails with "[1] != [1, 2, 3]" 😨
XCTAssertEqual(afterReplacingInitial, [1, 2, 3]) // v2.11.7 fails with "[1] != [1, 2, 3]" 😨
}

func testFilteredSearchSingle() {
let index = USearchIndex.make(
metric: USearchMetric.l2sq,
dimensions: 1,
connectivity: 8,
quantization: USearchScalar.F32
)
index.reserve(3)

// add 3 entries
index.add(key: 1, vector: [1.1])
index.add(key: 2, vector: [2.1])
index.add(key: 3, vector: [3.1])
XCTAssertEqual(index.count, 3)

// filter which accepts all keys:
XCTAssertEqual(
index.filteredSearch(vector: [1.0], count: 3) {
key in true
}.0,
[1, 2, 3]
) // works 😎

// filter which rejects all keys:
XCTAssertEqual(
index.filteredSearch(vector: [1.0], count: 3) {
key in false
}.0,
[]
) // works 😎

// filter function accepts a set of keys passed in through a capture.
let acceptedKeys: [USearchKey] = [1, 2]
XCTAssertEqual(
index.filteredSearch(vector: [1.0], count: 3) {
key in acceptedKeys.contains(key)
}.0,
acceptedKeys
) // works 😎

// filter function accepts a set of keys passed in through a capture,
// and also adheres to the count.
XCTAssertEqual(
index.filteredSearch(vector: [1.0], count: 1) {
key in key > 1
}.0,
[2]
) // works 😎
XCTAssertEqual(
index.filteredSearch(vector: [1.0], count: 2) {
key in key > 1
}.0,
[2, 3]
) // works 😎
}

func testFilteredSearchDouble() {
let index = USearchIndex.make(
metric: USearchMetric.l2sq,
dimensions: 1,
connectivity: 8,
quantization: USearchScalar.F64
)
index.reserve(3)

// add 3 entries
index.add(key: 1, vector: [Float64(1.1)])
index.add(key: 2, vector: [Float64(2.1)])
index.add(key: 3, vector: [Float64(3.1)])
XCTAssertEqual(index.count, 3)

// filter which accepts all keys:
XCTAssertEqual(
index.filteredSearch(vector: [Float64(1.0)], count: 3) {
key in true
}.0,
[1, 2, 3]
) // works 😎

// filter which rejects all keys:
XCTAssertEqual(
index.filteredSearch(vector: [Float64(1.0)], count: 3) {
key in false
}.0,
[]
) // works 😎

// filter function accepts a set of keys passed in through a capture.
let acceptedKeys: [USearchKey] = [1, 2]
XCTAssertEqual(
index.filteredSearch(vector: [Float64(1.0)], count: 3) {
key in acceptedKeys.contains(key)
}.0,
acceptedKeys
) // works 😎

// filter function accepts a set of keys passed in through a capture,
// and also respects the count.
XCTAssertEqual(
index.filteredSearch(vector: [Float64(1.0)], count: 1) {
key in key > 1
}.0,
[2]
) // works 😎
XCTAssertEqual(
index.filteredSearch(vector: [Float64(1.0)], count: 2) {
key in key > 1
}.0,
[2, 3]
) // works 😎
}
}
Loading