Skip to content

Commit

Permalink
[Assist] Add in SSH context Assist endpoints (gravitational#30319)
Browse files Browse the repository at this point in the history
* Add in SSH context Assist endpoints

This change introduces web endpoints used by Assist to support in SSH context functionality in the WebUI. Now a user will be able to generate Bash command and explain output in the terminal (like application logs).

* Remove mockup web UI

* Make the linter happy

* Remove unnecessary comments and add missing documentation

This commit removes a redundant comment on the AI model selection in client.go that suggested a model change that is not needed. It also adds missing documentation in chat.go, to clarify the purpose and functionality of the Reply function.

* Revert ws.WriteControl to use real time.

Usage of the fake clock was failing the tests, and it's not really beneficial in this case.
  • Loading branch information
jakule authored Aug 14, 2023
1 parent 7835175 commit e572328
Show file tree
Hide file tree
Showing 9 changed files with 535 additions and 113 deletions.
5 changes: 5 additions & 0 deletions lib/ai/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,11 @@ func (chat *Chat) Complete(ctx context.Context, userInput string, progressUpdate
}, model.NewTokenCount(), nil
}

return chat.Reply(ctx, userInput, progressUpdates)
}

// Reply replies to the user input with a message from the assistant based on the current context.
func (chat *Chat) Reply(ctx context.Context, userInput string, progressUpdates func(*model.AgentAction)) (any, *model.TokenCount, error) {
userMessage := openai.ChatCompletionMessage{
Role: openai.ChatMessageRoleUser,
Content: userInput,
Expand Down
31 changes: 29 additions & 2 deletions lib/ai/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,34 @@ func NewClientFromConfig(config openai.ClientConfig) *Client {
// toolsConfig contains all required clients and configuration for agent tools
// to interact with Teleport.
func (client *Client) NewChat(username string, toolsConfig model.ToolsConfig) (*Chat, error) {
agent, err := model.NewAgent(username, toolsConfig)
tools := []model.Tool{
model.NewExecutionTool(),
}
if !toolsConfig.DisableEmbeddingsTool {
tools = append(tools, model.NewRetrievalTool(toolsConfig.EmbeddingsClient, toolsConfig.NodeClient,
toolsConfig.AccessChecker, username))
}
agent, err := model.NewAgent(tools...)
if err != nil {
return nil, trace.Wrap(err)
}
return &Chat{
client: client,
messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleSystem,
Content: model.PromptCharacter(username),
},
},
// Initialize a tokenizer for prompt token accounting.
// Cl100k is used by GPT-3 and GPT-4.
tokenizer: codec.NewCl100kBase(),
agent: agent,
}, nil
}

func (client *Client) NewCommand(username string) (*Chat, error) {
agent, err := model.NewAgent(model.NewGenerateTool())
if err != nil {
return nil, trace.Wrap(err)
}
Expand Down Expand Up @@ -121,7 +148,7 @@ func (client *Client) CommandSummary(ctx context.Context, messages []openai.Chat
return completion, tc, trace.Wrap(err)
}

// ClassifyMessage takes a user message, a list of categories, and uses the AI mode as a zero shot classifier.
// ClassifyMessage takes a user message, a list of categories, and uses the AI mode as a zero-shot classifier.
func (client *Client) ClassifyMessage(ctx context.Context, message string, classes map[string]string) (string, error) {
resp, err := client.svc.CreateChatCompletion(
ctx,
Expand Down
60 changes: 45 additions & 15 deletions lib/ai/model/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,23 +47,34 @@ const (
finalResponseHeader = "<FINAL RESPONSE>"
)

// NewAgent creates a new agent. The Assist agent which defines the model responsible for the Assist feature.
func NewAgent(username string, config ToolsConfig) (*Agent, error) {
err := config.CheckAndSetDefaults()
if err != nil {
return nil, trace.Wrap(err)
}
// NewExecutionTool creates a new execution tool. The execution tool is responsible for executing commands.
func NewExecutionTool() Tool {
return &commandExecutionTool{}
}

tools := []Tool{&commandExecutionTool{}}
// NewGenerateTool creates a new generation tool. The generation tool is responsible for generating Bash commands.
func NewGenerateTool() Tool {
return &commandGenerationTool{}
}

if !config.DisableEmbeddingsTool {
tools = append(tools,
&embeddingRetrievalTool{
assistClient: config.EmbeddingsClient,
currentUser: username,
nodeClient: config.NodeClient,
userAccessChecker: config.AccessChecker,
})
// NewRetrievalTool creates a new retrieval tool. The retrieval tool is responsible for retrieving embeddings.
func NewRetrievalTool(assistClient assist.AssistEmbeddingServiceClient,
nodeClient NodeGetter,
userAccessChecker services.AccessChecker,
currentUser string,
) Tool {
return &embeddingRetrievalTool{
assistClient: assistClient,
currentUser: currentUser,
nodeClient: nodeClient,
userAccessChecker: userAccessChecker,
}
}

// NewAgent creates a new agent. The Assist agent which defines the model responsible for the Assist feature.
func NewAgent(tools ...Tool) (*Agent, error) {
if len(tools) == 0 {
return nil, trace.BadParameter("at least one tool is required")
}

return &Agent{
Expand Down Expand Up @@ -264,6 +275,25 @@ func (a *Agent) takeNextStep(ctx context.Context, state *executionState, progres
return stepOutput{finish: &agentFinish{output: completion}}, nil
}

if tool, ok := tool.(*commandGenerationTool); ok {
input, err := tool.parseInput(action.Input)
if err != nil {
action := &AgentAction{
Action: actionException,
Input: observationPrefix + "Invalid or incomplete response",
Log: thoughtPrefix + err.Error(),
}

return stepOutput{action: action, observation: action.Input}, nil
}
completion := &GeneratedCommand{
Command: input.Command,
}

log.Tracef("agent decided on command generation, let's translate to an agentFinish")
return stepOutput{finish: &agentFinish{output: completion}}, nil
}

runOut, err := tool.Run(ctx, action.Input)
if err != nil {
return stepOutput{}, trace.Wrap(err)
Expand Down
80 changes: 80 additions & 0 deletions lib/ai/model/generationtool.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
/*
* Copyright 2023 Gravitational, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package model

import (
"context"
"fmt"

"github.com/gravitational/trace"
)

type commandGenerationTool struct{}

type commandGenerationToolInput struct {
// Command is a unix command to execute.
Command string `json:"command"`
}

func (c *commandGenerationTool) Name() string {
return "Command Generation"
}

func (c *commandGenerationTool) Description() string {
// acknowledgement field is used to convince the LLM to return the JSON.
// Base on my testing LLM ignores the JSON when the schema has only one field.
// Adding additional "pseudo-fields" to the schema makes the LLM return the JSON.
return fmt.Sprintf(`Generate a Bash command.
The input must be a JSON object with the following schema:
%vjson
{
"command": string, \\ The generated command
"acknowledgement": boolean \\ Set to true to ackowledge that you understand the formatting
}
%v
`, "```", "```")
}

func (c *commandGenerationTool) Run(_ context.Context, _ string) (string, error) {
// This is stubbed because commandGenerationTool is handled specially.
// This is because execution of this tool breaks the loop and returns a command suggestion to the user.
// It is still handled as a tool because testing has shown that the LLM behaves better when it is treated as a tool.
//
// In addition, treating it as a Tool interface item simplifies the display and prompt assembly logic significantly.
return "", trace.NotImplemented("not implemented")
}

// parseInput is called in a special case if the planned tool is commandExecutionTool.
// This is because commandExecutionTool is handled differently from most other tools and forcibly terminates the thought loop.
func (*commandGenerationTool) parseInput(input string) (*commandGenerationToolInput, error) {
output, err := parseJSONFromModel[commandGenerationToolInput](input)
if err != nil {
return nil, err
}

if output.Command == "" {
return nil, &invalidOutputError{
coarse: "command generation: missing command",
detail: "command must be non-empty",
}
}

// Ignore the acknowledgement field.
// We do not care about the value. Having the command it enough.

return &output, nil
}
5 changes: 5 additions & 0 deletions lib/ai/model/messages.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,8 @@ type CompletionCommand struct {
Nodes []string `json:"nodes,omitempty"`
Labels []Label `json:"labels,omitempty"`
}

// GeneratedCommand represents a Bash command generated by LLM.
type GeneratedCommand struct {
Command string `json:"command"`
}
3 changes: 2 additions & 1 deletion lib/ai/model/prompt.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,8 @@ func ConversationCommandResult(result map[string][]byte) string {
message.WriteString(string(output))
message.WriteString("\n")
}
message.WriteString("Based on the chat history, extract relevant information out of the command output and write a summary.")
message.WriteString("Based on the chat history, extract relevant information out of the command output and write a summary. " +
"For error messages suggest a solution if possible. The solution can contain a Linux command or a description.")
return message.String()
}

Expand Down
84 changes: 83 additions & 1 deletion lib/assist/assist.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,31 @@ func (a *Assist) NewChat(ctx context.Context, assistService MessageService,
return chat, nil
}

// LightweightChat is a Teleport Assist chat that doesn't store the history
// of the conversation.
type LightweightChat struct {
assist *Assist
chat *ai.Chat
}

// NewLightweightChat creates a new Assist chat what doesn't store the history
// of the conversation.
func (a *Assist) NewLightweightChat(username string) (*LightweightChat, error) {
aichat, err := a.client.NewCommand(username) // TODO(jakule): fix this after all in-flight PRs are merged
if err != nil {
return nil, trace.Wrap(err)
}

return &LightweightChat{
assist: a,
chat: aichat,
}, nil
}

func (a *Assist) NewSSHCommand(username string) (*ai.Chat, error) {
return a.client.NewCommand(username)
}

// GenerateSummary generates a summary for the given message.
func (a *Assist) GenerateSummary(ctx context.Context, message string) (string, error) {
return a.client.Summary(ctx, message)
Expand Down Expand Up @@ -179,7 +204,7 @@ func (c *Chat) reloadMessages(ctx context.Context) error {
}

// ClassifyMessage takes a user message, a list of categories, and uses the AI
// mode as a zero shot classifier. It returns an error if the classification
// mode as a zero-shot classifier. It returns an error if the classification
// result is not a valid class.
func (a *Assist) ClassifyMessage(ctx context.Context, message string, classes map[string]string) (string, error) {
category, err := a.client.ClassifyMessage(ctx, message, classes)
Expand Down Expand Up @@ -406,6 +431,63 @@ func (c *Chat) ProcessComplete(ctx context.Context, onMessage onMessageFunc, use
return tokenCount, nil
}

// ProcessComplete processes a user message and returns the assistant's response.
func (c *LightweightChat) ProcessComplete(ctx context.Context, onMessage onMessageFunc, userInput string,
) (*model.TokenCount, error) {
progressUpdates := func(update *model.AgentAction) {
payload, err := json.Marshal(update)
if err != nil {
log.WithError(err).Debugf("Failed to marshal progress update: %v", update)
return
}

if err := onMessage(MessageKindProgressUpdate, payload, c.assist.clock.Now().UTC()); err != nil {
log.WithError(err).Debugf("Failed to send progress update: %v", update)
return
}
}

message, tokenCount, err := c.chat.Reply(ctx, userInput, progressUpdates)
if err != nil {
return nil, trace.Wrap(err)
}

c.chat.Insert(openai.ChatMessageRoleUser, userInput)

switch message := message.(type) {
case *model.Message:
c.chat.Insert(openai.ChatMessageRoleAssistant, message.Content)
if err := onMessage(MessageKindAssistantMessage, []byte(message.Content), c.assist.clock.Now().UTC()); err != nil {
return nil, trace.Wrap(err)
}
case *model.GeneratedCommand:
c.chat.Insert(openai.ChatMessageRoleAssistant, message.Command)
if err := onMessage(MessageKindCommand, []byte(message.Command), c.assist.clock.Now().UTC()); err != nil {
return nil, trace.Wrap(err)
}
case *model.StreamingMessage:
if err := func() error {
var text strings.Builder
defer onMessage(MessageKindAssistantPartialFinalize, nil, c.assist.clock.Now().UTC())
for part := range message.Parts {
text.WriteString(part)

if err := onMessage(MessageKindAssistantPartialMessage, []byte(part), c.assist.clock.Now().UTC()); err != nil {
return trace.Wrap(err)
}
}
c.chat.Insert(openai.ChatMessageRoleAssistant, text.String())
return nil
}(); err != nil {
return nil, trace.Wrap(err)
}
default:
return nil, trace.Errorf("Unexpected message type: %T", message)
}

return tokenCount, nil
}

func getOpenAITokenFromDefaultPlugin(ctx context.Context, proxyClient PluginGetter) (string, error) {
// Try retrieving credentials from the plugin resource first
openaiPlugin, err := proxyClient.PluginsClient().GetPlugin(ctx, &pluginsv1.GetPluginRequest{
Expand Down
Loading

0 comments on commit e572328

Please sign in to comment.