Skip to content

Commit

Permalink
Add AWS Kendra Retriever (langchain-ai#1789)
Browse files Browse the repository at this point in the history
* feat: add kendra integration

* add: integrate kendra retriever api

* linting: fix warnings

* add: docs + examples

* add: clientOptions + error if no region & indexId

* Rename to AmazonKendraRetriever, add test

---------

Co-authored-by: jacoblee93 <[email protected]>
  • Loading branch information
NickMandylas and jacoblee93 authored Jul 19, 2023
1 parent 460e110 commit e02a886
Show file tree
Hide file tree
Showing 11 changed files with 649 additions and 2 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
---
hide_table_of_contents: true
---

# Amazon Kendra Retriever

Amazon Kendra is an intelligent search service provided by Amazon Web Services (AWS). It utilizes advanced natural language processing (NLP) and machine learning algorithms to enable powerful search capabilities across various data sources within an organization. Kendra is designed to help users find the information they need quickly and accurately, improving productivity and decision-making.

With Kendra, users can search across a wide range of content types, including documents, FAQs, knowledge bases, manuals, and websites. It supports multiple languages and can understand complex queries, synonyms, and contextual meanings to provide highly relevant search results.

## Setup

```bash npm2yarn
npm i @aws-sdk/client-kendra
```

## Usage

import CodeBlock from "@theme/CodeBlock";
import Example from "@examples/retrievers/kendra.ts";

<CodeBlock language="typescript">{Example}</CodeBlock>
17 changes: 17 additions & 0 deletions examples/src/retrievers/kendra.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import { AmazonKendraRetriever } from "langchain/retrievers/amazon_kendra";

const retriever = new AmazonKendraRetriever({
topK: 10,
indexId: "YOUR_INDEX_ID",
region: "us-east-2", // Your region
clientOptions: {
credentials: {
accessKeyId: "YOUR_ACCESS_KEY_ID",
secretAccessKey: "YOUR_SECRET_ACCESS_KEY",
},
},
});

const docs = await retriever.getRelevantDocuments("How are clouds formed?");

console.log(docs);
3 changes: 3 additions & 0 deletions langchain/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,9 @@ output_parsers/expression.d.ts
retrievers.cjs
retrievers.js
retrievers.d.ts
retrievers/amazon_kendra.cjs
retrievers/amazon_kendra.js
retrievers/amazon_kendra.d.ts
retrievers/remote.cjs
retrievers/remote.js
retrievers/remote.d.ts
Expand Down
13 changes: 13 additions & 0 deletions langchain/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,9 @@
"retrievers.cjs",
"retrievers.js",
"retrievers.d.ts",
"retrievers/amazon_kendra.cjs",
"retrievers/amazon_kendra.js",
"retrievers/amazon_kendra.d.ts",
"retrievers/remote.cjs",
"retrievers/remote.js",
"retrievers/remote.d.ts",
Expand Down Expand Up @@ -499,6 +502,7 @@
"license": "MIT",
"devDependencies": {
"@aws-sdk/client-dynamodb": "^3.310.0",
"@aws-sdk/client-kendra": "^3.352.0",
"@aws-sdk/client-lambda": "^3.310.0",
"@aws-sdk/client-s3": "^3.310.0",
"@aws-sdk/client-sagemaker-runtime": "^3.310.0",
Expand Down Expand Up @@ -590,6 +594,7 @@
},
"peerDependencies": {
"@aws-sdk/client-dynamodb": "^3.310.0",
"@aws-sdk/client-kendra": "^3.352.0",
"@aws-sdk/client-lambda": "^3.310.0",
"@aws-sdk/client-s3": "^3.310.0",
"@aws-sdk/client-sagemaker-runtime": "^3.310.0",
Expand Down Expand Up @@ -651,6 +656,9 @@
"@aws-sdk/client-dynamodb": {
"optional": true
},
"@aws-sdk/client-kendra": {
"optional": true
},
"@aws-sdk/client-lambda": {
"optional": true
},
Expand Down Expand Up @@ -1458,6 +1466,11 @@
"require": "./retrievers.cjs"
}
},
"./retrievers/amazon_kendra": {
"types": "./retrievers/amazon_kendra.d.ts",
"import": "./retrievers/amazon_kendra.js",
"require": "./retrievers/amazon_kendra.cjs"
},
"./retrievers/remote": {
"types": "./retrievers/remote.d.ts",
"import": "./retrievers/remote.js",
Expand Down
8 changes: 6 additions & 2 deletions langchain/scripts/create-entrypoints.js
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,8 @@ const entrypoints = {
"document_loaders/web/s3": "document_loaders/web/s3",
"document_loaders/web/sonix_audio": "document_loaders/web/sonix_audio",
"document_loaders/web/confluence": "document_loaders/web/confluence",
"document_loaders/web/sort_xyz_blockchain": "document_loaders/web/sort_xyz_blockchain",
"document_loaders/web/sort_xyz_blockchain":
"document_loaders/web/sort_xyz_blockchain",
"document_loaders/fs/directory": "document_loaders/fs/directory",
"document_loaders/fs/buffer": "document_loaders/fs/buffer",
"document_loaders/fs/text": "document_loaders/fs/text",
Expand All @@ -121,7 +122,8 @@ const entrypoints = {
"document_loaders/fs/notion": "document_loaders/fs/notion",
"document_loaders/fs/unstructured": "document_loaders/fs/unstructured",
// document_transformers
"document_transformers/openai_functions": "document_transformers/openai_functions",
"document_transformers/openai_functions":
"document_transformers/openai_functions",
// chat_models
chat_models: "chat_models/index",
"chat_models/base": "chat_models/base",
Expand All @@ -144,6 +146,7 @@ const entrypoints = {
"output_parsers/expression": "output_parsers/expression",
// retrievers
retrievers: "retrievers/index",
"retrievers/amazon_kendra": "retrievers/amazon_kendra",
"retrievers/remote": "retrievers/remote/index",
"retrievers/supabase": "retrievers/supabase",
"retrievers/zep": "retrievers/zep",
Expand Down Expand Up @@ -275,6 +278,7 @@ const requiresOptionalDependency = [
"chat_models/googlevertexai",
"chat_models/googlepalm",
"sql_db",
"retrievers/amazon_kendra",
"retrievers/supabase",
"retrievers/zep",
"retrievers/metal",
Expand Down
1 change: 1 addition & 0 deletions langchain/src/load/import_constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ export const optionalImportEntrypoints = [
"langchain/chat_models/googlepalm",
"langchain/sql_db",
"langchain/output_parsers/expression",
"langchain/retrievers/amazon_kendra",
"langchain/retrievers/supabase",
"langchain/retrievers/zep",
"langchain/retrievers/metal",
Expand Down
3 changes: 3 additions & 0 deletions langchain/src/load/import_type.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,9 @@ export interface OptionalImportMap {
"langchain/output_parsers/expression"?:
| typeof import("../output_parsers/expression.js")
| Promise<typeof import("../output_parsers/expression.js")>;
"langchain/retrievers/amazon_kendra"?:
| typeof import("../retrievers/amazon_kendra.js")
| Promise<typeof import("../retrievers/amazon_kendra.js")>;
"langchain/retrievers/supabase"?:
| typeof import("../retrievers/supabase.js")
| Promise<typeof import("../retrievers/supabase.js")>;
Expand Down
230 changes: 230 additions & 0 deletions langchain/src/retrievers/amazon_kendra.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
import {
AttributeFilter,
DocumentAttribute,
DocumentAttributeValue,
KendraClient,
KendraClientConfig,
QueryCommand,
QueryCommandOutput,
QueryResultItem,
RetrieveCommand,
RetrieveCommandOutput,
RetrieveResultItem,
} from "@aws-sdk/client-kendra";

import { BaseRetriever } from "../schema/retriever.js";
import { Document } from "../document.js";

export interface AmazonKendraRetrieverArgs {
indexId: string;
topK: number;
region: string;
attributeFilter?: AttributeFilter;
clientOptions?: KendraClientConfig;
}

export class AmazonKendraRetriever extends BaseRetriever {
lc_namespace = ["langchain", "retrievers", "amazon_kendra"];

indexId: string;

topK: number;

kendraClient: KendraClient;

attributeFilter?: AttributeFilter;

constructor({
indexId,
topK = 10,
clientOptions,
attributeFilter,
region,
}: AmazonKendraRetrieverArgs) {
super();

if (!region) {
throw new Error("Please pass regionName field to the constructor!");
}

if (!indexId) {
throw new Error("Please pass Kendra Index Id to the constructor");
}

this.topK = topK;
this.kendraClient = new KendraClient({
region,
...clientOptions,
});
this.attributeFilter = attributeFilter;
this.indexId = indexId;
}

// A method to combine title and excerpt into a single string.
combineText(title?: string, excerpt?: string): string {
let text = "";
if (title) {
text += `Document Title: ${title}\n`;
}
if (excerpt) {
text += `Document Excerpt: \n${excerpt}\n`;
}
return text;
}

// A method to clean the result text by replacing sequences of whitespace with a single space and removing ellipses.
cleanResult(resText: string) {
const res = resText.replace(/\s+/g, " ").replace(/\.\.\./g, "");
return res;
}

// A method to extract the attribute value from a DocumentAttributeValue object.
getDocAttributeValue(value: DocumentAttributeValue) {
if (value.DateValue) {
return value.DateValue;
}
if (value.LongValue) {
return value.LongValue;
}
if (value.StringListValue) {
return value.StringListValue;
}
if (value.StringValue) {
return value.StringValue;
}
return "";
}

// A method to extract the attribute key-value pairs from an array of DocumentAttribute objects.
getDocAttributes(documentAttributes?: DocumentAttribute[]): {
[key: string]: unknown;
} {
const attributes: { [key: string]: unknown } = {};
if (documentAttributes) {
for (const attr of documentAttributes) {
if (attr.Key && attr.Value) {
attributes[attr.Key] = this.getDocAttributeValue(attr.Value);
}
}
}
return attributes;
}

// A method to convert a RetrieveResultItem object into a Document object.
convertRetrieverItem(item: RetrieveResultItem) {
const title = item.DocumentTitle || "";
const excerpt = item.Content ? this.cleanResult(item.Content) : "";
const pageContent = this.combineText(title, excerpt);
const source = item.DocumentURI;
const attributes = this.getDocAttributes(item.DocumentAttributes);
const metadata = {
source,
title,
excerpt,
document_attributes: attributes,
};

return new Document({ pageContent, metadata });
}

// A method to extract the top-k documents from a RetrieveCommandOutput object.
getRetrieverDocs(
response: RetrieveCommandOutput,
pageSize: number
): Document[] {
if (!response.ResultItems) return [];
const { length } = response.ResultItems;
const count = length < pageSize ? length : pageSize;

return response.ResultItems.slice(0, count).map((item) =>
this.convertRetrieverItem(item)
);
}

// A method to extract the excerpt text from a QueryResultItem object.
getQueryItemExcerpt(item: QueryResultItem) {
if (
item.AdditionalAttributes &&
item.AdditionalAttributes[0].Key === "AnswerText"
) {
if (!item.AdditionalAttributes) {
return "";
}
if (!item.AdditionalAttributes[0]) {
return "";
}

return this.cleanResult(
item.AdditionalAttributes[0].Value?.TextWithHighlightsValue?.Text || ""
);
} else if (item.DocumentExcerpt) {
return this.cleanResult(item.DocumentExcerpt.Text || "");
} else {
return "";
}
}

// A method to convert a QueryResultItem object into a Document object.
convertQueryItem(item: QueryResultItem) {
const title = item.DocumentTitle?.Text || "";
const excerpt = this.getQueryItemExcerpt(item);
const pageContent = this.combineText(title, excerpt);
const source = item.DocumentURI;
const attributes = this.getDocAttributes(item.DocumentAttributes);
const metadata = {
source,
title,
excerpt,
document_attributes: attributes,
};

return new Document({ pageContent, metadata });
}

// A method to extract the top-k documents from a QueryCommandOutput object.
getQueryDocs(response: QueryCommandOutput, pageSize: number) {
if (!response.ResultItems) return [];
const { length } = response.ResultItems;
const count = length < pageSize ? length : pageSize;
return response.ResultItems.slice(0, count).map((item) =>
this.convertQueryItem(item)
);
}

// A method to send a retrieve or query request to Kendra and return the top-k documents.
async queryKendra(
query: string,
topK: number,
attributeFilter?: AttributeFilter
) {
const retrieveCommand = new RetrieveCommand({
IndexId: this.indexId,
QueryText: query,
PageSize: topK,
AttributeFilter: attributeFilter,
});

const retrieveResponse = await this.kendraClient.send(retrieveCommand);
const retriveLength = retrieveResponse.ResultItems?.length;

if (retriveLength === 0) {
// Retrieve API returned 0 results, call query API
const queryCommand = new QueryCommand({
IndexId: this.indexId,
QueryText: query,
PageSize: topK,
AttributeFilter: attributeFilter,
});

const queryResponse = await this.kendraClient.send(queryCommand);
return this.getQueryDocs(queryResponse, this.topK);
} else {
return this.getRetrieverDocs(retrieveResponse, this.topK);
}
}

async _getRelevantDocuments(query: string): Promise<Document[]> {
const docs = await this.queryKendra(query, this.topK, this.attributeFilter);
return docs;
}
}
Loading

0 comments on commit e02a886

Please sign in to comment.