Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: Support Cross Region Inference #629

Merged
merged 1 commit into from
Feb 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions lib/chatbot-api/rest-api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,7 @@ export class ApiResolvers extends Construct {
actions: [
"bedrock:ListFoundationModels",
"bedrock:ListCustomModels",
"bedrock:ListInferenceProfiles",
"bedrock:InvokeModel",
"bedrock:InvokeModelWithResponseStream",
],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@ def handle_run(self, input: dict, model_kwargs: dict, files: Optional[list] = No


registry.register(r"^bedrock.anthropic.claude-3.*", Claude3)
registry.register(r"^bedrock.*.anthropic.claude-3.*", Claude3)
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,10 @@
registry.register(r"^bedrock.amazon.nova-reel*", BedrockChatMediaGeneration)
registry.register(r"^bedrock.amazon.nova-canvas*", BedrockChatMediaGeneration)
registry.register(r"^bedrock.amazon.nova*", BedrockChatAdapter)
registry.register(r"^bedrock.*.amazon.nova*", BedrockChatAdapter)
registry.register(r"^bedrock.*.anthropic.claude*", BedrockChatAdapter)
registry.register(r"^bedrock.*.meta.llama*", BedrockChatAdapter)
registry.register(r"^bedrock.*.mistral.mistral-large*", BedrockChatAdapter)
registry.register(r"^bedrock.*.mistral.mistral-small*", BedrockChatAdapter)
registry.register(r"^bedrock.*.mistral.mistral-7b-*", BedrockChatNoSystemPromptAdapter)
registry.register(r"^bedrock.*.mistral.mixtral-*", BedrockChatNoSystemPromptAdapter)
101 changes: 74 additions & 27 deletions lib/shared/layers/python-sdk/python/genai_core/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ def list_models():
if bedrock_models:
models.extend(bedrock_models)

bedrock_cris_models = list_bedrock_cris_models()
if bedrock_cris_models:
models.extend(bedrock_cris_models)

fine_tuned_models = list_bedrock_finetuned_models()
if fine_tuned_models:
models.extend(fine_tuned_models)
Expand Down Expand Up @@ -80,6 +84,73 @@ def list_azure_openai_models():
]


# Based on the table (Need to support both document and sytem prompt)
# https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference-supported-models-features.html
def does_model_support_documents(model_name):
return (
not re.match(r"^ai21.jamba*", model_name)
and not re.match(r"^ai21.j2*", model_name)
and not re.match(r"^amazon.titan-t*", model_name)
and not re.match(r"^cohere.command-light*", model_name)
and not re.match(r"^cohere.command-text*", model_name)
and not re.match(r"^mistral.mistral-7b-instruct-*", model_name)
and not re.match(r"^mistral.mistral-small*", model_name)
and not re.match(r"^amazon.nova-reel*", model_name)
and not re.match(r"^amazon.nova-canvas*", model_name)
and not re.match(r"^amazon.nova-micro*", model_name)
)


def create_bedrock_model_profile(bedrock_model: dict, model_name: str) -> dict:
model = {
"provider": Provider.BEDROCK.value,
"name": model_name,
"streaming": bedrock_model.get("responseStreamingSupported", False),
"inputModalities": bedrock_model["inputModalities"],
"outputModalities": bedrock_model["outputModalities"],
"interface": ModelInterface.LANGCHAIN.value,
"ragSupported": True,
"bedrockGuardrails": True,
}

if does_model_support_documents(model["name"]):
model["inputModalities"].append("DOCUMENT")
return model


def list_cross_region_inference_profiles():
bedrock = genai_core.clients.get_bedrock_client(service_name="bedrock")
response = bedrock.list_inference_profiles()

return {
inference_profile["models"][0]["modelArn"].split("/")[1]: inference_profile[
"inferenceProfileId"
]
for inference_profile in response.get("inferenceProfileSummaries", [])
if (
inference_profile.get("status") == "ACTIVE"
and inference_profile.get("type") == "SYSTEM_DEFINED"
)
}


def list_bedrock_cris_models():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean list_bedrock_cross_models?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cris is short for cross region inference profiles

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approved, thank you!

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, thanks!

try:
cross_region_profiles = list_cross_region_inference_profiles()
bedrock_client = genai_core.clients.get_bedrock_client(service_name="bedrock")
all_models = bedrock_client.list_foundation_models()["modelSummaries"]

return [
create_bedrock_model_profile(model, cross_region_profiles[model["modelId"]])
for model in all_models
if genai_core.types.InferenceType.INFERENCE_PROFILE.value
in model["inferenceTypesSupported"]
]
except Exception as e:
logger.error(f"Error listing cross region inference profiles models: {e}")
return None


def list_bedrock_models():
try:
bedrock = genai_core.clients.get_bedrock_client(service_name="bedrock")
Expand Down Expand Up @@ -108,33 +179,9 @@ def list_bedrock_models():
)
):
continue
model = {
"provider": Provider.BEDROCK.value,
"name": bedrock_model["modelId"],
"streaming": bedrock_model.get("responseStreamingSupported", False),
"inputModalities": bedrock_model["inputModalities"],
"outputModalities": bedrock_model["outputModalities"],
"interface": ModelInterface.LANGCHAIN.value,
"ragSupported": True,
"bedrockGuardrails": True,
}
# Based on the table (Need to support both document and sytem prompt)
# https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference-supported-models-features.html
if (
not re.match(r"^ai21.jamba*", model["name"])
and not re.match(r"^ai21.j2*", model["name"])
and not re.match(r"^amazon.titan-t*", model["name"])
and not re.match(r"^cohere.command-light*", model["name"])
and not re.match(r"^cohere.command-text*", model["name"])
and not re.match(r"^mistral.mistral-7b-instruct-*", model["name"])
and not re.match(r"^mistral.mistral-small*", model["name"])
and not re.match(r"^amazon.nova-reel*", model["name"])
and not re.match(r"^amazon.nova-canvas*", model["name"])
and not re.match(r"^amazon.nova-micro*", model["name"])
):
model["inputModalities"].append("DOCUMENT")

models.append(model)
models.append(
create_bedrock_model_profile(bedrock_model, bedrock_model["modelId"])
)

return models
except Exception as e:
Expand Down
1 change: 1 addition & 0 deletions lib/shared/layers/python-sdk/python/genai_core/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ class Modality(Enum):
class InferenceType(Enum):
ON_DEMAND = "ON_DEMAND"
PROVISIONED = "PROVISIONED"
INFERENCE_PROFILE = "INFERENCE_PROFILE"


class ModelStatus(Enum):
Expand Down
1 change: 1 addition & 0 deletions tests/__snapshots__/cdk-app.test.ts.snap
Original file line number Diff line number Diff line change
Expand Up @@ -4428,6 +4428,7 @@ schema {
"Action": [
"bedrock:ListFoundationModels",
"bedrock:ListCustomModels",
"bedrock:ListInferenceProfiles",
"bedrock:InvokeModel",
"bedrock:InvokeModelWithResponseStream",
],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4407,6 +4407,7 @@ schema {
"Action": [
"bedrock:ListFoundationModels",
"bedrock:ListCustomModels",
"bedrock:ListInferenceProfiles",
"bedrock:InvokeModel",
"bedrock:InvokeModelWithResponseStream",
],
Expand Down