Skip to content

Commit

Permalink
Fix node equality check in embedding processor (gravitational#30325)
Browse files Browse the repository at this point in the history
* Fix node equality check in embedding processor

* Apply minor suggestions from code review

* Make process() private, test using only public API

* Remove an unused function

* Update lib/ai/embeddingprocessor.go

Co-authored-by: Alan Parra <[email protected]>

---------

Co-authored-by: Alan Parra <[email protected]>
  • Loading branch information
justinas and codingllama authored Aug 16, 2023
1 parent 3626f26 commit 18016e5
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 32 deletions.
7 changes: 2 additions & 5 deletions lib/ai/embeddingprocessor.go
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ func (e *EmbeddingProcessor) Run(ctx context.Context, initialDelay, period time.
}
}

// process updates embeddings for all nodes once.
func (e *EmbeddingProcessor) process(ctx context.Context) {
batch := NewBatchReducer(e.mapProcessFn,
maxEmbeddingAPISize, // Max batch size allowed by OpenAI API,
Expand Down Expand Up @@ -260,11 +261,7 @@ func (e *EmbeddingProcessor) process(ctx context.Context) {
},
// On compare keys callback. Compare the keys for iteration.
func(node types.Server, embeddings *embeddinglib.Embedding) int {
if node.GetName() == embeddings.GetName() {
return 0
}

return strings.Compare(node.GetName(), embeddings.GetName())
return strings.Compare(node.GetName(), embeddings.GetEmbeddedID())
},
)

Expand Down
92 changes: 65 additions & 27 deletions lib/ai/embeddingprocessor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"crypto/sha256"
"errors"
"fmt"
"strings"
"sync"
"testing"
"time"
Expand All @@ -43,11 +44,15 @@ import (
// MockEmbedder returns embeddings based on the sha256 hash function. Those
// embeddings have no semantic meaning but ensure different embedded content
// provides different embeddings.
type MockEmbedder struct{}
type MockEmbedder struct {
timesCalled map[string]int
}

func (m MockEmbedder) ComputeEmbeddings(_ context.Context, input []string) ([]embedding.Vector64, error) {
func (m *MockEmbedder) ComputeEmbeddings(_ context.Context, input []string) ([]embedding.Vector64, error) {
result := make([]embedding.Vector64, len(input))
for i, text := range input {
name := strings.Split(text, "\n")[0]
m.timesCalled[name]++
hash := sha256.Sum256([]byte(text))
vector := make(embedding.Vector64, len(hash))
for j, x := range hash {
Expand All @@ -66,6 +71,15 @@ type mockNodeStreamer struct {
func (m *mockNodeStreamer) UpsertNode(_ context.Context, node types.Server) (*types.KeepAlive, error) {
m.mu.Lock()
defer m.mu.Unlock()
for i, n := range m.nodes {
// update
if n.GetName() == node.GetName() {
m.nodes[i] = node
return nil, nil
}
}

// insert
m.nodes = append(m.nodes, node)
return nil, nil
}
Expand Down Expand Up @@ -94,7 +108,9 @@ func TestNodeEmbeddingGeneration(t *testing.T) {
})
require.NoError(t, err)

embedder := MockEmbedder{}
embedder := MockEmbedder{
timesCalled: make(map[string]int),
}
presence := &mockNodeStreamer{}
embeddings := local.NewEmbeddingsService(bk)

Expand All @@ -107,25 +123,16 @@ func TestNodeEmbeddingGeneration(t *testing.T) {
Jitter: retryutils.NewSeventhJitter(),
})

done := make(chan struct{})
go func() {
err := processor.Run(ctx, 100*time.Millisecond, time.Second)
assert.ErrorIs(t, context.Canceled, err)
close(done)
}()

// Add some node servers.
const numNodes = 5
nodes := make([]types.Server, 0, numNodes)
for i := 0; i < numNodes; i++ {
node, _ := types.NewServer(fmt.Sprintf("node%d", i), types.KindNode, types.ServerSpecV2{
Addr: "127.0.0.1:1234",
Hostname: fmt.Sprintf("node%d", i),
CmdLabels: map[string]types.CommandLabelV2{
"version": {Result: "v8"},
"hostname": {Result: fmt.Sprintf("node%d.example.com", i)},
},
})
const numInitialNodes = 5
nodes := make([]types.Server, 0, numInitialNodes)
for i := 0; i < numInitialNodes; i++ {
node := makeNode(i + 1)
_, err = presence.UpsertNode(ctx, node)
require.NoError(t, err)
nodes = append(nodes, node)
Expand All @@ -134,12 +141,41 @@ func TestNodeEmbeddingGeneration(t *testing.T) {
require.Eventually(t, func() bool {
items, err := stream.Collect(embeddings.GetEmbeddings(ctx, types.KindNode))
assert.NoError(t, err)
return (len(items) == numNodes) && (len(nodes) == numNodes)
return len(items) == numInitialNodes
}, 7*time.Second, 200*time.Millisecond)

cancel()
validateEmbeddings(t,
presence.GetNodeStream(ctx, defaults.Namespace),
embeddings.GetEmbeddings(ctx, types.KindNode))

waitForDone(t, done, "timed out waiting for processor to stop")
for k, v := range embedder.timesCalled {
require.Equal(t, 1, v, "expected %v to be computed once, was %d", k, v)
}

// Run once more and verify that only changed or newly inserted nodes get their embeddings calculated
node1 := nodes[0]
node1.GetMetadata().Labels["foo"] = "bar"
_, err = presence.UpsertNode(ctx, node1)
require.NoError(t, err)
node6 := makeNode(6)
_, err = presence.UpsertNode(ctx, node6)
require.NoError(t, err)

// Since nodes are streamed in ascending order by names, when embeddings for node6 are calculated,
// we can be sure that our recent changes have been fully processed
require.Eventually(t, func() bool {
items, err := stream.Collect(embeddings.GetEmbeddings(ctx, types.KindNode))
assert.NoError(t, err)
return len(items) == numInitialNodes+1
}, 7*time.Second, 200*time.Millisecond)

for k, v := range embedder.timesCalled {
expected := 1
if strings.Contains(k, "node1") {
expected = 2
}
require.Equal(t, expected, v, "expected embedding for %q to be computed %d times, got computed %d times", k, expected, v)
}

validateEmbeddings(t,
presence.GetNodeStream(ctx, defaults.Namespace),
Expand All @@ -162,14 +198,16 @@ func TestMarshallUnmarshallEmbedding(t *testing.T) {
require.Equal(t, initial.Vector, final.Vector)
}

func waitForDone(t *testing.T, done chan struct{}, errMsg string) {
t.Helper()

select {
case <-done:
case <-time.After(5 * time.Second):
t.Fatal(errMsg)
}
func makeNode(num int) types.Server {
node, _ := types.NewServer(fmt.Sprintf("node%d", num), types.KindNode, types.ServerSpecV2{
Addr: "127.0.0.1:1234",
Hostname: fmt.Sprintf("node%d", num),
CmdLabels: map[string]types.CommandLabelV2{
"version": {Result: "v8"},
"hostname": {Result: fmt.Sprintf("node%d.example.com", num)},
},
})
return node
}

func validateEmbeddings(t *testing.T, nodesStream stream.Stream[types.Server], embeddingsStream stream.Stream[*embedding.Embedding]) {
Expand Down

0 comments on commit 18016e5

Please sign in to comment.