Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Amazon Bedrock Retriever to leverage Knowledge Base deployed onto AWS #1219

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions packages/community/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
- Bedrock support for the Anthropic Claude Models [usage](https://ts.llamaindex.ai/modules/llms/available_llms/bedrock)
- Bedrock support for the Meta LLama 2, 3 and 3.1 Models [usage](https://ts.llamaindex.ai/modules/llms/available_llms/bedrock)
- Meta LLama3.1 405b tool call support
- Bedrock support for querying Knowledge Base

## LICENSE

Expand Down
2 changes: 2 additions & 0 deletions packages/community/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@
},
"dependencies": {
"@aws-sdk/client-bedrock-runtime": "^3.642.0",
"@aws-sdk/client-bedrock-agent-runtime": "^3.642.0",
"llamaindex": "workspace:*",
"@llamaindex/core": "workspace:*",
"@llamaindex/env": "workspace:*"
}
Expand Down
1 change: 1 addition & 0 deletions packages/community/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ export {
BEDROCK_MODEL_MAX_TOKENS,
Bedrock,
} from "./llm/bedrock/index.js";
export { AmazonKnowledgeBaseRetriever } from "./retrievers/bedrock.js";
160 changes: 160 additions & 0 deletions packages/community/src/retrievers/bedrock.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
import {
BedrockAgentRuntimeClient,
RetrieveCommand,
type BedrockAgentRuntimeClientConfig,
type RetrievalFilter,
type SearchType,
} from "@aws-sdk/client-bedrock-agent-runtime";

import { BaseRetriever, Document } from "llamaindex";
Copy link
Member

@himself65 himself65 Sep 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

import Document from llamaindex/core/schema and copy BaseRetriever type here for now.

llamaindex is the to end user level pacakge and shouldn't be imported from our other pacakge, so please remove llamaindex from deps

remind me that we need move BaseRetriever into core module. If you are interested please follow python repo file structure and my other PR in this repo. Or I will do it tmr

Copy link
Author

@ajohn-wick ajohn-wick Sep 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you a lot for your review, please proceed with moving BaseRetriever into core module. Then, I will update code to import it if you are OKay @himself65?
Otherwise, I will copy BaseRetriever type here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

too late to me, will do tmr


/**
* Interface for the arguments required to initialize an
* AmazonKnowledgeBaseRetriever instance.
*/
export interface AmazonKnowledgeBaseRetrieverArgs {
knowledgeBaseId: string;
topK: number;
region: string;
clientOptions?: BedrockAgentRuntimeClientConfig;
filter?: RetrievalFilter;
overrideSearchType?: SearchType;
}

/**
* Class for interacting with Amazon Bedrock Knowledge Bases, a RAG workflow oriented service
* Extends the BaseRetriever class.
* @example
* ```typescript
* const retriever = new AmazonKnowledgeBaseRetriever({
* topK: 10,
* knowledgeBaseId: "YOUR_KNOWLEDGE_BASE_ID",
* region: "us-east-2",
* clientOptions: {
* credentials: {
* accessKeyId: "YOUR_ACCESS_KEY_ID",
* secretAccessKey: "YOUR_SECRET_ACCESS_KEY",
* },
* },
* });
*
* const docs = await retriever.retrieve({query: "How are clouds formed?"});
* ```
*/
export class AmazonKnowledgeBaseRetriever extends BaseRetriever {
static lc_name() {
return "AmazonKnowledgeBaseRetriever";
}

lc_namespace = ["llamaindex", "retrievers", "amazon_bedrock_knowledge_base"];

knowledgeBaseId: string;

topK: number;

bedrockAgentRuntimeClient: BedrockAgentRuntimeClient;

filter?: RetrievalFilter;

overrideSearchType?: SearchType;

constructor({
knowledgeBaseId,
topK = 10,
clientOptions,
region,
filter,
overrideSearchType,
}: AmazonKnowledgeBaseRetrieverArgs) {
super();

this.topK = topK;
this.filter = filter;
this.overrideSearchType = overrideSearchType;
this.bedrockAgentRuntimeClient = new BedrockAgentRuntimeClient({
region,
...clientOptions,
});
this.knowledgeBaseId = knowledgeBaseId;
}

/**
* Cleans the result text by replacing sequences of whitespace with a
* single space and removing ellipses.
* @param resText The result text to clean.
* @returns The cleaned result text.
*/
cleanResult(resText: string) {
const res = resText.replace(/\s+/g, " ").replace(/\.\.\./g, "");
return res;
}

async queryKnowledgeBase(
query: string,
topK: number,
filter?: RetrievalFilter,
overrideSearchType?: SearchType,
) {
const retrieveCommand = new RetrieveCommand({
knowledgeBaseId: this.knowledgeBaseId,
retrievalQuery: {
text: query,
},
retrievalConfiguration: {
vectorSearchConfiguration: {
numberOfResults: topK,
overrideSearchType,
filter,
},
},
});

const retrieveResponse =
await this.bedrockAgentRuntimeClient.send(retrieveCommand);

return (
retrieveResponse.retrievalResults?.map((result: any) => {
let source;
switch (result.location?.type) {
case "CONFLUENCE":
source = result.location?.confluenceLocation?.url;
break;
case "S3":
source = result.location?.s3Location?.uri;
break;
case "SALESFORCE":
source = result.location?.salesforceLocation?.url;
break;
case "SHAREPOINT":
source = result.location?.sharePointLocation?.url;
break;
case "WEB":
source = result.location?.webLocation?.url;
break;
default:
source = result.location?.s3Location?.uri;
break;
}

return {
pageContent: this.cleanResult(result.content?.text || ""),
metadata: {
source,
score: result.score,
...result.metadata,
},
};
}) ?? ([] as Array<Document>)
);
}

async _retrieve(query: string): Promise<Document[]> {
const docs = await this.queryKnowledgeBase(
query,
this.topK,
this.filter,
this.overrideSearchType,
);
return docs;
}
}
Loading