Skip to content

Commit

Permalink
Update for AWS Bedrock knowledge bases to support filters and overrid…
Browse files Browse the repository at this point in the history
…eSearchType, Update the KB to support other locations as sources (langchain-ai#6189)

Co-authored-by: Brace Sproul <[email protected]>
  • Loading branch information
jl4nz and bracesproul authored Jul 24, 2024
1 parent 29c08a3 commit 63305a0
Show file tree
Hide file tree
Showing 4 changed files with 931 additions and 16 deletions.
6 changes: 3 additions & 3 deletions libs/langchain-aws/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,15 @@
"author": "LangChain",
"license": "MIT",
"dependencies": {
"@aws-sdk/client-bedrock-agent-runtime": "^3.583.0",
"@aws-sdk/client-bedrock-agent-runtime": "^3.616.0",
"@aws-sdk/client-bedrock-runtime": "^3.602.0",
"@aws-sdk/client-kendra": "^3.352.0",
"@aws-sdk/credential-provider-node": "^3.600.0",
"@langchain/core": ">=0.2.16 <0.3.0",
"zod-to-json-schema": "^3.22.5"
},
"devDependencies": {
"@aws-sdk/types": "^3.598.0",
"@aws-sdk/types": "^3.609.0",
"@jest/globals": "^29.5.0",
"@langchain/scripts": "~0.0.14",
"@langchain/standard-tests": "0.0.0",
Expand Down Expand Up @@ -97,4 +97,4 @@
"index.d.ts",
"index.d.cts"
]
}
}
68 changes: 58 additions & 10 deletions libs/langchain-aws/src/retrievers/bedrock.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ import {
RetrieveCommand,
BedrockAgentRuntimeClient,
type BedrockAgentRuntimeClientConfig,
type SearchType,
type RetrievalFilter,
} from "@aws-sdk/client-bedrock-agent-runtime";

import { BaseRetriever } from "@langchain/core/retrievers";
Expand All @@ -16,6 +18,8 @@ export interface AmazonKnowledgeBaseRetrieverArgs {
topK: number;
region: string;
clientOptions?: BedrockAgentRuntimeClientConfig;
filter?: RetrievalFilter;
overrideSearchType?: SearchType;
}

/**
Expand Down Expand Up @@ -51,15 +55,23 @@ export class AmazonKnowledgeBaseRetriever extends BaseRetriever {

bedrockAgentRuntimeClient: BedrockAgentRuntimeClient;

filter?: RetrievalFilter;

overrideSearchType?: SearchType;

constructor({
knowledgeBaseId,
topK = 10,
clientOptions,
region,
filter,
overrideSearchType,
}: AmazonKnowledgeBaseRetrieverArgs) {
super();

this.topK = topK;
this.filter = filter;
this.overrideSearchType = overrideSearchType;
this.bedrockAgentRuntimeClient = new BedrockAgentRuntimeClient({
region,
...clientOptions,
Expand All @@ -78,7 +90,12 @@ export class AmazonKnowledgeBaseRetriever extends BaseRetriever {
return res;
}

async queryKnowledgeBase(query: string, topK: number) {
async queryKnowledgeBase(
query: string,
topK: number,
filter?: RetrievalFilter,
overrideSearchType?: SearchType
) {
const retrieveCommand = new RetrieveCommand({
knowledgeBaseId: this.knowledgeBaseId,
retrievalQuery: {
Expand All @@ -87,6 +104,8 @@ export class AmazonKnowledgeBaseRetriever extends BaseRetriever {
retrievalConfiguration: {
vectorSearchConfiguration: {
numberOfResults: topK,
overrideSearchType,
filter,
},
},
});
Expand All @@ -96,19 +115,48 @@ export class AmazonKnowledgeBaseRetriever extends BaseRetriever {
);

return (
retrieveResponse.retrievalResults?.map((result) => ({
pageContent: this.cleanResult(result.content?.text || ""),
metadata: {
source: result.location?.s3Location?.uri,
score: result.score,
...result.metadata,
},
})) ?? ([] as Array<Document>)
retrieveResponse.retrievalResults?.map((result) => {
let source;
switch (result.location?.type) {
case "CONFLUENCE":
source = result.location?.confluenceLocation?.url;
break;
case "S3":
source = result.location?.s3Location?.uri;
break;
case "SALESFORCE":
source = result.location?.salesforceLocation?.url;
break;
case "SHAREPOINT":
source = result.location?.sharePointLocation?.url;
break;
case "WEB":
source = result.location?.webLocation?.url;
break;
default:
source = result.location?.s3Location?.uri;
break;
}

return {
pageContent: this.cleanResult(result.content?.text || ""),
metadata: {
source,
score: result.score,
...result.metadata,
},
};
}) ?? ([] as Array<Document>)
);
}

async _getRelevantDocuments(query: string): Promise<Document[]> {
const docs = await this.queryKnowledgeBase(query, this.topK);
const docs = await this.queryKnowledgeBase(
query,
this.topK,
this.filter,
this.overrideSearchType
);
return docs;
}
}
2 changes: 2 additions & 0 deletions libs/langchain-aws/src/retrievers/tests/bedrock.int.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ test("AmazonKnowledgeBaseRetriever", async () => {
topK: 10,
knowledgeBaseId: process.env.AMAZON_KNOWLEDGE_BASE_ID || "",
region: process.env.BEDROCK_AWS_REGION,
overrideSearchType: "HYBRID",
filter: undefined,
clientOptions: {
credentials: {
accessKeyId: process.env.BEDROCK_AWS_ACCESS_KEY_ID,
Expand Down
Loading

0 comments on commit 63305a0

Please sign in to comment.