Skip to content

Commit

Permalink
feat: support passing metadata (#53)
Browse files Browse the repository at this point in the history
* feat: support passing metadata

* pass credentials via metadata call credentials

* Update common proto submodule
  • Loading branch information
virajkarandikar authored Aug 23, 2023
1 parent 185e3ff commit b367963
Show file tree
Hide file tree
Showing 15 changed files with 31 additions and 17 deletions.
2 changes: 1 addition & 1 deletion common
1 change: 1 addition & 0 deletions riva/client/argparse_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,4 +57,5 @@ def add_connection_argparse_parameters(parser: argparse.ArgumentParser) -> argpa
parser.add_argument(
"--use-ssl", action='store_true', help="Boolean to control if SSL/TLS encryption should be used."
)
parser.add_argument("--metadata", action='append', nargs='+', help="Send HTTP Header(s) to server")
return parser
19 changes: 16 additions & 3 deletions riva/client/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,26 @@
import os
from pathlib import Path
from typing import List, Optional, Tuple, Union

import grpc


def create_channel(
ssl_cert: Optional[Union[str, os.PathLike]] = None, use_ssl: bool = False, uri: str = "localhost:50051",
ssl_cert: Optional[Union[str, os.PathLike]] = None, use_ssl: bool = False, uri: str = "localhost:50051", metadata: Optional[List[Tuple[str, str]]] = None,
) -> grpc.Channel:

def metadata_callback(context, callback):
callback(metadata, None)

if ssl_cert is not None or use_ssl:
root_certificates = None
if ssl_cert is not None:
ssl_cert = Path(ssl_cert).expanduser()
with open(ssl_cert, 'rb') as f:
root_certificates = f.read()
creds = grpc.ssl_channel_credentials(root_certificates)
if metadata:
auth_creds = grpc.metadata_call_credentials(metadata_callback)
creds = grpc.composite_channel_credentials(creds, auth_creds)
channel = grpc.secure_channel(uri, creds)
else:
channel = grpc.insecure_channel(uri)
Expand All @@ -30,6 +36,7 @@ def __init__(
ssl_cert: Optional[Union[str, os.PathLike]] = None,
use_ssl: bool = False,
uri: str = "localhost:50051",
metadata_args: List[List[str]] = None,
) -> None:
"""
A class responsible for establishing connection with a server and providing security metadata.
Expand All @@ -44,7 +51,13 @@ def __init__(
self.ssl_cert: Optional[Path] = None if ssl_cert is None else Path(ssl_cert).expanduser()
self.uri: str = uri
self.use_ssl: bool = use_ssl
self.channel: grpc.Channel = create_channel(self.ssl_cert, self.use_ssl, self.uri)
self.metadata = []
if metadata_args:
for meta in metadata_args:
if len(meta) != 2:
raise ValueError(f"Metadata should have 2 parameters in \"key\" \"value\" pair. Receieved {len(meta)} parameters.")
self.metadata.append(tuple(meta))
self.channel: grpc.Channel = create_channel(self.ssl_cert, self.use_ssl, self.uri, self.metadata)

def get_auth_metadata(self) -> List[Tuple[str, str]]:
"""
Expand Down
2 changes: 1 addition & 1 deletion scripts/asr/riva_streaming_asr_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def streaming_transcription_worker(
) -> None:
output_file = Path(output_file).expanduser()
try:
auth = riva.client.Auth(args.ssl_cert, args.use_ssl, args.server)
auth = riva.client.Auth(args.ssl_cert, args.use_ssl, args.server, args.metadata)
asr_service = riva.client.ASRService(auth)
config = riva.client.StreamingRecognitionConfig(
config=riva.client.RecognitionConfig(
Expand Down
2 changes: 1 addition & 1 deletion scripts/asr/transcribe_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def main() -> None:
if args.list_devices:
riva.client.audio_io.list_output_devices()
return
auth = riva.client.Auth(args.ssl_cert, args.use_ssl, args.server)
auth = riva.client.Auth(args.ssl_cert, args.use_ssl, args.server, args.metadata)
asr_service = riva.client.ASRService(auth)
config = riva.client.StreamingRecognitionConfig(
config=riva.client.RecognitionConfig(
Expand Down
2 changes: 1 addition & 1 deletion scripts/asr/transcribe_file_offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def parse_args() -> argparse.Namespace:

def main() -> None:
args = parse_args()
auth = riva.client.Auth(args.ssl_cert, args.use_ssl, args.server)
auth = riva.client.Auth(args.ssl_cert, args.use_ssl, args.server, args.metadata)
asr_service = riva.client.ASRService(auth)
config = riva.client.RecognitionConfig(
language_code=args.language_code,
Expand Down
2 changes: 1 addition & 1 deletion scripts/asr/transcribe_mic.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def main() -> None:
if args.list_devices:
riva.client.audio_io.list_input_devices()
return
auth = riva.client.Auth(args.ssl_cert, args.use_ssl, args.server)
auth = riva.client.Auth(args.ssl_cert, args.use_ssl, args.server, args.metadata)
asr_service = riva.client.ASRService(auth)
config = riva.client.StreamingRecognitionConfig(
config=riva.client.RecognitionConfig(
Expand Down
2 changes: 1 addition & 1 deletion scripts/nlp/eval_intent_slot.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ def parse_args() -> argparse.Namespace:

def main() -> None:
args = parse_args()
auth = riva.client.Auth(args.ssl_cert, args.use_ssl, args.server)
auth = riva.client.Auth(args.ssl_cert, args.use_ssl, args.server, args.metadata)
service = riva.client.NLPService(auth)
intent_report, slot_report = intent_slots_classification_report(
args.input_file,
Expand Down
2 changes: 1 addition & 1 deletion scripts/nlp/intentslot_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def pretty_print_result(

def main() -> None:
args = parse_args()
auth = riva.client.Auth(args.ssl_cert, args.use_ssl, args.server)
auth = riva.client.Auth(args.ssl_cert, args.use_ssl, args.server, args.metadata)
service = riva.client.NLPService(auth)
if args.interactive:
while True:
Expand Down
2 changes: 1 addition & 1 deletion scripts/nlp/ner_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def parse_args() -> argparse.Namespace:

def main() -> None:
args = parse_args()
auth = riva.client.Auth(args.ssl_cert, args.use_ssl, args.server)
auth = riva.client.Auth(args.ssl_cert, args.use_ssl, args.server, args.metadata)
service = riva.client.NLPService(auth)
tokens, slots, slot_confidences, starts, ends = riva.client.extract_most_probable_token_classification_predictions(
service.classify_tokens(input_strings=args.query, model_name=args.model)
Expand Down
4 changes: 2 additions & 2 deletions scripts/nlp/punctuation_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def parse_args() -> argparse.Namespace:


def run_punct_capit(args: argparse.Namespace) -> None:
auth = riva.client.Auth(args.ssl_cert, args.use_ssl, args.server)
auth = riva.client.Auth(args.ssl_cert, args.use_ssl, args.server, args.metadata)
nlp_service = riva.client.NLPService(auth)
if args.interactive:
while True:
Expand Down Expand Up @@ -134,7 +134,7 @@ def run_tests(args: argparse.Namespace) -> int:
],
}

auth = riva.client.Auth(args.ssl_cert, args.use_ssl, args.server)
auth = riva.client.Auth(args.ssl_cert, args.use_ssl, args.server, args.metadata)
nlp_service = riva.client.NLPService(auth)

fail_count = 0
Expand Down
2 changes: 1 addition & 1 deletion scripts/nlp/qa_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def parse_args() -> argparse.Namespace:

def main() -> None:
args = parse_args()
auth = riva.client.Auth(args.ssl_cert, args.use_ssl, args.server)
auth = riva.client.Auth(args.ssl_cert, args.use_ssl, args.server, args.metadata)
service = riva.client.NLPService(auth)
resp = service.natural_query(args.query, args.context)
print(resp)
Expand Down
2 changes: 1 addition & 1 deletion scripts/nlp/text_classify_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def parse_args() -> argparse.Namespace:

def main() -> None:
args = parse_args()
auth = riva.client.Auth(args.ssl_cert, args.use_ssl, args.server)
auth = riva.client.Auth(args.ssl_cert, args.use_ssl, args.server, args.metadata)
service = riva.client.NLPService(auth)
print(riva.client.nlp.extract_most_probable_text_class_and_confidence(service.classify_text(args.query, args.model)))

Expand Down
2 changes: 1 addition & 1 deletion scripts/nmt/nmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def request(inputs,args):

args = parse_args()

auth = riva.client.Auth(args.ssl_cert, args.use_ssl, args.server)
auth = riva.client.Auth(args.ssl_cert, args.use_ssl, args.server, args.metadata)
nmt_client = riva.client.NeuralMachineTranslationClient(auth)

if args.list_models:
Expand Down
2 changes: 1 addition & 1 deletion scripts/tts/talk.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def main() -> None:
if args.list_devices:
riva.client.audio_io.list_output_devices()
return
auth = riva.client.Auth(args.ssl_cert, args.use_ssl, args.server)
auth = riva.client.Auth(args.ssl_cert, args.use_ssl, args.server, args.metadata)
service = riva.client.SpeechSynthesisService(auth)
nchannels = 1
sampwidth = 2
Expand Down

0 comments on commit b367963

Please sign in to comment.