Skip to content

Commit d74fb1a

Browse files
committed
added AudioWarper to allow streaming in cli
1 parent 228630c commit d74fb1a

File tree

7 files changed

+316
-49
lines changed

7 files changed

+316
-49
lines changed

Examples/WhisperAX/WhisperAX/Views/ContentView.swift

+1-21
Original file line numberDiff line numberDiff line change
@@ -681,26 +681,6 @@ struct ContentView: View {
681681
}
682682
}
683683

684-
func requestMicrophoneIfNeeded() async -> Bool {
685-
let microphoneStatus = AVCaptureDevice.authorizationStatus(for: .audio)
686-
687-
switch microphoneStatus {
688-
case .notDetermined:
689-
return await withCheckedContinuation { continuation in
690-
AVCaptureDevice.requestAccess(for: .audio) { granted in
691-
continuation.resume(returning: granted)
692-
}
693-
}
694-
case .restricted, .denied:
695-
print("Microphone access denied")
696-
return false
697-
case .authorized:
698-
return true
699-
@unknown default:
700-
fatalError("Unknown authorization status")
701-
}
702-
}
703-
704684
func loadModel(_ model: String, redownload: Bool = false) {
705685
print("Selected Model: \(UserDefaults.standard.string(forKey: "selectedModel") ?? "nil")")
706686

@@ -872,7 +852,7 @@ struct ContentView: View {
872852
func startRecording(_ loop: Bool) {
873853
if let audioProcessor = whisperKit?.audioProcessor {
874854
Task(priority: .userInitiated) {
875-
guard await requestMicrophoneIfNeeded() else {
855+
guard await AudioProcessor.requestMicrophoneIfNeeded() else {
876856
print("Microphone access was not granted.")
877857
return
878858
}

Examples/WhisperAX/WhisperAXWatchApp/WhisperAXExampleView.swift

-20
Original file line numberDiff line numberDiff line change
@@ -665,26 +665,6 @@ struct WhisperAXWatchView: View {
665665
}
666666
}
667667
}
668-
669-
// func requestMicrophoneIfNeeded() async -> Bool {
670-
// let microphoneStatus = AVCaptureDevice.authorizationStatus(for: .audio)
671-
//
672-
// switch microphoneStatus {
673-
// case .notDetermined:
674-
// return await withCheckedContinuation { continuation in
675-
// AVCaptureDevice.requestAccess(for: .audio) { granted in
676-
// continuation.resume(returning: granted)
677-
// }
678-
// }
679-
// case .restricted, .denied:
680-
// print("Microphone access denied")
681-
// return false
682-
// case .authorized:
683-
// return true
684-
// @unknown default:
685-
// fatalError("Unknown authorization status")
686-
// }
687-
// }
688668
}
689669

690670
#Preview {

README.md

+5-1
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,11 @@ You can then run them via the CLI with:
124124
swift run transcribe --model-path "Models/whisperkit-coreml/openai_whisper-large-v3" --audio-path "path/to/your/audio.{wav,mp3,m4a,flac}"
125125
```
126126

127-
Which should print a transcription of the audio file.
127+
Which should print a transcription of the audio file. If you would like to stream the audio directly from a microphone, use:
128+
129+
```bash
130+
swift run transcribe --model-path "Models/whisperkit-coreml/openai_whisper-large-v3"
131+
```
128132

129133
## Contributing & Roadmap
130134

Sources/WhisperKit/Core/AudioProcessor.swift

+20
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,26 @@ public class AudioProcessor: NSObject, AudioProcessing {
302302
return convertedArray
303303
}
304304

305+
public static func requestMicrophoneIfNeeded() async -> Bool {
306+
let microphoneStatus = AVCaptureDevice.authorizationStatus(for: .audio)
307+
308+
switch microphoneStatus {
309+
case .notDetermined:
310+
return await withCheckedContinuation { continuation in
311+
AVCaptureDevice.requestAccess(for: .audio) { granted in
312+
continuation.resume(returning: granted)
313+
}
314+
}
315+
case .restricted, .denied:
316+
print("Microphone access denied")
317+
return false
318+
case .authorized:
319+
return true
320+
@unknown default:
321+
fatalError("Unknown authorization status")
322+
}
323+
}
324+
305325
deinit {
306326
stopRecording()
307327
}
+222
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,222 @@
1+
// For licensing see accompanying LICENSE.md file.
2+
// Copyright © 2024 Argmax, Inc. All rights reserved.
3+
4+
import Foundation
5+
6+
extension AudioWarper {
7+
public struct State {
8+
public var isRecording: Bool = false
9+
public var currentFallbacks: Int = 0
10+
public var lastBufferSize: Int = 0
11+
public var lastConfirmedSegmentEndSeconds: Float = 0
12+
public var bufferEnergy: [Float] = []
13+
public var currentText: String = ""
14+
public var confirmedSegments: [TranscriptionSegment] = []
15+
public var unconfirmedSegments: [TranscriptionSegment] = []
16+
public var unconfirmedText: [String] = []
17+
}
18+
}
19+
20+
/// Responsible for streaming audio from the microphone, processing it, and transcribing it in real-time.
21+
public actor AudioWarper {
22+
private var state: AudioWarper.State = .init() {
23+
didSet {
24+
stateChangeCallback?(state)
25+
}
26+
}
27+
private let stateChangeCallback: ((AudioWarper.State) -> Void)?
28+
29+
private let requiredSegmentsForConfirmation: Int
30+
private let useVAD: Bool
31+
private let silenceThreshold: Float
32+
private let compressionCheckWindow: Int
33+
private let audioProcessor: any AudioProcessing
34+
private let transcriber: any Transcriber
35+
private let decodingOptions: DecodingOptions
36+
37+
public init(
38+
audioProcessor: any AudioProcessing,
39+
transcriber: any Transcriber,
40+
decodingOptions: DecodingOptions,
41+
requiredSegmentsForConfirmation: Int = 2,
42+
silenceThreshold: Float = 0.3,
43+
compressionCheckWindow: Int = 20,
44+
useVAD: Bool = true,
45+
stateChangeCallback: ((AudioWarper.State) -> Void)?
46+
) {
47+
self.audioProcessor = audioProcessor
48+
self.transcriber = transcriber
49+
self.decodingOptions = decodingOptions
50+
self.requiredSegmentsForConfirmation = requiredSegmentsForConfirmation
51+
self.silenceThreshold = silenceThreshold
52+
self.compressionCheckWindow = compressionCheckWindow
53+
self.useVAD = useVAD
54+
self.stateChangeCallback = stateChangeCallback
55+
}
56+
57+
public func startRecording() async throws {
58+
guard !state.isRecording else { return }
59+
guard await AudioProcessor.requestMicrophoneIfNeeded() else {
60+
print("Microphone access was not granted.")
61+
return
62+
}
63+
state.isRecording = true
64+
try audioProcessor.startRecordingLive { [weak self] _ in
65+
Task { [weak self] in
66+
await self?.onAudioBufferCallback()
67+
}
68+
}
69+
await realtimeLoop()
70+
}
71+
72+
public func stopRecording() {
73+
state.isRecording = false
74+
audioProcessor.stopRecording()
75+
}
76+
77+
private func realtimeLoop() async {
78+
while state.isRecording {
79+
do {
80+
try await transcribeCurrentBuffer()
81+
} catch {
82+
print("Error: \(error.localizedDescription)")
83+
break
84+
}
85+
}
86+
}
87+
88+
private func onAudioBufferCallback() {
89+
state.bufferEnergy = audioProcessor.relativeEnergy
90+
}
91+
92+
private func onProgressCallback(_ progress: TranscriptionProgress) {
93+
let fallbacks = Int(progress.timings.totalDecodingFallbacks)
94+
if progress.text.count < state.currentText.count {
95+
if fallbacks == state.currentFallbacks {
96+
state.unconfirmedText.append(state.currentText)
97+
} else {
98+
print("Fallback occured: \(fallbacks)")
99+
}
100+
}
101+
state.currentText = progress.text
102+
state.currentFallbacks = fallbacks
103+
}
104+
105+
private func transcribeCurrentBuffer() async throws {
106+
// Retrieve the current audio buffer from the audio processor
107+
let currentBuffer = audioProcessor.audioSamples
108+
109+
// Calculate the size and duration of the next buffer segment
110+
let nextBufferSize = currentBuffer.count - state.lastBufferSize
111+
let nextBufferSeconds = Float(nextBufferSize) / Float(WhisperKit.sampleRate)
112+
113+
// Only run the transcribe if the next buffer has at least 1 second of audio
114+
guard nextBufferSeconds > 1 else {
115+
if state.currentText == "" {
116+
state.currentText = "Waiting for speech..."
117+
}
118+
try await Task.sleep(nanoseconds: 100_000_000) // sleep for 100ms for next buffer
119+
return
120+
}
121+
122+
if useVAD {
123+
// Retrieve the current relative energy values from the audio processor
124+
let currentRelativeEnergy = audioProcessor.relativeEnergy
125+
126+
// Calculate the number of energy values to consider based on the duration of the next buffer
127+
// Each energy value corresponds to 1 buffer length (100ms of audio), hence we divide by 0.1
128+
let energyValuesToConsider = Int(nextBufferSeconds / 0.1)
129+
130+
// Extract the relevant portion of energy values from the currentRelativeEnergy array
131+
let nextBufferEnergies = currentRelativeEnergy.suffix(energyValuesToConsider)
132+
133+
// Determine the number of energy values to check for voice presence
134+
// Considering up to the last 1 second of audio, which translates to 10 energy values
135+
let numberOfValuesToCheck = max(10, nextBufferEnergies.count - 10)
136+
137+
// Check if any of the energy values in the considered range exceed the silence threshold
138+
// This indicates the presence of voice in the buffer
139+
let voiceDetected = nextBufferEnergies.prefix(numberOfValuesToCheck).contains { $0 > Float(silenceThreshold) }
140+
141+
// Only run the transcribe if the next buffer has voice
142+
guard voiceDetected else {
143+
if state.currentText == "" {
144+
state.currentText = "Waiting for speech..."
145+
}
146+
// Sleep for 100ms and check the next buffer
147+
try await Task.sleep(nanoseconds: 100_000_000)
148+
return
149+
}
150+
}
151+
152+
// Run transcribe
153+
state.lastBufferSize = currentBuffer.count
154+
155+
let transcription = try await transcribeAudioSamples(Array(currentBuffer))
156+
157+
state.currentText = ""
158+
state.unconfirmedText = []
159+
guard let segments = transcription?.segments else {
160+
return
161+
}
162+
163+
// Logic for moving segments to confirmedSegments
164+
if segments.count > requiredSegmentsForConfirmation {
165+
// Calculate the number of segments to confirm
166+
let numberOfSegmentsToConfirm = segments.count - requiredSegmentsForConfirmation
167+
168+
// Confirm the required number of segments
169+
let confirmedSegmentsArray = Array(segments.prefix(numberOfSegmentsToConfirm))
170+
let remainingSegments = Array(segments.suffix(requiredSegmentsForConfirmation))
171+
172+
// Update lastConfirmedSegmentEnd based on the last confirmed segment
173+
if let lastConfirmedSegment = confirmedSegmentsArray.last, lastConfirmedSegment.end > state.lastConfirmedSegmentEndSeconds {
174+
state.lastConfirmedSegmentEndSeconds = lastConfirmedSegment.end
175+
176+
// Add confirmed segments to the confirmedSegments array
177+
if !state.confirmedSegments.contains(confirmedSegmentsArray) {
178+
state.confirmedSegments.append(contentsOf: confirmedSegmentsArray)
179+
}
180+
}
181+
182+
// Update transcriptions to reflect the remaining segments
183+
state.unconfirmedSegments = remainingSegments
184+
} else {
185+
// Handle the case where segments are fewer or equal to required
186+
state.unconfirmedSegments = segments
187+
}
188+
}
189+
190+
private func transcribeAudioSamples(_ samples: [Float]) async throws -> TranscriptionResult? {
191+
var options = decodingOptions
192+
options.clipTimestamps = [state.lastConfirmedSegmentEndSeconds]
193+
let checkWindow = compressionCheckWindow
194+
return try await transcriber.transcribe(audioArray: samples, decodeOptions: options) { [weak self] progress in
195+
Task { [weak self] in
196+
await self?.onProgressCallback(progress)
197+
}
198+
return Self.shouldStopEarly(progress: progress, options: options, compressionCheckWindow: checkWindow)
199+
}
200+
}
201+
202+
private static func shouldStopEarly(
203+
progress: TranscriptionProgress,
204+
options: DecodingOptions,
205+
compressionCheckWindow: Int
206+
) -> Bool? {
207+
let currentTokens = progress.tokens
208+
if currentTokens.count > compressionCheckWindow {
209+
let checkTokens: [Int] = currentTokens.suffix(compressionCheckWindow)
210+
let compressionRatio = compressionRatio(of: checkTokens)
211+
if compressionRatio > options.compressionRatioThreshold ?? 0.0 {
212+
return false
213+
}
214+
}
215+
if let avgLogprob = progress.avgLogprob, let logProbThreshold = options.logProbThreshold {
216+
if avgLogprob < logProbThreshold {
217+
return false
218+
}
219+
}
220+
return nil
221+
}
222+
}

Sources/WhisperKit/Core/WhisperKit.swift

+6-1
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,13 @@ import Hub
99
import TensorUtils
1010
import Tokenizers
1111

12+
public protocol Transcriber {
13+
func transcribe(audioPath: String, decodeOptions: DecodingOptions?, callback: TranscriptionCallback) async throws -> TranscriptionResult?
14+
func transcribe(audioArray: [Float], decodeOptions: DecodingOptions?, callback: TranscriptionCallback) async throws -> TranscriptionResult?
15+
}
16+
1217
@available(macOS 14, iOS 17, watchOS 10, visionOS 1, *)
13-
public class WhisperKit {
18+
public class WhisperKit: Transcriber {
1419
// Models
1520
public var modelVariant: ModelVariant = .tiny
1621
public var modelState: ModelState = .unloaded

0 commit comments

Comments
 (0)