Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add Amazon Bedrock Retriever #1219

Merged
merged 11 commits into from
Sep 23, 2024
5 changes: 5 additions & 0 deletions .changeset/sour-dots-try.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"@llamaindex/community": patch
---

feat: add Amazon Bedrock Retriever
1 change: 1 addition & 0 deletions packages/community/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
- Bedrock support for the Anthropic Claude Models [usage](https://ts.llamaindex.ai/modules/llms/available_llms/bedrock)
- Bedrock support for the Meta LLama 2, 3 and 3.1 Models [usage](https://ts.llamaindex.ai/modules/llms/available_llms/bedrock)
- Meta LLama3.1 405b tool call support
- Bedrock support for querying Knowledge Base

## LICENSE

Expand Down
1 change: 1 addition & 0 deletions packages/community/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
},
"dependencies": {
"@aws-sdk/client-bedrock-runtime": "^3.642.0",
"@aws-sdk/client-bedrock-agent-runtime": "^3.642.0",
"@llamaindex/core": "workspace:*",
"@llamaindex/env": "workspace:*"
}
Expand Down
1 change: 1 addition & 0 deletions packages/community/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ export {
BEDROCK_MODEL_MAX_TOKENS,
Bedrock,
} from "./llm/bedrock/index.js";
export { AmazonKnowledgeBaseRetriever } from "./retrievers/bedrock.js";
165 changes: 165 additions & 0 deletions packages/community/src/retrievers/bedrock.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
import type { KnowledgeBaseVectorSearchConfiguration } from "@aws-sdk/client-bedrock-agent-runtime";
import {
BedrockAgentRuntimeClient,
type BedrockAgentRuntimeClientConfig,
type RetrievalFilter,
RetrieveCommand,
type SearchType,
} from "@aws-sdk/client-bedrock-agent-runtime";
import type { QueryBundle } from "@llamaindex/core/query-engine";
import { BaseRetriever } from "@llamaindex/core/retriever";
import { Document, type NodeWithScore } from "@llamaindex/core/schema";
import { extractText } from "@llamaindex/core/utils";

/**
* Interface for the arguments required to initialize an
* AmazonKnowledgeBaseRetriever instance.
*/
export interface AmazonKnowledgeBaseRetrieverArgs {
knowledgeBaseId: string;
topK: number;
region: string;
clientOptions?: BedrockAgentRuntimeClientConfig;
filter?: RetrievalFilter;
overrideSearchType?: SearchType;
}

/**
* Class for interacting with Amazon Bedrock Knowledge Bases, a RAG workflow oriented service
* Extends the BaseRetriever class.
* @example
* ```typescript
* const retriever = new AmazonKnowledgeBaseRetriever({
* topK: 10,
* knowledgeBaseId: "YOUR_KNOWLEDGE_BASE_ID",
* region: "us-east-2",
* clientOptions: {
* credentials: {
* accessKeyId: "YOUR_ACCESS_KEY_ID",
* secretAccessKey: "YOUR_SECRET_ACCESS_KEY",
* },
* },
* });
*
* const docs = await retriever.retrieve({query: "How are clouds formed?"});
* ```
*/
export class AmazonKnowledgeBaseRetriever extends BaseRetriever {
static lc_name() {
return "AmazonKnowledgeBaseRetriever";
}

lc_namespace = ["llamaindex", "retrievers", "amazon_bedrock_knowledge_base"];

knowledgeBaseId: string;

topK: number;

bedrockAgentRuntimeClient: BedrockAgentRuntimeClient;

filter: RetrievalFilter | undefined;

overrideSearchType: SearchType | undefined;

constructor({
knowledgeBaseId,
topK = 10,
clientOptions,
region,
filter,
overrideSearchType,
}: AmazonKnowledgeBaseRetrieverArgs) {
super();

this.topK = topK;
this.filter = filter;
this.overrideSearchType = overrideSearchType;
this.bedrockAgentRuntimeClient = new BedrockAgentRuntimeClient({
region,
...clientOptions,
});
this.knowledgeBaseId = knowledgeBaseId;
}

/**
* Cleans the result text by replacing sequences of whitespace with a
* single space and removing ellipses.
* @param resText The result text to clean.
* @returns The cleaned result text.
*/
cleanResult(resText: string) {
const res = resText.replace(/\s+/g, " ").replace(/\.\.\./g, "");
return res;
}

async queryKnowledgeBase(
query: QueryBundle,
topK: number,
filter?: RetrievalFilter,
overrideSearchType?: SearchType,
): Promise<NodeWithScore[]> {
const retrieveCommand = new RetrieveCommand({
knowledgeBaseId: this.knowledgeBaseId,
retrievalQuery: {
text: extractText(query),
},
retrievalConfiguration: {
vectorSearchConfiguration: {
numberOfResults: topK,
overrideSearchType,
filter,
} as KnowledgeBaseVectorSearchConfiguration,
},
});

const retrieveResponse =
await this.bedrockAgentRuntimeClient.send(retrieveCommand);

return (
retrieveResponse.retrievalResults?.map((result) => {
let source;
switch (result.location?.type) {
case "CONFLUENCE":
source = result.location?.confluenceLocation?.url;
break;
case "S3":
source = result.location?.s3Location?.uri;
break;
case "SALESFORCE":
source = result.location?.salesforceLocation?.url;
break;
case "SHAREPOINT":
source = result.location?.sharePointLocation?.url;
break;
case "WEB":
source = result.location?.webLocation?.url;
break;
default:
source = result.location?.s3Location?.uri;
break;
}

return {
node: new Document({
text: this.cleanResult(result.content?.text || ""),
metadata: {
source,
score: result.score,
...result.metadata,
},
}),
score: result.score ?? 1.0,
};
}) ?? []
);
}

async _retrieve(query: QueryBundle): Promise<NodeWithScore[]> {
return await this.queryKnowledgeBase(
query,
this.topK,
this.filter,
this.overrideSearchType,
);
}
}
Loading