Skip to content

Commit

Permalink
add full-duplex-bot (#2685)
Browse files Browse the repository at this point in the history
Co-authored-by: Yulin Li <[email protected]>
  • Loading branch information
yulin-li and Yulin Li authored Dec 2, 2024
1 parent 274be90 commit 89ca39f
Show file tree
Hide file tree
Showing 19 changed files with 3,006 additions and 1 deletion.
1 change: 1 addition & 0 deletions .gitattributes
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ proguard-rules.pro text
*.cshtml text
*.csproj text
*.css text
*.csv text
*.editorconfig text
*.entitlements text
*.go text
Expand Down
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -330,5 +330,5 @@ samples/js/**/microsoft.cognitiveservices.speech.sdk.bundle*.js*
**/objectivec/**/Podfile.lock
**/objectivec/**/*.xcworkspace/

# iOS framwork
# iOS framework
samples/objective-c/ios/MicrosoftCognitiveServicesSpeech.framework/
7 changes: 7 additions & 0 deletions scenarios/full-duplex-bot/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
.env
.venv
/**/venv

__pycache__/

**/tmp*.wav
20 changes: 20 additions & 0 deletions scenarios/full-duplex-bot/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
FROM mcr.microsoft.com/azurelinux/base/python:3.12

RUN tdnf distro-sync -y && \
tdnf install -y ca-certificates && \
tdnf clean all

RUN pip install poetry
RUN poetry config virtualenvs.create false

WORKDIR /app

COPY pyproject.toml poetry.lock ./

RUN poetry install --no-dev --no-root

COPY fullduplex /app

ENV FORWARDED_ALLOW_IPS="*"

ENTRYPOINT [ "uvicorn", "webapp:app", "--host", "0.0.0.0", "--port", "8080", "--proxy-headers" ]
130 changes: 130 additions & 0 deletions scenarios/full-duplex-bot/fullduplex/VAD/vad_iterator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
import copy
import torch
import numpy as np


class VADIterator:
def __init__(
self,
model,
threshold: float = 0.5,
sampling_rate: int = 16000,
min_silence_duration_ms: int = 100,
speech_pad_ms: int = 30,
):
"""
Mainly taken from https://github.com/snakers4/silero-vad
Class for stream imitation
Parameters
----------
model: preloaded .jit/.onnx silero VAD model
threshold: float (default - 0.5)
Speech threshold. Silero VAD outputs speech probabilities for each audio chunk, probabilities ABOVE this value are considered as SPEECH.
It is better to tune this parameter for each dataset separately, but "lazy" 0.5 is pretty good for most datasets.
sampling_rate: int (default - 16000)
Currently silero VAD models support 8000 and 16000 sample rates
min_silence_duration_ms: int (default - 100 milliseconds)
In the end of each speech chunk wait for min_silence_duration_ms before separating it
speech_pad_ms: int (default - 30 milliseconds)
Final speech chunks are padded by speech_pad_ms each side
"""

self.model = model
self.threshold = threshold
self.sampling_rate = sampling_rate
self.is_speaking = False
self.buffer = []
self.start_pad_buffer = []


if sampling_rate not in [8000, 16000]:
raise ValueError(
"VADIterator does not support sampling rates other than [8000, 16000]"
)

self.min_silence_samples = sampling_rate * min_silence_duration_ms / 1000
self.speech_pad_samples = sampling_rate * speech_pad_ms / 1000

self.reset_states()

def reset_states(self):
self.model.reset_states()
self.triggered = False
self.temp_end = 0
self.current_sample = 0

@torch.no_grad()
def __call__(self, x):
"""
x: torch.Tensor
audio chunk (see examples in repo)
return_seconds: bool (default - False)
whether return timestamps in seconds (default - samples)
"""

if not torch.is_tensor(x):
try:
x = torch.Tensor(x)
except Exception:
raise TypeError("Audio cannot be casted to tensor. Cast it manually")

window_size_samples = len(x[0]) if x.dim() == 2 else len(x)
self.current_sample += window_size_samples

speech_prob = self.model(x, self.sampling_rate).item()

if (speech_prob >= self.threshold) and self.temp_end:
self.temp_end = 0

if (speech_prob >= self.threshold) and not self.triggered:
self.triggered = True
self.buffer = copy.deepcopy(self.start_pad_buffer)
self.buffer.append(x)
return None

if (speech_prob < self.threshold - 0.15) and self.triggered:
if not self.temp_end:
self.temp_end = self.current_sample
if self.current_sample - self.temp_end >= self.min_silence_samples:
# if self.current_sample - self.temp_end > self.speech_pad_samples:
# return None
# else:
# end of speak
self.temp_end = 0
self.triggered = False
spoken_utterance = self.buffer
self.buffer = []
return spoken_utterance

if self.triggered:
self.buffer.append(x)

self.start_pad_buffer.append(x)
self.start_pad_buffer = self.start_pad_buffer[-int(self.speech_pad_samples//window_size_samples):]

return None

def int2float(sound):
"""
Taken from https://github.com/snakers4/silero-vad
"""
sound = sound.astype("float32")
sound *= 1 / 32768
# sound = sound.squeeze() # depends on the use case
return sound

def float2int(sound):
"""
Taken from
"""

# sound = sound.squeeze() # depends on the use case
sound *= 32768
sound = np.clip(sound, -32768, 32767)
return sound.astype("int16")
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
import json
import logging
import queue
import threading
from typing import Callable
from requests import Session
import wave
import torch
import numpy as np
from io import BytesIO

from VAD.vad_iterator import VADIterator, int2float, float2int

logger = logging.getLogger(__name__)


AzureADTokenProvider = Callable[[], str]

class VADHandler(threading.Thread):
def __init__(self, model, threshold, sampling_rate, min_silence_duration_ms, speech_pad_ms, stop_event, input_queue, output_queue):
super().__init__()
self.stop_event = stop_event
self.input_queue = input_queue
self.output_queue = output_queue
self.vad_iterator = VADIterator(
model,
threshold=threshold,
sampling_rate=sampling_rate,
min_silence_duration_ms=min_silence_duration_ms,
speech_pad_ms=speech_pad_ms
)


def run(self) -> None:
logger.info("VAD handler started")
while not self.stop_event.is_set():
chunk = self.input_queue.get()
if chunk is None:
break
chunk = np.frombuffer(chunk, dtype=np.int16)
vad_output = self.vad_iterator(torch.from_numpy(int2float(chunk)))
if vad_output is not None and len(vad_output) != 0:
logger.info(f"VAD output: {len(vad_output)}")
array = np.concatenate(vad_output)
self.output_queue.put(array)
self.output_queue.put(None)


class AzureFastTranscriptionClient(threading.Thread):
def __init__(self, endpoint: str, locale: str, key: str, token_provider: AzureADTokenProvider, stop_event, input_queue, callback):
super().__init__()
self.endpoint = f"{endpoint}/speechtotext/transcriptions:transcribe?api-version=2024-11-15"
self.token_provider = token_provider
self.key = key
self.stop_event = stop_event
self.input_queue = input_queue
self.callback = callback
self.session = Session()
self.session.get(self.endpoint)
self.data = {
'definition': json.dumps({
'locales': [locale],
'profanityFilterMode': 'Masked',
'channels': [0]
})
}

def run(self) -> None:
# warm up connection
self.session.get(self.endpoint)
while not self.stop_event.is_set():
array = self.input_queue.get()
if array is None:
break
if self.token_provider is None:
headers = {
'Ocp-Apim-Subscription-Key': self.key
}
else:
headers = {
'Authorization': f'Bearer {self.token_provider()}'
}
array = float2int(array)
# open a memory buffer
tmp = BytesIO()
with wave.open(tmp, "wb") as wf:
wf.setnchannels(1)
wf.setsampwidth(2)
wf.setframerate(16000)
wf.writeframes(array.tobytes())

files = {
'audio': tmp.getbuffer()
}
response = self.session.post(self.endpoint, headers=headers, files=files, data=self.data)
if response.status_code >= 400:
print(response.text)
if self.callback is not None:
self.callback(response.json())

class AzureFastTranscriptionRecognizer:
def __init__(self, endpoint: str, token_provider: AzureADTokenProvider = None, key: str = None):
self.endpoint = endpoint
self.token_provider = token_provider
self.key = key
self.audio_queue = queue.Queue()
self.vad_queue = queue.Queue()
self.stop_event = threading.Event()
self._partial_chunk = b""
self._on_recognized = None
vad_model, _ = torch.hub.load("snakers4/silero-vad", "silero_vad")
self.vad_handler = VADHandler(
model=vad_model,
threshold=0.5,
sampling_rate=16000,
min_silence_duration_ms=150,
speech_pad_ms=100,
stop_event=self.stop_event,
input_queue=self.audio_queue,
output_queue=self.vad_queue
)
self._locale = "en-US"

@property
def locale(self):
return self._locale

@locale.setter
def locale(self, value: str):
self._locale = value

def start(self):
self.recognizer_client = AzureFastTranscriptionClient(
endpoint=self.endpoint,
locale=self.locale,
key=self.key,
token_provider=self.token_provider,
stop_event=self.stop_event,
input_queue=self.vad_queue,
callback=self.on_recognized
)
threading.Thread(target=self.vad_handler.run).start()
threading.Thread(target=self.recognizer_client.run).start()

def start_continuous_recognition(self):
self.start()

def stop(self, force=False):
if force:
self.stop_event.set()
self.audio_queue.put(None)

def stop_continuous_recognition(self):
self.stop()

def __call__(self, chunk: bytes):
self._partial_chunk += chunk
while len(self._partial_chunk) >= 1024:
self.audio_queue.put(self._partial_chunk[:1024])
self._partial_chunk = self._partial_chunk[1024:]

@property
def on_recognized(self):
return self._on_recognized

@on_recognized.setter
def on_recognized(self, callback: Callable[[str], None]):
self._on_recognized = callback

if __name__ == '__main__':
import os
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
token_provider = get_bearer_token_provider(DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default")
recognizer = AzureFastTranscriptionRecognizer(
endpoint=os.getenv("SPEECH_ENDPOINT"),
key=os.getenv("SPEECH_KEY"))
recognizer.on_recognized = lambda x: print(x)
recognizer.start()
with open("tests/4.wav", "rb") as f:
f.read(44)
while True:
chunk = f.read(1600)
if not chunk:
break
recognizer(chunk)
# time.sleep(0.05)
recognizer.stop()
Loading

0 comments on commit 89ca39f

Please sign in to comment.