forked from langchain-ai/langchainjs
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add cosine similarity and maximal marginal relevance functions (langc…
…hain-ai#1973) * Added cosine similarity function to math utils * Added Maximal Marginal Relevance function to math utils * Changed matrix framework from tensorflow to mljs Co-authored-by: Ben Perlmutter <[email protected]> * Added support for 2d array queryEmbedding parameter in MMR function --------- Co-authored-by: Ben Perlmutter <[email protected]>
- Loading branch information
1 parent
48ac23f
commit 6603a3e
Showing
4 changed files
with
297 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
import { similarity as ml_distance_similarity } from "ml-distance"; | ||
|
||
/** | ||
* This function calculates the row-wise cosine similarity between two matrices with the same number of columns. | ||
* | ||
* @param {number[][]} X - The first matrix. | ||
* @param {number[][]} Y - The second matrix. | ||
* | ||
* @throws {Error} If the number of columns in X and Y are not the same. | ||
* | ||
* @returns {number[][] | [[]]} A matrix where each row represents the cosine similarity values between the corresponding rows of X and Y. | ||
*/ | ||
export function cosineSimilarity(X: number[][], Y: number[][]): number[][] { | ||
if ( | ||
X.length === 0 || | ||
X[0].length === 0 || | ||
Y.length === 0 || | ||
Y[0].length === 0 | ||
) { | ||
return [[]]; | ||
} | ||
|
||
if (X[0].length !== Y[0].length) { | ||
throw new Error( | ||
`Number of columns in X and Y must be the same. X has shape ${[ | ||
X.length, | ||
X[0].length, | ||
]} and Y has shape ${[Y.length, Y[0].length]}.` | ||
); | ||
} | ||
|
||
return X.map((xVector) => | ||
Y.map((yVector) => ml_distance_similarity.cosine(xVector, yVector)).map( | ||
(similarity) => (Number.isNaN(similarity) ? 0 : similarity) | ||
) | ||
); | ||
} | ||
|
||
/** | ||
* This function implements the Maximal Marginal Relevance algorithm | ||
* to select a set of embeddings that maximizes the diversity and relevance to a query embedding. | ||
* | ||
* @param {number[]|number[][]} queryEmbedding - The query embedding. | ||
* @param {number[][]} embeddingList - The list of embeddings to select from. | ||
* @param {number} [lambda=0.5] - The trade-off parameter between relevance and diversity. | ||
* @param {number} [k=4] - The maximum number of embeddings to select. | ||
* | ||
* @returns {number[]} The indexes of the selected embeddings in the embeddingList. | ||
*/ | ||
export function maximalMarginalRelevance( | ||
queryEmbedding: number[] | number[][], | ||
embeddingList: number[][], | ||
lambda = 0.5, | ||
k = 4 | ||
): number[] { | ||
if (Math.min(k, embeddingList.length) <= 0) { | ||
return []; | ||
} | ||
|
||
const queryEmbeddingExpanded = ( | ||
Array.isArray(queryEmbedding[0]) ? queryEmbedding : [queryEmbedding] | ||
) as number[][]; | ||
|
||
const similarityToQuery = cosineSimilarity( | ||
queryEmbeddingExpanded, | ||
embeddingList | ||
)[0]; | ||
const mostSimilarEmbeddingIndex = argMax(similarityToQuery); | ||
|
||
const selectedEmbeddings = [embeddingList[mostSimilarEmbeddingIndex]]; | ||
const selectedEmbeddingsIndexes = [mostSimilarEmbeddingIndex]; | ||
|
||
while (selectedEmbeddingsIndexes.length < Math.min(k, embeddingList.length)) { | ||
let bestScore = -Infinity; | ||
let bestIndex = -1; | ||
|
||
const similarityToSelected = cosineSimilarity( | ||
embeddingList, | ||
selectedEmbeddings | ||
); | ||
|
||
similarityToQuery.forEach((queryScore, queryScoreIndex) => { | ||
if (queryScoreIndex in selectedEmbeddingsIndexes) { | ||
return; | ||
} | ||
const maxSimilarityToSelected = Math.max( | ||
...similarityToSelected[queryScoreIndex] | ||
); | ||
const score = | ||
lambda * queryScore - (1 - lambda) * maxSimilarityToSelected; | ||
|
||
if (score > bestScore) { | ||
bestScore = score; | ||
bestIndex = queryScoreIndex; | ||
} | ||
}); | ||
selectedEmbeddings.push(embeddingList[bestIndex]); | ||
selectedEmbeddingsIndexes.push(bestIndex); | ||
} | ||
|
||
return selectedEmbeddingsIndexes; | ||
} | ||
|
||
/** | ||
* Finds the index of the maximum value in the given array. | ||
* @param {number[]} array - The input array. | ||
* | ||
* @returns {number} The index of the maximum value in the array. If the array is empty, returns -1. | ||
*/ | ||
function argMax(array: number[]): number { | ||
if (array.length === 0) { | ||
return -1; | ||
} | ||
|
||
let maxValue = array[0]; | ||
let maxIndex = 0; | ||
|
||
for (let i = 1; i < array.length; i += 1) { | ||
if (array[i] > maxValue) { | ||
maxIndex = i; | ||
maxValue = array[i]; | ||
} | ||
} | ||
return maxIndex; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
import { test, expect } from "@jest/globals"; | ||
import { Matrix } from "ml-matrix"; | ||
import { cosineSimilarity, maximalMarginalRelevance } from "../math_utils.js"; | ||
|
||
test("Test cosine similarity zero", async () => { | ||
const X = Matrix.rand(3, 3).to2DArray(); | ||
const Y = Matrix.zeros(3, 3).to2DArray(); | ||
const expected = Matrix.zeros(3, 3).to2DArray(); | ||
const actual = cosineSimilarity(X, Y); | ||
expect(actual).toEqual(expected); | ||
}); | ||
|
||
test("Test cosine similarity identity", async () => { | ||
const X = Matrix.rand(4, 4).to2DArray(); | ||
const actual = cosineSimilarity(X, X); | ||
|
||
// Diagonal is expected to be [1, 1, 1, 1] | ||
for (let i = 0; i < 4; i += 1) { | ||
expect(actual[i][i]).toBeCloseTo(1); | ||
} | ||
}); | ||
|
||
test("Test cosine similarity", async () => { | ||
const X = [ | ||
[1.0, 2.0, 3.0], | ||
[0.0, 1.0, 0.0], | ||
[1.0, 2.0, 0.0], | ||
]; | ||
|
||
const Y = [ | ||
[0.5, 1.0, 1.5], | ||
[1.0, 0.0, 0.0], | ||
[2.0, 5.0, 2.0], | ||
[0.0, 0.0, 0.0], | ||
]; | ||
|
||
const expected = [ | ||
[1, 0.2672612419124244, 0.8374357893586237, 0], | ||
[0.5345224838248488, 0, 0.8703882797784892, 0], | ||
[0.5976143046671968, 0.4472135954999579, 0.9341987329938275, 0], | ||
]; | ||
|
||
const actual = cosineSimilarity(X, Y); | ||
expect(actual).toEqual(expected); | ||
}); | ||
|
||
test("Test cosine similarity empty", async () => { | ||
const X = [[]]; | ||
const Y = Matrix.rand(3, 3).to2DArray(); | ||
expect(cosineSimilarity(X, X)).toEqual([[]]); | ||
expect(cosineSimilarity(X, Y)).toEqual([[]]); | ||
}); | ||
|
||
test("Test cosine similarity wrong shape", async () => { | ||
const X = Matrix.rand(2, 2).to2DArray(); | ||
const Y = Matrix.rand(2, 4).to2DArray(); | ||
expect(() => cosineSimilarity(X, Y)).toThrowError(); | ||
}); | ||
|
||
test("Test cosine similarity different shape", async () => { | ||
const X = Matrix.rand(2, 2).to2DArray(); | ||
const Y = Matrix.rand(4, 2).to2DArray(); | ||
expect(() => cosineSimilarity(X, Y)).not.toThrowError(); | ||
}); | ||
|
||
test("Test maximal marginal relevance lambda zero", async () => { | ||
const queryEmbedding = Matrix.rand(5, 1).to1DArray(); | ||
const zeros = Matrix.zeros(5, 1).to1DArray(); | ||
const embeddingList = [queryEmbedding, queryEmbedding, zeros]; | ||
|
||
const expected = [0, 2]; | ||
const actual = maximalMarginalRelevance(queryEmbedding, embeddingList, 0, 2); | ||
|
||
expect(actual).toEqual(expected); | ||
}); | ||
|
||
test("Test maximal marginal relevance lambda one", async () => { | ||
const queryEmbedding = Matrix.rand(5, 1).to1DArray(); | ||
const zeros = Matrix.zeros(5, 1).to1DArray(); | ||
const embeddingList = [queryEmbedding, queryEmbedding, zeros]; | ||
|
||
const expected = [0, 1]; | ||
const actual = maximalMarginalRelevance(queryEmbedding, embeddingList, 1, 2); | ||
|
||
expect(actual).toEqual(expected); | ||
}); | ||
|
||
test("Test maximal marginal relevance", async () => { | ||
// Vectors that are 30, 45 and 75 degrees from query vector (cosine similarity of | ||
// 0.87, 0.71, 0.26) and the latter two are 15 and 60 degree from the first | ||
// (cosine similarity 0.97 and 0.71). So for 3rd vector be chosen, must be case that | ||
// 0.71lambda - 0.97(1 - lambda) < 0.26lambda - 0.71(1-lambda) -> lambda ~< .26 / .71 | ||
|
||
const queryEmbedding = [1, 0]; | ||
const embeddingList = [ | ||
[3 ** 0.5, 1], | ||
[1, 1], | ||
[1, 2 + 3 ** 0.5], | ||
]; | ||
|
||
let expected = [0, 2]; | ||
let actual = maximalMarginalRelevance( | ||
queryEmbedding, | ||
embeddingList, | ||
25 / 71, | ||
2 | ||
); | ||
expect(actual).toEqual(expected); | ||
|
||
expected = [0, 1]; | ||
actual = maximalMarginalRelevance(queryEmbedding, embeddingList, 27 / 71, 2); | ||
expect(actual).toEqual(expected); | ||
}); | ||
|
||
test("Test maximal marginal relevance query dim", async () => { | ||
const randomVector = Matrix.rand(5, 1); | ||
|
||
const queryEmbedding = randomVector.to1DArray(); | ||
const queryEmbedding2D = randomVector.transpose().to2DArray(); | ||
const embeddingList = Matrix.rand(4, 5).to2DArray(); | ||
|
||
const first = maximalMarginalRelevance(queryEmbedding, embeddingList, 1, 2); | ||
const second = maximalMarginalRelevance( | ||
queryEmbedding2D, | ||
embeddingList, | ||
1, | ||
2 | ||
); | ||
|
||
expect(first).toEqual(second); | ||
}); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters