Skip to content

Commit

Permalink
Added implementation for SuppressBlankFilter (#18)
Browse files Browse the repository at this point in the history
* addded SuppressBlankFilter

* added performance optimisation and tests
  • Loading branch information
jkrukowski authored Feb 14, 2024
1 parent 1f64dc7 commit 5f3353a
Show file tree
Hide file tree
Showing 6 changed files with 79 additions and 27 deletions.
28 changes: 11 additions & 17 deletions Sources/WhisperKit/Core/LogitsFilter.swift
Original file line number Diff line number Diff line change
Expand Up @@ -12,42 +12,36 @@ public protocol LogitsFiltering {
@available(macOS 14, iOS 17, watchOS 10, visionOS 1, *)
public class SuppressTokensFilter: LogitsFiltering {
let suppressTokens: [Int]
private let tokenIndexes: [[NSNumber]]
private let suppressTokenIndexes: [[NSNumber]]

public init(suppressTokens: [Int]) {
self.suppressTokens = suppressTokens
self.tokenIndexes = suppressTokens.map { [0, 0, $0 as NSNumber] }
self.suppressTokenIndexes = suppressTokens.map { [0, 0, $0 as NSNumber] }
}

public func filterLogits(_ logits: MLMultiArray, withTokens tokens: [Int]) -> MLMultiArray {
let pointer = UnsafeMutablePointer<FloatType>(OpaquePointer(logits.dataPointer))
for index in tokenIndexes {
let linearOffset = logits.linearOffset(for: index)
pointer[linearOffset] = -FloatType.infinity
}
logits.fill(indexes: suppressTokenIndexes, with: -FloatType.infinity)
return logits
}
}

@available(macOS 14, iOS 17, watchOS 10, visionOS 1, *)
public class SuppressBlankFilter: LogitsFiltering {
let tokenizer: Tokenizer
let suppressBlankTokens: [Int]
let sampleBegin: Int
private let suppressTokenIndexes: [[NSNumber]]

public init(tokenizer: Tokenizer, sampleBegin: Int) {
self.tokenizer = tokenizer
public init(suppressBlankTokens: [Int], sampleBegin: Int) {
self.suppressBlankTokens = suppressBlankTokens
self.sampleBegin = sampleBegin
// TODO: implement
fatalError("Not implemented: \(#function)")
self.suppressTokenIndexes = suppressBlankTokens.map { [0, 0, $0 as NSNumber] }
}

public func filterLogits(_ logits: MLMultiArray, withTokens tokens: [Int]) -> MLMultiArray {
if tokens.count == sampleBegin {
if let blankToken = tokenizer.convertTokenToId(" ") {
Logging.debug(blankToken)
}
// TODO: implement
guard tokens.count == sampleBegin else {
return logits
}
logits.fill(indexes: suppressTokenIndexes, with: -FloatType.infinity)
return logits
}
}
Expand Down
2 changes: 2 additions & 0 deletions Sources/WhisperKit/Core/Models.swift
Original file line number Diff line number Diff line change
Expand Up @@ -771,6 +771,7 @@ public class TextDecoderCachePrefillOutput: MLFeatureProvider {
// MARK: Tokenizer

public extension Tokenizer {
var whitespaceToken: Int { convertTokenToId(" ") ?? Self.defaultWhitespaceToken }
var specialTokenBegin: Int { convertTokenToId("<|endoftext|>") ?? Self.defaultSpecialTokenBegin }
var endToken: Int { convertTokenToId("<|endoftext|>") ?? Self.defaultEndToken }
var startOfTranscriptToken: Int { convertTokenToId("<|startoftranscript|>") ?? Self.defaultStartOfTranscriptToken }
Expand All @@ -782,6 +783,7 @@ public extension Tokenizer {
var timeTokenBegin: Int { convertTokenToId("<|0.00|>") ?? Self.defaultTimeTokenBegin }

// Default values for each token, using base vocab
internal static var defaultWhitespaceToken: Int { 50257 }
internal static var defaultSpecialTokenBegin: Int { 50257 }
internal static var defaultEndToken: Int { 50257 }
internal static var defaultStartOfTranscriptToken: Int { 50258 }
Expand Down
8 changes: 6 additions & 2 deletions Sources/WhisperKit/Core/TextDecoder.swift
Original file line number Diff line number Diff line change
Expand Up @@ -328,8 +328,12 @@ public class TextDecoder: TextDecoding, WhisperMLModel {

var logitsFilters: [any LogitsFiltering] = []
if options.suppressBlank {
// TODO: implement
logitsFilters.append(SuppressBlankFilter(tokenizer: tokenizer, sampleBegin: prefilledIndex))
logitsFilters.append(
SuppressBlankFilter(
suppressBlankTokens: [tokenizer.whitespaceToken, tokenizer.endToken],
sampleBegin: prefilledIndex
)
)
}

if !options.supressTokens.isEmpty {
Expand Down
26 changes: 19 additions & 7 deletions Sources/WhisperKit/Core/Utils.swift
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,29 @@ import Tokenizers
// MARK: - Helpers

extension MLMultiArray {
/// Calculate the linear offset by summing the products of each dimension’s
/// index with the dimension’s stride
///
/// More info here: https://developer.apple.com/documentation/coreml/mlmultiarray/2879231-subscript
func linearOffset(for index: [NSNumber]) -> Int {
/// Calculate the linear offset by summing the products of each dimension’s index with the dimension’s stride.
/// More info [here](https://developer.apple.com/documentation/coreml/mlmultiarray/2879231-subscript)
/// - Parameters:
/// - index: The index of the element
/// - strides: The precomputed strides of the multi-array, if not provided, it will be computed. It's a performance optimization to avoid recomputing the strides every time when accessing the multi-array with multiple indexes.
@inline(__always)
func linearOffset(for index: [NSNumber], strides strideInts: [Int]? = nil) -> Int {
var linearOffset = 0
for (dimension, stride) in zip(index, strides) {
linearOffset += dimension.intValue * stride.intValue
let strideInts = strideInts ?? strides.map { $0.intValue }
for (dimension, stride) in zip(index, strideInts) {
linearOffset += dimension.intValue * stride
}
return linearOffset
}

func fill<Value>(indexes: [[NSNumber]], with value: Value) {
let pointer = UnsafeMutablePointer<Value>(OpaquePointer(dataPointer))
let strideInts = strides.map { $0.intValue }
for index in indexes {
let linearOffset = linearOffset(for: index, strides: strideInts)
pointer[linearOffset] = value
}
}
}

func initMLMultiArray(shape: [NSNumber], dataType: MLMultiArrayDataType, initialValue: Any) -> MLMultiArray {
Expand Down
2 changes: 1 addition & 1 deletion Sources/WhisperKit/Core/WhisperKit.swift
Original file line number Diff line number Diff line change
Expand Up @@ -540,7 +540,7 @@ public class WhisperKit {

// add them to the `allSegments` list
allSegments.append(contentsOf: currentSegments)
let allCurrentTokens = currentSegments.reduce([]) { $0 + $1.tokens }
let allCurrentTokens = currentSegments.flatMap { $0.tokens }
allTokens.append(contentsOf: allCurrentTokens)

timings.decodingWindowing += Date().timeIntervalSince(windowingStart)
Expand Down
40 changes: 40 additions & 0 deletions Tests/WhisperKitTests/UnitTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,19 @@ final class UnitTests: XCTestCase {
XCTAssertEqual(resultFull.segments.first?.end, resultSeek.segments.first?.end, "Segments should have the same end time")
}

// MARK: - Utils Tests

func testFillIndexesWithValue() throws {
let logits = try MLMultiArray.logits([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7])
logits.fill(indexes: [], with: -FloatType.infinity)
XCTAssertEqual(logits.data(for: 2), [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7])

let logits2 = try MLMultiArray.logits([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7])
let indexes2: [[NSNumber]] = [[0, 0, 0], [0, 0, 1], [0, 0, 5]]
logits2.fill(indexes: indexes2, with: -FloatType.infinity)
XCTAssertEqual(logits2.data(for: 2), [-.infinity, -.infinity, 0.3, 0.4, 0.5, -.infinity, 0.7])
}

// MARK: - LogitsFilter Tests

func testSuppressTokensFilter() throws {
Expand All @@ -469,6 +482,33 @@ final class UnitTests: XCTestCase {
let result3 = tokensFilter3.filterLogits(logits3, withTokens: [])
XCTAssertEqual(result3.data(for: 2), [-.infinity, 0.2, -.infinity, 0.4, 0.5, -.infinity, -.infinity])
}

func testSuppressBlankFilter() throws {
let tokensFilter1 = SuppressBlankFilter(suppressBlankTokens: [], sampleBegin: 0)
let logits1 = try MLMultiArray.logits([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7])
let result1 = tokensFilter1.filterLogits(logits1, withTokens: [])
XCTAssertEqual(result1.data(for: 2), [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7])

let tokensFilter2 = SuppressBlankFilter(suppressBlankTokens: [0], sampleBegin: 0)
let logits2 = try MLMultiArray.logits([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7])
let result2 = tokensFilter2.filterLogits(logits2, withTokens: [])
XCTAssertEqual(result2.data(for: 2), [-.infinity, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7])

let tokensFilter3 = SuppressBlankFilter(suppressBlankTokens: [0, 2, 6], sampleBegin: 0)
let logits3 = try MLMultiArray.logits([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7])
let result3 = tokensFilter3.filterLogits(logits3, withTokens: [])
XCTAssertEqual(result3.data(for: 2), [-.infinity, 0.2, -.infinity, 0.4, 0.5, 0.6, -.infinity])

let tokensFilter4 = SuppressBlankFilter(suppressBlankTokens: [0, 2, 6], sampleBegin: 3)
let logits4 = try MLMultiArray.logits([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7])
let result4 = tokensFilter4.filterLogits(logits4, withTokens: [1, 2, 3])
XCTAssertEqual(result4.data(for: 2), [-.infinity, 0.2, -.infinity, 0.4, 0.5, 0.6, -.infinity])

let tokensFilter5 = SuppressBlankFilter(suppressBlankTokens: [0, 2, 6], sampleBegin: 5)
let logits5 = try MLMultiArray.logits([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7])
let result5 = tokensFilter5.filterLogits(logits5, withTokens: [1, 2, 3])
XCTAssertEqual(result5.data(for: 2), [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7])
}
}

// MARK: Helpers
Expand Down

0 comments on commit 5f3353a

Please sign in to comment.