Skip to content

Commit

Permalink
add s2s and s2t client utility functions (#43)
Browse files Browse the repository at this point in the history
  • Loading branch information
rmittal-github authored Apr 26, 2023
1 parent 2d3a719 commit 22438ab
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 2 deletions.
1 change: 1 addition & 0 deletions riva/client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,5 +36,6 @@
from riva.client.proto.riva_asr_pb2 import RecognitionConfig, StreamingRecognitionConfig
from riva.client.proto.riva_audio_pb2 import AudioEncoding
from riva.client.proto.riva_nlp_pb2 import AnalyzeIntentOptions
from riva.client.proto.riva_nmt_pb2 import StreamingTranslateSpeechToSpeechConfig, TranslationConfig, SynthesizeSpeechConfig, StreamingTranslateSpeechToTextConfig
from riva.client.tts import SpeechSynthesisService
from riva.client.nmt import NeuralMachineTranslationClient
109 changes: 107 additions & 2 deletions riva/client/nmt.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,26 @@
# SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: MIT

from typing import Generator, Optional, Union, List

from typing import Callable, Dict, Generator, Iterable, List, Optional, TextIO, Union
from grpc._channel import _MultiThreadedRendezvous

import riva.client.proto.riva_nmt_pb2 as riva_nmt
import riva.client.proto.riva_nmt_pb2_grpc as riva_nmt_srv
from riva.client import Auth

def streaming_s2s_request_generator(
audio_chunks: Iterable[bytes], streaming_config: riva_nmt.StreamingTranslateSpeechToSpeechConfig
) -> Generator[riva_nmt.StreamingTranslateSpeechToSpeechRequest, None, None]:
yield riva_nmt.StreamingTranslateSpeechToSpeechRequest(config=streaming_config)
for chunk in audio_chunks:
yield riva_nmt.StreamingTranslateSpeechToSpeechRequest(audio_content=chunk)

def streaming_s2t_request_generator(
audio_chunks: Iterable[bytes], streaming_config: riva_nmt.StreamingTranslateSpeechToTextConfig
) -> Generator[riva_nmt.StreamingTranslateSpeechToTextRequest, None, None]:
yield riva_nmt.StreamingTranslateSpeechToTextRequest(config=streaming_config)
for chunk in audio_chunks:
yield riva_nmt.StreamingTranslateSpeechToTextRequest(audio_content=chunk)

class NeuralMachineTranslationClient:
"""
Expand All @@ -25,6 +37,99 @@ def __init__(self, auth: Auth) -> None:
self.auth = auth
self.stub = riva_nmt_srv.RivaTranslationStub(self.auth.channel)

def streaming_s2s_response_generator(
self, audio_chunks: Iterable[bytes], streaming_config: riva_nmt.StreamingTranslateSpeechToSpeechConfig
) -> Generator[riva_nmt.StreamingTranslateSpeechToSpeechResponse, None, None]:
"""
Generates speech to speech translation responses for fragments of speech audio in :param:`audio_chunks`.
The purpose of the method is to perform speech to speech translation "online" - as soon as
audio is acquired on small chunks of audio.
All available audio chunks will be sent to a server on first ``next()`` call.
Args:
audio_chunks (:obj:`Iterable[bytes]`): an iterable object which contains raw audio fragments
of speech. For example, such raw audio can be obtained with
.. code-block:: python
import wave
with wave.open(file_name, 'rb') as wav_f:
raw_audio = wav_f.readframes(n_frames)
streaming_config (:obj:`riva.client.proto.riva_nmt_pb2.StreamingTranslateSpeechToSpeechConfig`): a config for streaming.
You may find description of config fields in message ``StreamingTranslateSpeechToSpeechConfig`` in
`common repo
<https://docs.nvidia.com/deeplearning/riva/user-guide/docs/reference/protos/protos.html#riva-proto-riva-nmt-proto>`_.
An example of creation of streaming config:
.. code-style:: python
from riva.client import RecognitionConfig, StreamingRecognitionConfig, StreamingTranslateSpeechToSpeechConfig, TranslationConfig, SynthesizeSpeechConfig
config = RecognitionConfig(enable_automatic_punctuation=True)
asr_config = StreamingRecognitionConfig(config, interim_results=True)
translation_config = TranslationConfig(source_language_code="es-US", target_language_code="en-US")
tts_config = SynthesizeSpeechConfig(sample_rate_hz=44100, voice_name="English-US.Female-1")
streaming_config = StreamingTranslateSpeechToSpeechConfig(asr_config, translation_config, tts_config)
Yields:
:obj:`riva.client.proto.riva_nmt_pb2.StreamingTranslateSpeechToSpeechResponse`: responses for audio chunks in
:param:`audio_chunks`. You may find description of response fields in declaration of
``StreamingTranslateSpeechToSpeechResponse``
message `here
<https://docs.nvidia.com/deeplearning/riva/user-guide/docs/reference/protos/protos.html#riva-proto-riva-nmt-proto>`_.
"""
generator = streaming_s2s_request_generator(audio_chunks, streaming_config)
for response in self.stub.StreamingTranslateSpeechToSpeech(generator, metadata=self.auth.get_auth_metadata()):
yield response


def streaming_s2t_response_generator(
self, audio_chunks: Iterable[bytes], streaming_config: riva_nmt.StreamingTranslateSpeechToTextConfig
) -> Generator[riva_nmt.StreamingTranslateSpeechToTextResponse, None, None]:
"""
Generates speech to text translation responses for fragments of speech audio in :param:`audio_chunks`.
The purpose of the method is to perform speech to text translation "online" - as soon as
audio is acquired on small chunks of audio.
All available audio chunks will be sent to a server on first ``next()`` call.
Args:
audio_chunks (:obj:`Iterable[bytes]`): an iterable object which contains raw audio fragments
of speech. For example, such raw audio can be obtained with
.. code-block:: python
import wave
with wave.open(file_name, 'rb') as wav_f:
raw_audio = wav_f.readframes(n_frames)
streaming_config (:obj:`riva.client.proto.riva_nmt_pb2.StreamingTranslateSpeechToTextConfig`): a config for streaming.
You may find description of config fields in message ``StreamingTranslateSpeechToTextConfig`` in
`common repo
<https://docs.nvidia.com/deeplearning/riva/user-guide/docs/reference/protos/protos.html#riva-proto-riva-nmt-proto>`_.
An example of creation of streaming config:
.. code-style:: python
from riva.client import RecognitionConfig, StreamingRecognitionConfig, StreamingTranslateSpeechToTextConfig, TranslationConfig
config = RecognitionConfig(enable_automatic_punctuation=True)
asr_config = StreamingRecognitionConfig(config, interim_results=True)
translation_config = TranslationConfig(source_language_code="es-US", target_language_code="en-US")
streaming_config = StreamingTranslateSpeechToTextConfig(asr_config, translation_config)
Yields:
:obj:`riva.client.proto.riva_nmt_pb2.StreamingTranslateSpeechToTextResponse`: responses for audio chunks in
:param:`audio_chunks`. You may find description of response fields in declaration of
``StreamingTranslateSpeechToTextResponse``
message `here
<https://docs.nvidia.com/deeplearning/riva/user-guide/docs/reference/protos/protos.html#riva-proto-riva-nmt-proto>`_.
"""
generator = streaming_s2t_request_generator(audio_chunks, streaming_config)
for response in self.stub.StreamingTranslateSpeechToText(generator, metadata=self.auth.get_auth_metadata()):
yield response


def translate(
self,
texts: List[str],
Expand Down

0 comments on commit 22438ab

Please sign in to comment.