Skip to content

Commit

Permalink
Add cosine similarity and maximal marginal relevance functions (langc…
Browse files Browse the repository at this point in the history
…hain-ai#1973)

* Added cosine similarity function to math utils

* Added Maximal Marginal Relevance function to math utils

* Changed matrix framework from tensorflow to mljs

Co-authored-by: Ben Perlmutter <[email protected]>

* Added support for 2d array queryEmbedding parameter in MMR function

---------

Co-authored-by: Ben Perlmutter <[email protected]>
  • Loading branch information
archie-swif and mongodben authored Jul 20, 2023
1 parent 48ac23f commit 6603a3e
Show file tree
Hide file tree
Showing 4 changed files with 297 additions and 0 deletions.
1 change: 1 addition & 0 deletions langchain/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -567,6 +567,7 @@
"ioredis": "^5.3.2",
"jest": "^29.5.0",
"mammoth": "^1.5.1",
"ml-matrix": "^6.10.4",
"mongodb": "^5.2.0",
"mysql2": "^3.3.3",
"notion-to-md": "^3.1.0",
Expand Down
125 changes: 125 additions & 0 deletions langchain/src/util/math_utils.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import { similarity as ml_distance_similarity } from "ml-distance";

/**
* This function calculates the row-wise cosine similarity between two matrices with the same number of columns.
*
* @param {number[][]} X - The first matrix.
* @param {number[][]} Y - The second matrix.
*
* @throws {Error} If the number of columns in X and Y are not the same.
*
* @returns {number[][] | [[]]} A matrix where each row represents the cosine similarity values between the corresponding rows of X and Y.
*/
export function cosineSimilarity(X: number[][], Y: number[][]): number[][] {
if (
X.length === 0 ||
X[0].length === 0 ||
Y.length === 0 ||
Y[0].length === 0
) {
return [[]];
}

if (X[0].length !== Y[0].length) {
throw new Error(
`Number of columns in X and Y must be the same. X has shape ${[
X.length,
X[0].length,
]} and Y has shape ${[Y.length, Y[0].length]}.`
);
}

return X.map((xVector) =>
Y.map((yVector) => ml_distance_similarity.cosine(xVector, yVector)).map(
(similarity) => (Number.isNaN(similarity) ? 0 : similarity)
)
);
}

/**
* This function implements the Maximal Marginal Relevance algorithm
* to select a set of embeddings that maximizes the diversity and relevance to a query embedding.
*
* @param {number[]|number[][]} queryEmbedding - The query embedding.
* @param {number[][]} embeddingList - The list of embeddings to select from.
* @param {number} [lambda=0.5] - The trade-off parameter between relevance and diversity.
* @param {number} [k=4] - The maximum number of embeddings to select.
*
* @returns {number[]} The indexes of the selected embeddings in the embeddingList.
*/
export function maximalMarginalRelevance(
queryEmbedding: number[] | number[][],
embeddingList: number[][],
lambda = 0.5,
k = 4
): number[] {
if (Math.min(k, embeddingList.length) <= 0) {
return [];
}

const queryEmbeddingExpanded = (
Array.isArray(queryEmbedding[0]) ? queryEmbedding : [queryEmbedding]
) as number[][];

const similarityToQuery = cosineSimilarity(
queryEmbeddingExpanded,
embeddingList
)[0];
const mostSimilarEmbeddingIndex = argMax(similarityToQuery);

const selectedEmbeddings = [embeddingList[mostSimilarEmbeddingIndex]];
const selectedEmbeddingsIndexes = [mostSimilarEmbeddingIndex];

while (selectedEmbeddingsIndexes.length < Math.min(k, embeddingList.length)) {
let bestScore = -Infinity;
let bestIndex = -1;

const similarityToSelected = cosineSimilarity(
embeddingList,
selectedEmbeddings
);

similarityToQuery.forEach((queryScore, queryScoreIndex) => {
if (queryScoreIndex in selectedEmbeddingsIndexes) {
return;
}
const maxSimilarityToSelected = Math.max(
...similarityToSelected[queryScoreIndex]
);
const score =
lambda * queryScore - (1 - lambda) * maxSimilarityToSelected;

if (score > bestScore) {
bestScore = score;
bestIndex = queryScoreIndex;
}
});
selectedEmbeddings.push(embeddingList[bestIndex]);
selectedEmbeddingsIndexes.push(bestIndex);
}

return selectedEmbeddingsIndexes;
}

/**
* Finds the index of the maximum value in the given array.
* @param {number[]} array - The input array.
*
* @returns {number} The index of the maximum value in the array. If the array is empty, returns -1.
*/
function argMax(array: number[]): number {
if (array.length === 0) {
return -1;
}

let maxValue = array[0];
let maxIndex = 0;

for (let i = 1; i < array.length; i += 1) {
if (array[i] > maxValue) {
maxIndex = i;
maxValue = array[i];
}
}
return maxIndex;
}
131 changes: 131 additions & 0 deletions langchain/src/util/tests/math_utils.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
import { test, expect } from "@jest/globals";
import { Matrix } from "ml-matrix";
import { cosineSimilarity, maximalMarginalRelevance } from "../math_utils.js";

test("Test cosine similarity zero", async () => {
const X = Matrix.rand(3, 3).to2DArray();
const Y = Matrix.zeros(3, 3).to2DArray();
const expected = Matrix.zeros(3, 3).to2DArray();
const actual = cosineSimilarity(X, Y);
expect(actual).toEqual(expected);
});

test("Test cosine similarity identity", async () => {
const X = Matrix.rand(4, 4).to2DArray();
const actual = cosineSimilarity(X, X);

// Diagonal is expected to be [1, 1, 1, 1]
for (let i = 0; i < 4; i += 1) {
expect(actual[i][i]).toBeCloseTo(1);
}
});

test("Test cosine similarity", async () => {
const X = [
[1.0, 2.0, 3.0],
[0.0, 1.0, 0.0],
[1.0, 2.0, 0.0],
];

const Y = [
[0.5, 1.0, 1.5],
[1.0, 0.0, 0.0],
[2.0, 5.0, 2.0],
[0.0, 0.0, 0.0],
];

const expected = [
[1, 0.2672612419124244, 0.8374357893586237, 0],
[0.5345224838248488, 0, 0.8703882797784892, 0],
[0.5976143046671968, 0.4472135954999579, 0.9341987329938275, 0],
];

const actual = cosineSimilarity(X, Y);
expect(actual).toEqual(expected);
});

test("Test cosine similarity empty", async () => {
const X = [[]];
const Y = Matrix.rand(3, 3).to2DArray();
expect(cosineSimilarity(X, X)).toEqual([[]]);
expect(cosineSimilarity(X, Y)).toEqual([[]]);
});

test("Test cosine similarity wrong shape", async () => {
const X = Matrix.rand(2, 2).to2DArray();
const Y = Matrix.rand(2, 4).to2DArray();
expect(() => cosineSimilarity(X, Y)).toThrowError();
});

test("Test cosine similarity different shape", async () => {
const X = Matrix.rand(2, 2).to2DArray();
const Y = Matrix.rand(4, 2).to2DArray();
expect(() => cosineSimilarity(X, Y)).not.toThrowError();
});

test("Test maximal marginal relevance lambda zero", async () => {
const queryEmbedding = Matrix.rand(5, 1).to1DArray();
const zeros = Matrix.zeros(5, 1).to1DArray();
const embeddingList = [queryEmbedding, queryEmbedding, zeros];

const expected = [0, 2];
const actual = maximalMarginalRelevance(queryEmbedding, embeddingList, 0, 2);

expect(actual).toEqual(expected);
});

test("Test maximal marginal relevance lambda one", async () => {
const queryEmbedding = Matrix.rand(5, 1).to1DArray();
const zeros = Matrix.zeros(5, 1).to1DArray();
const embeddingList = [queryEmbedding, queryEmbedding, zeros];

const expected = [0, 1];
const actual = maximalMarginalRelevance(queryEmbedding, embeddingList, 1, 2);

expect(actual).toEqual(expected);
});

test("Test maximal marginal relevance", async () => {
// Vectors that are 30, 45 and 75 degrees from query vector (cosine similarity of
// 0.87, 0.71, 0.26) and the latter two are 15 and 60 degree from the first
// (cosine similarity 0.97 and 0.71). So for 3rd vector be chosen, must be case that
// 0.71lambda - 0.97(1 - lambda) < 0.26lambda - 0.71(1-lambda) -> lambda ~< .26 / .71

const queryEmbedding = [1, 0];
const embeddingList = [
[3 ** 0.5, 1],
[1, 1],
[1, 2 + 3 ** 0.5],
];

let expected = [0, 2];
let actual = maximalMarginalRelevance(
queryEmbedding,
embeddingList,
25 / 71,
2
);
expect(actual).toEqual(expected);

expected = [0, 1];
actual = maximalMarginalRelevance(queryEmbedding, embeddingList, 27 / 71, 2);
expect(actual).toEqual(expected);
});

test("Test maximal marginal relevance query dim", async () => {
const randomVector = Matrix.rand(5, 1);

const queryEmbedding = randomVector.to1DArray();
const queryEmbedding2D = randomVector.transpose().to2DArray();
const embeddingList = Matrix.rand(4, 5).to2DArray();

const first = maximalMarginalRelevance(queryEmbedding, embeddingList, 1, 2);
const second = maximalMarginalRelevance(
queryEmbedding2D,
embeddingList,
1,
2
);

expect(first).toEqual(second);
});
40 changes: 40 additions & 0 deletions yarn.lock
Original file line number Diff line number Diff line change
Expand Up @@ -13023,6 +13023,7 @@ __metadata:
langsmith: ~0.0.11
mammoth: ^1.5.1
ml-distance: ^4.0.0
ml-matrix: ^6.10.4
mongodb: ^5.2.0
mysql2: ^3.3.3
notion-to-md: ^3.1.0
Expand Down Expand Up @@ -14141,6 +14142,15 @@ __metadata:
languageName: node
linkType: hard

"ml-array-max@npm:^1.2.4":
version: 1.2.4
resolution: "ml-array-max@npm:1.2.4"
dependencies:
is-any-array: ^2.0.0
checksum: af59075eb6bf0076e179a075748f8f1a10f1d60f7ae55103035d5aca637ceb6a109e47bce28bfb82756c977652cbcad4d985e859cacd517edc8807f2e61f7abf
languageName: node
linkType: hard

"ml-array-mean@npm:^1.1.6":
version: 1.1.6
resolution: "ml-array-mean@npm:1.1.6"
Expand All @@ -14150,6 +14160,26 @@ __metadata:
languageName: node
linkType: hard

"ml-array-min@npm:^1.2.3":
version: 1.2.3
resolution: "ml-array-min@npm:1.2.3"
dependencies:
is-any-array: ^2.0.0
checksum: 7a09d5b4cf563a4743b69e5a395f6a617d6fd74ae5f35d0b77ca8ac9568d98b61249bd7d1f962a6e744726ebb94a6ece6e386a6e024ad0e9d329bce7e7e9f2c3
languageName: node
linkType: hard

"ml-array-rescale@npm:^1.3.7":
version: 1.3.7
resolution: "ml-array-rescale@npm:1.3.7"
dependencies:
is-any-array: ^2.0.0
ml-array-max: ^1.2.4
ml-array-min: ^1.2.3
checksum: 7852a09cbc1f39ed625a93ba803ecc13438ddcae20961d7435fb0a89512b66e282b5ea0f425458813028f4004252ed40c6407b893d4b1910591c5aabc8e93810
languageName: node
linkType: hard

"ml-array-sum@npm:^1.1.6":
version: 1.1.6
resolution: "ml-array-sum@npm:1.1.6"
Expand Down Expand Up @@ -14177,6 +14207,16 @@ __metadata:
languageName: node
linkType: hard

"ml-matrix@npm:^6.10.4":
version: 6.10.4
resolution: "ml-matrix@npm:6.10.4"
dependencies:
is-any-array: ^2.0.0
ml-array-rescale: ^1.3.7
checksum: fe9895af746eec3db316451b6db40dcb14fe91f9939640a26202618ae8af826007e1db821e2c914a1555427cdf69a0c70b14965e4464efed0e180db42749b6c0
languageName: node
linkType: hard

"ml-tree-similarity@npm:^1.0.0":
version: 1.0.0
resolution: "ml-tree-similarity@npm:1.0.0"
Expand Down

0 comments on commit 6603a3e

Please sign in to comment.