Skip to content

Commit

Permalink
Nc/tensorflow embeddings (langchain-ai#1020)
Browse files Browse the repository at this point in the history
* Add tensorflow embeddings

* Add entrypoints

* Lint

* Update name
  • Loading branch information
nfcampos authored Apr 28, 2023
1 parent a5a24a1 commit e92bbca
Show file tree
Hide file tree
Showing 8 changed files with 231 additions and 2 deletions.
19 changes: 19 additions & 0 deletions docs/docs/modules/models/embeddings/integrations.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ const embeddings = new OpenAIEmbeddings({

## `CohereEmbeddings`

The `CohereEmbeddings` class uses the Cohere API to generate embeddings for a given text.

```bash npm2yarn
npm install cohere-ai
```
Expand All @@ -32,3 +34,20 @@ const embeddings = new CohereEmbeddings({
apiKey: "YOUR-API-KEY", // In Node.js defaults to process.env.COHERE_API_KEY
});
```

## `TensorFlowEmbeddings`

This Embeddings integration runs the embeddings entirely in your browser or Node.js environment, using [TensorFlow.js](https://www.tensorflow.org/js). This means that your data isn't sent to any third party, and you don't need to sign up for any API keys. However, it does require more memory and processing power than the other integrations.

```bash npm2yarn
npm install @tensorflow/tfjs-core @tensorflow/tfjs-converter @tensorflow-models/universal-sentence-encoder @tensorflow/tfjs-backend-cpu
```

```typescript
import "@tensorflow/tfjs-backend-cpu";
import { TensorFlowEmbeddings } from "langchain/embeddings/tensorflow";

const embeddings = new TensorFlowEmbeddings();
```

This example uses the CPU backend, which works in any JS environment. However, you can use any of the backends supported by TensorFlow.js, including GPU and WebAssembly, which will be a lot faster. For Node.js you can use the `@tensorflow/tfjs-node` package, and for the browser you can use the `@tensorflow/tfjs-backend-webgl` package. See the [TensorFlow.js documentation](https://www.tensorflow.org/js/guide/platform_environment) for more information.
3 changes: 3 additions & 0 deletions langchain/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ embeddings/openai.d.ts
embeddings/cohere.cjs
embeddings/cohere.js
embeddings/cohere.d.ts
embeddings/tensorflow.cjs
embeddings/tensorflow.js
embeddings/tensorflow.d.ts
llms.cjs
llms.js
llms.d.ts
Expand Down
24 changes: 24 additions & 0 deletions langchain/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@
"embeddings/cohere.cjs",
"embeddings/cohere.js",
"embeddings/cohere.d.ts",
"embeddings/tensorflow.cjs",
"embeddings/tensorflow.js",
"embeddings/tensorflow.d.ts",
"llms.cjs",
"llms.js",
"llms.d.ts",
Expand Down Expand Up @@ -294,6 +297,10 @@
"@opensearch-project/opensearch": "^2.2.0",
"@pinecone-database/pinecone": "^0.0.14",
"@supabase/supabase-js": "^2.10.0",
"@tensorflow-models/universal-sentence-encoder": "^1.3.3",
"@tensorflow/tfjs-backend-cpu": "^4.4.0",
"@tensorflow/tfjs-converter": "^4.4.0",
"@tensorflow/tfjs-core": "^4.4.0",
"@tsconfig/recommended": "^1.0.2",
"@types/d3-dsv": "^2",
"@types/flat": "^5.0.2",
Expand Down Expand Up @@ -346,6 +353,9 @@
"@opensearch-project/opensearch": "*",
"@pinecone-database/pinecone": "*",
"@supabase/supabase-js": "^2.10.0",
"@tensorflow-models/universal-sentence-encoder": "*",
"@tensorflow/tfjs-converter": "*",
"@tensorflow/tfjs-core": "*",
"@zilliz/milvus2-sdk-node": "^2.2.0",
"axios": "*",
"cheerio": "^1.0.0-rc.12",
Expand Down Expand Up @@ -388,6 +398,15 @@
"@supabase/supabase-js": {
"optional": true
},
"@tensorflow-models/universal-sentence-encoder": {
"optional": true
},
"@tensorflow/tfjs-converter": {
"optional": true
},
"@tensorflow/tfjs-core": {
"optional": true
},
"@zilliz/milvus2-sdk-node": {
"optional": true
},
Expand Down Expand Up @@ -559,6 +578,11 @@
"import": "./embeddings/cohere.js",
"require": "./embeddings/cohere.cjs"
},
"./embeddings/tensorflow": {
"types": "./embeddings/tensorflow.d.ts",
"import": "./embeddings/tensorflow.js",
"require": "./embeddings/tensorflow.cjs"
},
"./llms": {
"node": {
"types": "./llms.d.ts",
Expand Down
2 changes: 2 additions & 0 deletions langchain/scripts/create-entrypoints.js
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ const entrypoints = {
"embeddings/fake": "embeddings/fake",
"embeddings/openai": "embeddings/openai",
"embeddings/cohere": "embeddings/cohere",
"embeddings/tensorflow": "embeddings/tensorflow",
// llms
llms: "llms/index",
"llms/load": "llms/load",
Expand Down Expand Up @@ -135,6 +136,7 @@ const requiresOptionalDependency = [
"tools/webbrowser",
"chains/load",
"embeddings/cohere",
"embeddings/tensorflow",
"llms/load",
"llms/cohere",
"llms/hf",
Expand Down
44 changes: 44 additions & 0 deletions langchain/src/embeddings/tensorflow.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import { load } from "@tensorflow-models/universal-sentence-encoder";
import * as tf from "@tensorflow/tfjs-core";

import { Embeddings, EmbeddingsParams } from "./base.js";

export interface TensorFlowEmbeddingsParams extends EmbeddingsParams {}

export class TensorFlowEmbeddings extends Embeddings {
constructor(fields?: TensorFlowEmbeddingsParams) {
super(fields ?? {});

try {
tf.backend();
} catch (e) {
throw new Error("No TensorFlow backend found, see instructions at ...");
}
}

_cached: ReturnType<typeof load>;

private async load() {
if (this._cached === undefined) {
this._cached = load();
}
return this._cached;
}

private _embed(texts: string[]) {
return this.caller.call(async () => {
const model = await this.load();
return model.embed(texts);
});
}

embedQuery(document: string): Promise<number[]> {
return this._embed([document])
.then((embeddings) => embeddings.array())
.then((embeddings) => embeddings[0]);
}

embedDocuments(documents: string[]): Promise<number[][]> {
return this._embed(documents).then((embeddings) => embeddings.array());
}
}
41 changes: 41 additions & 0 deletions langchain/src/embeddings/tests/tensorflow.int.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import { test, expect } from "@jest/globals";
import "@tensorflow/tfjs-backend-cpu";
import { TensorFlowEmbeddings } from "../tensorflow.js";
import { MemoryVectorStore } from "../../vectorstores/memory.js";
import { Document } from "../../document.js";

test("TensorflowEmbeddings", async () => {
const embeddings = new TensorFlowEmbeddings();

const documents = [
"Hello world!",
"Hello bad world!",
"Hello nice world!",
"Hello good world!",
"1 + 1 = 2",
"1 + 1 = 3",
];

const queryEmbedding = await embeddings.embedQuery(documents[0]);
expect(queryEmbedding).toHaveLength(512);
expect(typeof queryEmbedding[0]).toBe("number");

const store = new MemoryVectorStore(embeddings);

await store.addDocuments(
documents.map((pageContent) => new Document({ pageContent }))
);

expect(await store.similaritySearch(documents[4], 2)).toMatchInlineSnapshot(`
[
Document {
"metadata": {},
"pageContent": "1 + 1 = 2",
},
Document {
"metadata": {},
"pageContent": "1 + 1 = 3",
},
]
`);
});
1 change: 1 addition & 0 deletions langchain/tsconfig.json
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
"src/embeddings/fake.ts",
"src/embeddings/openai.ts",
"src/embeddings/cohere.ts",
"src/embeddings/tensorflow.ts",
"src/llms/load.ts",
"src/llms/base.ts",
"src/llms/openai.ts",
Expand Down
99 changes: 97 additions & 2 deletions yarn.lock
Original file line number Diff line number Diff line change
Expand Up @@ -6849,6 +6849,53 @@ __metadata:
languageName: node
linkType: hard

"@tensorflow-models/universal-sentence-encoder@npm:^1.3.3":
version: 1.3.3
resolution: "@tensorflow-models/universal-sentence-encoder@npm:1.3.3"
peerDependencies:
"@tensorflow/tfjs-converter": ^3.6.0
"@tensorflow/tfjs-core": ^3.6.0
checksum: 99fea126088a11385501f104983fbc3de2a3cb4414aad94d08200cd40678f7f7711c7e8735ad570369151834b6ec8acb970d7b464c76d40ff9f102e476e45f57
languageName: node
linkType: hard

"@tensorflow/tfjs-backend-cpu@npm:^4.4.0":
version: 4.4.0
resolution: "@tensorflow/tfjs-backend-cpu@npm:4.4.0"
dependencies:
"@types/seedrandom": ^2.4.28
seedrandom: ^3.0.5
peerDependencies:
"@tensorflow/tfjs-core": 4.4.0
checksum: e84eee75989ea7d2fcf7090a59021a44fbc9428cf4264cbaaaa0bba16664728a4edf5f229d5ee9c5e860c62b00d052b9e41204ca5b5b1a24cebf614f43deebe8
languageName: node
linkType: hard

"@tensorflow/tfjs-converter@npm:^4.4.0":
version: 4.4.0
resolution: "@tensorflow/tfjs-converter@npm:4.4.0"
peerDependencies:
"@tensorflow/tfjs-core": 4.4.0
checksum: 5e0c71f53ed9ee48122f6580ae35ef3b82ed84caeab3ee5f6ea24d88f9761c01649d9fe5d2603558fa9c133bf0cc55e60b17db2e62ea1a38349e123418852fdf
languageName: node
linkType: hard

"@tensorflow/tfjs-core@npm:^4.4.0":
version: 4.4.0
resolution: "@tensorflow/tfjs-core@npm:4.4.0"
dependencies:
"@types/long": ^4.0.1
"@types/offscreencanvas": ~2019.7.0
"@types/seedrandom": ^2.4.28
"@types/webgl-ext": 0.0.30
"@webgpu/types": 0.1.30
long: 4.0.0
node-fetch: ~2.6.1
seedrandom: ^3.0.5
checksum: 89b1c7a726cf0d640bc0dcc7404fc29b3b8ee229de52ab0c58f665b9933958f9ec140684e4ceb868860cdf97b9e4fa70c5df679a81bfae7c77f3783ba71c02c2
languageName: node
linkType: hard

"@testing-library/dom@npm:^8.5.0":
version: 8.20.0
resolution: "@testing-library/dom@npm:8.20.0"
Expand Down Expand Up @@ -7352,6 +7399,13 @@ __metadata:
languageName: node
linkType: hard

"@types/offscreencanvas@npm:~2019.7.0":
version: 2019.7.0
resolution: "@types/offscreencanvas@npm:2019.7.0"
checksum: 018cfcd19e0c59c44d14ba61caaca7246f77fbb512839c7881654b7f2b6591dbdd5857362eccbf49f29cdc93724e71a4b37c8b6cf203388f9c04e913a53ea390
languageName: node
linkType: hard

"@types/parse-json@npm:^4.0.0":
version: 4.0.0
resolution: "@types/parse-json@npm:4.0.0"
Expand Down Expand Up @@ -7526,6 +7580,13 @@ __metadata:
languageName: node
linkType: hard

"@types/seedrandom@npm:^2.4.28":
version: 2.4.30
resolution: "@types/seedrandom@npm:2.4.30"
checksum: 1bcf634bb0146b5de443d4581556795e52186f830072ef12f42efb6f15a4e9950630f1c4322356df0f39e953436d4fb17b5fb2834e867ff6cf6e5570849ebad2
languageName: node
linkType: hard

"@types/semver@npm:^7.3.12":
version: 7.3.13
resolution: "@types/semver@npm:7.3.13"
Expand Down Expand Up @@ -7612,6 +7673,13 @@ __metadata:
languageName: node
linkType: hard

"@types/webgl-ext@npm:0.0.30":
version: 0.0.30
resolution: "@types/webgl-ext@npm:0.0.30"
checksum: c98aa8af2d24d54dc29d836aed0ca5acff45903dc91879d62110e84e615024ca071af69d1dcff26857aa2ab80c86bea9373d718616dcc61dbfdc6a6dd917a064
languageName: node
linkType: hard

"@types/webidl-conversions@npm:*":
version: 7.0.0
resolution: "@types/webidl-conversions@npm:7.0.0"
Expand Down Expand Up @@ -8119,6 +8187,13 @@ __metadata:
languageName: node
linkType: hard

"@webgpu/types@npm:0.1.30":
version: 0.1.30
resolution: "@webgpu/types@npm:0.1.30"
checksum: c07516879c60617214717d63789f4a512004d004e845ed0439449f564fb8b94bd48882d599918e7db26914d584c70aecda25e90a327aabc32e57c90aa818ac06
languageName: node
linkType: hard

"@xtuc/ieee754@npm:^1.2.0":
version: 1.2.0
resolution: "@xtuc/ieee754@npm:1.2.0"
Expand Down Expand Up @@ -17117,6 +17192,10 @@ __metadata:
"@opensearch-project/opensearch": ^2.2.0
"@pinecone-database/pinecone": ^0.0.14
"@supabase/supabase-js": ^2.10.0
"@tensorflow-models/universal-sentence-encoder": ^1.3.3
"@tensorflow/tfjs-backend-cpu": ^4.4.0
"@tensorflow/tfjs-converter": ^4.4.0
"@tensorflow/tfjs-core": ^4.4.0
"@tsconfig/recommended": ^1.0.2
"@types/d3-dsv": ^2
"@types/flat": ^5.0.2
Expand Down Expand Up @@ -17182,6 +17261,9 @@ __metadata:
"@opensearch-project/opensearch": "*"
"@pinecone-database/pinecone": "*"
"@supabase/supabase-js": ^2.10.0
"@tensorflow-models/universal-sentence-encoder": "*"
"@tensorflow/tfjs-converter": "*"
"@tensorflow/tfjs-core": "*"
"@zilliz/milvus2-sdk-node": ^2.2.0
axios: "*"
cheerio: ^1.0.0-rc.12
Expand Down Expand Up @@ -17216,6 +17298,12 @@ __metadata:
optional: true
"@supabase/supabase-js":
optional: true
"@tensorflow-models/universal-sentence-encoder":
optional: true
"@tensorflow/tfjs-converter":
optional: true
"@tensorflow/tfjs-core":
optional: true
"@zilliz/milvus2-sdk-node":
optional: true
axios:
Expand Down Expand Up @@ -17607,7 +17695,7 @@ __metadata:
languageName: node
linkType: hard

"long@npm:^4.0.0":
"long@npm:4.0.0, long@npm:^4.0.0":
version: 4.0.0
resolution: "long@npm:4.0.0"
checksum: 16afbe8f749c7c849db1f4de4e2e6a31ac6e617cead3bdc4f9605cb703cd20e1e9fc1a7baba674ffcca57d660a6e5b53a9e236d7b25a295d3855cca79cc06744
Expand Down Expand Up @@ -18679,7 +18767,7 @@ __metadata:
languageName: node
linkType: hard

"node-fetch@npm:^2.6.1, node-fetch@npm:^2.6.7":
"node-fetch@npm:^2.6.1, node-fetch@npm:^2.6.7, node-fetch@npm:~2.6.1":
version: 2.6.9
resolution: "node-fetch@npm:2.6.9"
dependencies:
Expand Down Expand Up @@ -22468,6 +22556,13 @@ __metadata:
languageName: node
linkType: hard

"seedrandom@npm:^3.0.5":
version: 3.0.5
resolution: "seedrandom@npm:3.0.5"
checksum: 728b56bc3bc1b9ddeabd381e449b51cb31bdc0aa86e27fcd0190cea8c44613d5bcb2f6bb63ed79f78180cbe791c20b8ec31a9627f7b7fc7f476fd2bdb7e2da9f
languageName: node
linkType: hard

"selderee@npm:^0.11.0":
version: 0.11.0
resolution: "selderee@npm:0.11.0"
Expand Down

0 comments on commit e92bbca

Please sign in to comment.