Skip to content

Commit

Permalink
Audio input length from CoreML metadata
Browse files Browse the repository at this point in the history
Add arbitrary length audio
  • Loading branch information
EduardoPach authored and ZachNagengast committed Dec 19, 2024
1 parent 2ed122e commit f63313f
Show file tree
Hide file tree
Showing 6 changed files with 27 additions and 10 deletions.
1 change: 0 additions & 1 deletion Sources/WhisperKit/Core/Audio/AudioProcessor.swift
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,6 @@ public class AudioProcessor: NSObject, AudioProcessing {
}

public var audioBufferCallback: (([Float]) -> Void)?
public var maxBufferLength = WhisperKit.sampleRate * WhisperKit.chunkLength // 30 seconds of audio at 16,000 Hz
public var minBufferLength = Int(Double(WhisperKit.sampleRate) * 0.1) // 0.1 second of audio at 16,000 Hz

// MARK: - Loading and conversion
Expand Down
10 changes: 10 additions & 0 deletions Sources/WhisperKit/Core/FeatureExtractor.swift
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import Foundation

public protocol FeatureExtracting {
var melCount: Int? { get }
var windowSamples: Int? { get }
func logMelSpectrogram(fromAudio inputAudio: MLMultiArray) async throws -> MLMultiArray?
}

Expand All @@ -26,6 +27,14 @@ open class FeatureExtractor: FeatureExtracting, WhisperMLModel {
return shape[1]
}

public var windowSamples: Int? {
guard let inputDescription = model?.modelDescription.inputDescriptionsByName["audio"] else { return nil }
guard inputDescription.type == .multiArray else { return nil }
guard let shapeConstraint = inputDescription.multiArrayConstraint else { return nil }
let shape = shapeConstraint.shape.map { $0.intValue }
return shape[0] // The audio input is a 1D array
}

