Skip to content

Commit

Permalink
Improve error handling and embedding prompts; fix typos (gravitationa…
Browse files Browse the repository at this point in the history
…l#28403)

* "Improve error handling and embedding prompts; fix typos"

This commit encompasses several changes. First, an error handling routine has been added in AssistContext.tsx to properly close a WebSocket connection and finish all results. The intent is to ensure that execution fails gracefully when a session doesn't end normally. In tool.go, user instructions have been made more explicit to ensure users check access to nodes before generating any commands. It warns them that not checking access will cause error. Also, some minor typos were corrected in agent.go and messages.go for better readability.

* "Refactor 'hosts' to 'nodes' in AI Tool Descriptions"

This commit refactors the language from 'host' terminology to 'node' terminology in the AI tool's generated responses as the LLM seems to be confused when generating queries with embeddings.

* Update expected test values in chat_test.go

The expected values in three different tests in chat_test.go have been updated. This change was required because the underlying algorithm has been adjusted and these modifications will keep the tests aligned with the current algorithm's behavior.
  • Loading branch information
jakule authored Jun 29, 2023
1 parent ebecb59 commit c9bf80b
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 10 deletions.
6 changes: 3 additions & 3 deletions lib/ai/chat_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func TestChat_PromptTokens(t *testing.T) {
Content: "Hello",
},
},
want: 703,
want: 743,
},
{
name: "system and user messages",
Expand All @@ -63,7 +63,7 @@ func TestChat_PromptTokens(t *testing.T) {
Content: "Hi LLM.",
},
},
want: 711,
want: 751,
},
{
name: "tokenize our prompt",
Expand All @@ -77,7 +77,7 @@ func TestChat_PromptTokens(t *testing.T) {
Content: "Show me free disk space on localhost node.",
},
},
want: 914,
want: 954,
},
}

Expand Down
2 changes: 1 addition & 1 deletion lib/ai/model/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ type Agent struct {
tools []Tool
}

// agentAction is an event type represetning the decision to take a single action, typically a tool invocation.
// agentAction is an event type representing the decision to take a single action, typically a tool invocation.
type agentAction struct {
// The action to take, typically a tool name.
action string
Expand Down
2 changes: 1 addition & 1 deletion lib/ai/model/messages.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ type TokensUsed struct {
}

// UsedTokens returns the number of tokens used during a single invocation of the agent.
// This method creates a convinient way to get TokensUsed from embedded structs.
// This method creates a convenient way to get TokensUsed from embedded structs.
func (t *TokensUsed) UsedTokens() *TokensUsed {
return t
}
Expand Down
12 changes: 7 additions & 5 deletions lib/ai/model/tool.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,13 @@ func (c *commandExecutionTool) Name() string {
}

func (c *commandExecutionTool) Description() string {
return fmt.Sprintf(`Execute a command on a set of remote hosts based on a set of hostnames or/and a set of labels.
return fmt.Sprintf(`Execute a command on a set of remote nodes based on a set of node names or/and a set of labels.
The input must be a JSON object with the following schema:
%vjson
{
"command": string, \\ The command to execute
"nodes": []string, \\ Execute a command on all nodes that have the given hostnames
"nodes": []string, \\ Execute a command on all nodes that have the given node names
"labels": []{"key": string, "value": string} \\ Execute a command on all nodes that has at least one of the labels
}
%v
Expand Down Expand Up @@ -140,7 +140,7 @@ func (e *embeddingRetrievalTool) Run(ctx context.Context, input string) (string,
if sb.Len() == 0 {
// Either no nodes are connected, embedding process hasn't started yet, or
// the user doesn't have access to any resources.
return "Didn't find any nodes matching your query", nil
return "Didn't find any nodes matching the query", nil
}

return sb.String(), nil
Expand All @@ -151,11 +151,13 @@ func (e *embeddingRetrievalTool) Name() string {
}

func (e *embeddingRetrievalTool) Description() string {
return fmt.Sprintf(`Ask about existing remote hosts to fetch node names or/and set of labels. Use this capability instead of guessing the names and labels.
return fmt.Sprintf(`Ask about existing remote nodes that user has access to fetch node names or/and set of labels.
Always use this capability before returning generating any command. Do not assume that the user has access to any nodes. Returning a command without checking for access will result in an error.
Always prefer to use labler rather than node names.
The input must be a JSON object with the following schema:
%vjson
{
"question": string \\ Question about the available remote hosts
"question": string \\ Question about the available remote nodes
}
%v
`, "```", "```")
Expand Down
15 changes: 15 additions & 0 deletions web/packages/teleport/src/Assist/context/AssistContext.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,21 @@ export function AssistContextProvider(props: PropsWithChildren<unknown>) {
break;
}
};

executeCommandWebSocket.current.onclose = () => {
executeCommandWebSocket.current = null;

// If the execution failed, we won't get a SESSION_END message, so we
// need to mark all the results as finished here.
for (const nodeId of nodeIdToResultId.keys()) {
dispatch({
type: AssistStateActionType.FinishCommandResult,
conversationId: state.conversations.selectedId,
commandResultId: nodeIdToResultId.get(nodeId),
});
}
nodeIdToResultId.clear();
};
}

async function deleteConversation(conversationId: string) {
Expand Down

0 comments on commit c9bf80b

Please sign in to comment.