diff --git a/lib/aws-genai-llm-chatbot-stack.ts b/lib/aws-genai-llm-chatbot-stack.ts index 78a0e55f3..cee06fa3e 100644 --- a/lib/aws-genai-llm-chatbot-stack.ts +++ b/lib/aws-genai-llm-chatbot-stack.ts @@ -56,6 +56,7 @@ export class AwsGenAILLMChatbotStack extends cdk.Stack { ragEngines: ragEngines, userPool: authentication.userPool, modelsParameter: models.modelsParameter, + bedrockEnabledModelsParameter: models.bedrockEnabledModelsParameter, models: models.models, }); diff --git a/lib/chatbot-api/index.ts b/lib/chatbot-api/index.ts index 4a8051a87..ef2e342ad 100644 --- a/lib/chatbot-api/index.ts +++ b/lib/chatbot-api/index.ts @@ -25,6 +25,7 @@ export interface ChatBotApiProps { readonly ragEngines?: RagEngines; readonly userPool: cognito.UserPool; readonly modelsParameter: ssm.StringParameter; + readonly bedrockEnabledModelsParameter: ssm.StringParameter; readonly models: SageMakerModelEndpoint[]; } diff --git a/lib/chatbot-api/rest-api.ts b/lib/chatbot-api/rest-api.ts index 5d8e1a03b..988ef031f 100644 --- a/lib/chatbot-api/rest-api.ts +++ b/lib/chatbot-api/rest-api.ts @@ -25,6 +25,7 @@ export interface ApiResolversProps { readonly byUserIdIndex: string; readonly userFeedbackBucket: s3.Bucket; readonly modelsParameter: ssm.StringParameter; + readonly bedrockEnabledModelsParameter: ssm.StringParameter; readonly models: SageMakerModelEndpoint[]; readonly api: appsync.GraphqlApi; } @@ -59,6 +60,7 @@ export class ApiResolvers extends Construct { ...props.shared.defaultEnvironmentVariables, CONFIG_PARAMETER_NAME: props.shared.configParameter.parameterName, MODELS_PARAMETER_NAME: props.modelsParameter.parameterName, + BEDROCK_ENABLED_MODELS_PARAMETER_NAME: props.bedrockEnabledModelsParameter.parameterName, X_ORIGIN_VERIFY_SECRET_ARN: props.shared.xOriginVerifySecret.secretArn, API_KEYS_SECRETS_ARN: props.shared.apiKeysSecret.secretArn, @@ -290,6 +292,7 @@ export class ApiResolvers extends Construct { props.shared.apiKeysSecret.grantRead(apiHandler); props.shared.configParameter.grantRead(apiHandler); props.modelsParameter.grantRead(apiHandler); + props.bedrockEnabledModelsParameter.grantRead(apiHandler); props.sessionsTable.grantReadWriteData(apiHandler); props.userFeedbackBucket.grantReadWrite(apiHandler); props.ragEngines?.uploadBucket.grantReadWrite(apiHandler); diff --git a/lib/models/index.ts b/lib/models/index.ts index 72a476cb4..0f6844ec5 100644 --- a/lib/models/index.ts +++ b/lib/models/index.ts @@ -28,11 +28,13 @@ export interface ModelsProps { export class Models extends Construct { public readonly models: SageMakerModelEndpoint[]; public readonly modelsParameter: ssm.StringParameter; + public readonly bedrockEnabledModelsParameter: ssm.StringParameter; constructor(scope: Construct, id: string, props: ModelsProps) { super(scope, id); const models: SageMakerModelEndpoint[] = []; + const bedrockEnabledModels: string[] = ['anthropic.claude-3-haiku-20240307-v1:0', 'anthropic.claude-3-sonnet-20240229-v1:0']; let hfTokenSecret: secretsmanager.Secret | undefined; if (props.config.llms.huggingfaceApiSecretArn) { @@ -386,8 +388,15 @@ export class Models extends Construct { ), }); + const bedrockEnabledModelsParameter = new ssm.StringParameter(this, "BedrockEnabledModelsParameter", { + stringValue: JSON.stringify( + bedrockEnabledModels + ), + }); + this.models = models; this.modelsParameter = modelsParameter; + this.bedrockEnabledModelsParameter = bedrockEnabledModelsParameter; if (models.length > 0 && props.config.llms?.sagemakerSchedule?.enabled) { const schedulerRole: iam.Role = new iam.Role(this, "SchedulerRole", { diff --git a/lib/shared/layers/python-sdk/python/genai_core/models.py b/lib/shared/layers/python-sdk/python/genai_core/models.py index 5d0a4a22a..41d59743a 100644 --- a/lib/shared/layers/python-sdk/python/genai_core/models.py +++ b/lib/shared/layers/python-sdk/python/genai_core/models.py @@ -83,11 +83,14 @@ def list_bedrock_models(): byInferenceType=genai_core.types.InferenceType.ON_DEMAND.value, byOutputModality=genai_core.types.Modality.TEXT.value, ) + + enabledModels = genai_core.parameters.get_enabled_bedrock_models() + bedrock_models = [ m for m in response.get("modelSummaries", []) if m.get("modelLifecycle", {}).get("status") - == genai_core.types.ModelStatus.ACTIVE.value + == genai_core.types.ModelStatus.ACTIVE.value and m.get("modelId") in enabledModels ] models = [ diff --git a/lib/shared/layers/python-sdk/python/genai_core/parameters.py b/lib/shared/layers/python-sdk/python/genai_core/parameters.py index b7eb65a4c..eef363c5d 100644 --- a/lib/shared/layers/python-sdk/python/genai_core/parameters.py +++ b/lib/shared/layers/python-sdk/python/genai_core/parameters.py @@ -5,6 +5,7 @@ API_KEYS_SECRETS_ARN = os.environ.get("API_KEYS_SECRETS_ARN") CONFIG_PARAMETER_NAME = os.environ.get("CONFIG_PARAMETER_NAME") MODELS_PARAMETER_NAME = os.environ.get("MODELS_PARAMETER_NAME") +BEDROCK_ENABLED_MODELS_PARAMETER_NAME = os.environ.get("BEDROCK_ENABLED_MODELS_PARAMETER_NAME") def get_external_api_key(name: str): @@ -32,3 +33,6 @@ def get_config(): def get_sagemaker_models(): return parameters.get_parameter(MODELS_PARAMETER_NAME, transform="json", max_age=30) + +def get_enabled_bedrock_models(): + return parameters.get_parameter(BEDROCK_ENABLED_MODELS_PARAMETER_NAME, transform="json", max_age=30)