public func logMelSpectrogram(fromAudio inputAudio: MLMultiArray) async throws -> MLMultiArray? {
guard let model else {
throw WhisperError.modelsUnavailable()
Expand All @@ -40,4 +49,5 @@ open class FeatureExtractor: FeatureExtracting, WhisperMLModel {
let output = MelSpectrogramOutput(features: outputFeatures)
return output.melspectrogramFeatures
}

}
2 changes: 2 additions & 0 deletions Sources/WhisperKit/Core/Models.swift
Original file line number Diff line number Diff line change
Expand Up @@ -1509,6 +1509,8 @@ public enum Constants {

public static let defaultAudioReadFrameSize: AVAudioFrameCount = 1_323_000 // 30s of audio at commonly found 44.1khz sample rate

public static let defaultWindowSamples: Int = 480_000 // 30s of audio at 16khz sample rate default for Whisper models

public static let fallbackModelSupportConfig: ModelSupportConfig = {
var config = ModelSupportConfig(
repoName: "whisperkit-coreml-fallback",
Expand Down
5 changes: 3 additions & 2 deletions Sources/WhisperKit/Core/TranscribeTask.swift
Original file line number Diff line number Diff line change
Expand Up @@ -109,17 +109,18 @@ final class TranscribeTask {
let previousSeekProgress = progress.completedUnitCount

let windowPadding = 16000 // prevent hallucinations at the end of the clip by stopping up to 1.0s early
let windowSamples = featureExtractor.windowSamples ?? Constants.defaultWindowSamples
while seek < seekClipEnd - windowPadding {
// calculate new encoder segment features
let timeOffset = Float(seek) / Float(WhisperKit.sampleRate)
let segmentSize = min(WhisperKit.windowSamples, contentFrames - seek, seekClipEnd - seek)
let segmentSize = min(windowSamples, contentFrames - seek, seekClipEnd - seek)
let timeOffsetEnd = Float(seek + segmentSize) / Float(WhisperKit.sampleRate)
Logging.debug("Decoding Seek: \(seek) (\(formatTimestamp(timeOffset))s)")
Logging.debug("Decoding Window Size: \(segmentSize) (\(formatTimestamp(timeOffsetEnd - timeOffset))s)")

let audioProcessingStart = Date()
let clipAudioSamples = Array(audioArray[seek..<(seek + segmentSize)])
guard let audioSamples = AudioProcessor.padOrTrimAudio(fromArray: clipAudioSamples, startAt: 0, toLength: WhisperKit.windowSamples) else {
guard let audioSamples = AudioProcessor.padOrTrimAudio(fromArray: clipAudioSamples, startAt: 0, toLength: windowSamples) else {
throw WhisperError.transcriptionFailed("Audio samples are nil")
}
let processTime = Date().timeIntervalSince(audioProcessingStart)
Expand Down
13 changes: 8 additions & 5 deletions Sources/WhisperKit/Core/WhisperKit.swift
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@ open class WhisperKit {
/// Shapes
public static let sampleRate: Int = 16000
public static let hopLength: Int = 160
public static let chunkLength: Int = 30 // seconds
public static let windowSamples: Int = 480_000 // sampleRate * chunkLength
public static let secondsPerTimeToken = Float(0.02)

/// Progress
Expand Down Expand Up @@ -73,6 +71,7 @@ open class WhisperKit {
modelFolder: config.modelFolder,
download: config.download
)


if let prewarm = config.prewarm, prewarm {
Logging.info("Prewarming models...")
Expand Down Expand Up @@ -501,7 +500,11 @@ open class WhisperKit {
decoderInputs.decoderKeyPaddingMask[0] = 0.0

// Detect language using up to the first 30 seconds
guard let audioSamples = AudioProcessor.padOrTrimAudio(fromArray: audioArray, startAt: 0, toLength: WhisperKit.windowSamples) else {
guard let audioSamples = AudioProcessor.padOrTrimAudio(
fromArray: audioArray,
startAt: 0,
toLength: featureExtractor.windowSamples ?? Constants.defaultWindowSamples
) else {
throw WhisperError.transcriptionFailed("Audio samples are nil")
}
guard let melOutput = try await featureExtractor.logMelSpectrogram(fromAudio: audioSamples) else {
Expand Down Expand Up @@ -809,15 +812,15 @@ open class WhisperKit {
var transcribeResults = [TranscriptionResult]()

// Determine if the audio array requires chunking
let isChunkable = audioArray.count > WhisperKit.windowSamples
let isChunkable = audioArray.count > featureExtractor.windowSamples ?? Constants.defaultWindowSamples
switch (isChunkable, decodeOptions?.chunkingStrategy) {
case (true, .vad):
// We have some audio that will require multiple windows and a strategy to chunk them
let vad = voiceActivityDetector ?? EnergyVAD()
let chunker = VADAudioChunker(vad: vad)
let audioChunks: [AudioChunk] = try await chunker.chunkAll(
audioArray: audioArray,
maxChunkLength: WhisperKit.windowSamples,
maxChunkLength: featureExtractor.windowSamples ?? Constants.defaultWindowSamples,
decodeOptions: decodeOptions
)

Expand Down
6 changes: 4 additions & 2 deletions Tests/WhisperKitTests/UnitTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -1420,6 +1420,8 @@ final class UnitTests: XCTestCase {

func testVADAudioChunker() async throws {
let chunker = VADAudioChunker()
// Setting windowSamples to default value as WhisperKit.windowSamples is not accessible in this scope
let windowSamples: Int = 480_000

let singleChunkPath = try XCTUnwrap(
Bundle.current.path(forResource: "jfk", ofType: "wav"),
Expand All @@ -1430,7 +1432,7 @@ final class UnitTests: XCTestCase {

var audioChunks = try await chunker.chunkAll(
audioArray: audioArray,
maxChunkLength: WhisperKit.windowSamples,
maxChunkLength: windowSamples,
decodeOptions: DecodingOptions()
)

Expand All @@ -1445,7 +1447,7 @@ final class UnitTests: XCTestCase {

audioChunks = try await chunker.chunkAll(
audioArray: audioArray,
maxChunkLength: WhisperKit.windowSamples,
maxChunkLength: windowSamples,
decodeOptions: DecodingOptions()
)

Expand Down

0 comments on commit f63313f

Please sign in to comment.