From d83b68f75cf66d4e1c57952303f4d818b74d1205 Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Wed, 9 Apr 2025 17:45:41 +0200 Subject: [PATCH] add bedrock support --- go.mod | 14 +++++ go.sum | 28 +++++++++ internal/config/config.go | 7 +++ internal/llm/agent/agent.go | 23 +++++++ internal/llm/models/models.go | 16 +++++ internal/llm/provider/anthropic.go | 33 ++++++++-- internal/llm/provider/bedrock.go | 87 ++++++++++++++++++++++++++ internal/tui/components/repl/editor.go | 16 ++++- 8 files changed, 217 insertions(+), 7 deletions(-) create mode 100644 internal/llm/provider/bedrock.go diff --git a/go.mod b/go.mod index fd9298e..ab519a5 100644 --- a/go.mod +++ b/go.mod @@ -47,6 +47,20 @@ require ( github.com/alecthomas/chroma/v2 v2.15.0 // indirect github.com/andybalholm/cascadia v1.3.2 // indirect github.com/atotto/clipboard v0.1.4 // indirect + github.com/aws/aws-sdk-go-v2 v1.30.3 // indirect + github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.3 // indirect + github.com/aws/aws-sdk-go-v2/config v1.27.27 // indirect + github.com/aws/aws-sdk-go-v2/credentials v1.17.27 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.11 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.15 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.15 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.3 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.17 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.22.4 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.26.4 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.30.3 // indirect + github.com/aws/smithy-go v1.20.3 // indirect github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect github.com/aymerick/douceur v0.2.0 // indirect github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc // indirect diff --git a/go.sum b/go.sum index 10d753d..c4b32ef 100644 --- a/go.sum +++ b/go.sum @@ -28,6 +28,34 @@ github.com/anthropics/anthropic-sdk-go v0.2.0-beta.2 h1:h7qxtumNjKPWFv1QM/HJy60M github.com/anthropics/anthropic-sdk-go v0.2.0-beta.2/go.mod h1:AapDW22irxK2PSumZiQXYUFvsdQgkwIWlpESweWZI/c= github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4= github.com/atotto/clipboard v0.1.4/go.mod h1:ZY9tmq7sm5xIbd9bOK4onWV4S6X0u6GY7Vn0Yu86PYI= +github.com/aws/aws-sdk-go-v2 v1.30.3 h1:jUeBtG0Ih+ZIFH0F4UkmL9w3cSpaMv9tYYDbzILP8dY= +github.com/aws/aws-sdk-go-v2 v1.30.3/go.mod h1:nIQjQVp5sfpQcTc9mPSr1B0PaWK5ByX9MOoDadSN4lc= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.3 h1:tW1/Rkad38LA15X4UQtjXZXNKsCgkshC3EbmcUmghTg= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.3/go.mod h1:UbnqO+zjqk3uIt9yCACHJ9IVNhyhOCnYk8yA19SAWrM= +github.com/aws/aws-sdk-go-v2/config v1.27.27 h1:HdqgGt1OAP0HkEDDShEl0oSYa9ZZBSOmKpdpsDMdO90= +github.com/aws/aws-sdk-go-v2/config v1.27.27/go.mod h1:MVYamCg76dFNINkZFu4n4RjDixhVr51HLj4ErWzrVwg= +github.com/aws/aws-sdk-go-v2/credentials v1.17.27 h1:2raNba6gr2IfA0eqqiP2XiQ0UVOpGPgDSi0I9iAP+UI= +github.com/aws/aws-sdk-go-v2/credentials v1.17.27/go.mod h1:gniiwbGahQByxan6YjQUMcW4Aov6bLC3m+evgcoN4r4= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.11 h1:KreluoV8FZDEtI6Co2xuNk/UqI9iwMrOx/87PBNIKqw= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.11/go.mod h1:SeSUYBLsMYFoRvHE0Tjvn7kbxaUhl75CJi1sbfhMxkU= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.15 h1:SoNJ4RlFEQEbtDcCEt+QG56MY4fm4W8rYirAmq+/DdU= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.15/go.mod h1:U9ke74k1n2bf+RIgoX1SXFed1HLs51OgUSs+Ph0KJP8= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.15 h1:C6WHdGnTDIYETAm5iErQUiVNsclNx9qbJVPIt03B6bI= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.15/go.mod h1:ZQLZqhcu+JhSrA9/NXRm8SkDvsycE+JkV3WGY41e+IM= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0 h1:hT8rVHwugYE2lEfdFE0QWVo81lF7jMrYJVDWI+f+VxU= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0/go.mod h1:8tu/lYfQfFe6IGnaOdrpVgEL2IrrDOf6/m9RQum4NkY= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.3 h1:dT3MqvGhSoaIhRseqw2I0yH81l7wiR2vjs57O51EAm8= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.3/go.mod h1:GlAeCkHwugxdHaueRr4nhPuY+WW+gR8UjlcqzPr1SPI= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.17 h1:HGErhhrxZlQ044RiM+WdoZxp0p+EGM62y3L6pwA4olE= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.17/go.mod h1:RkZEx4l0EHYDJpWppMJ3nD9wZJAa8/0lq9aVC+r2UII= +github.com/aws/aws-sdk-go-v2/service/sso v1.22.4 h1:BXx0ZIxvrJdSgSvKTZ+yRBeSqqgPM89VPlulEcl37tM= +github.com/aws/aws-sdk-go-v2/service/sso v1.22.4/go.mod h1:ooyCOXjvJEsUw7x+ZDHeISPMhtwI3ZCB7ggFMcFfWLU= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.26.4 h1:yiwVzJW2ZxZTurVbYWA7QOrAaCYQR72t0wrSBfoesUE= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.26.4/go.mod h1:0oxfLkpz3rQ/CHlx5hB7H69YUpFiI1tql6Q6Ne+1bCw= +github.com/aws/aws-sdk-go-v2/service/sts v1.30.3 h1:ZsDKRLXGWHk8WdtyYMoGNO7bTudrvuKpDKgMVRlepGE= +github.com/aws/aws-sdk-go-v2/service/sts v1.30.3/go.mod h1:zwySh8fpFyXp9yOr/KVzxOl8SRqgf/IDw5aUt9UKFcQ= +github.com/aws/smithy-go v1.20.3 h1:ryHwveWzPV5BIof6fyDvor6V3iUL7nTfiTKXHiW05nE= +github.com/aws/smithy-go v1.20.3/go.mod h1:krry+ya/rV9RDcV/Q16kpu6ypI4K2czasz0NC3qS14E= github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k= github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8= github.com/aymanbagabas/go-udiff v0.2.0 h1:TK0fH4MteXUDspT88n8CKzvK0X9O2xu9yQjWpi6yML8= diff --git a/internal/config/config.go b/internal/config/config.go index 03b26e3..8029058 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -36,6 +36,11 @@ type Model struct { // TODO: Maybe support multiple models for different purposes } +type AnthropicConfig struct { + DisableCache bool `json:"disableCache"` + UseBedrock bool `json:"useBedrock"` +} + type Provider struct { APIKey string `json:"apiKey"` Enabled bool `json:"enabled"` @@ -130,6 +135,8 @@ func Load(debug bool) error { defaultModelSet = true } } + + viper.SetDefault("providers.bedrock.enabled", true) // TODO: add more providers cfg = &Config{} diff --git a/internal/llm/agent/agent.go b/internal/llm/agent/agent.go index baf78be..78062d0 100644 --- a/internal/llm/agent/agent.go +++ b/internal/llm/agent/agent.go @@ -380,6 +380,29 @@ func getAgentProviders(ctx context.Context, model models.Model) (provider.Provid return nil, nil, err } + case models.ProviderBedrock: + var err error + agentProvider, err = provider.NewBedrockProvider( + provider.WithBedrockSystemMessage( + prompt.CoderAnthropicSystemPrompt(), + ), + provider.WithBedrockMaxTokens(maxTokens), + provider.WithBedrockModel(model), + ) + if err != nil { + return nil, nil, err + } + titleGenerator, err = provider.NewBedrockProvider( + provider.WithBedrockSystemMessage( + prompt.TitlePrompt(), + ), + provider.WithBedrockMaxTokens(maxTokens), + provider.WithBedrockModel(model), + ) + if err != nil { + return nil, nil, err + } + } return agentProvider, titleGenerator, nil diff --git a/internal/llm/models/models.go b/internal/llm/models/models.go index 2f75db9..4791218 100644 --- a/internal/llm/models/models.go +++ b/internal/llm/models/models.go @@ -31,11 +31,15 @@ const ( // GROQ QWENQwq ModelID = "qwen-qwq" + + // Bedrock + BedrockClaude37Sonnet ModelID = "bedrock.claude-3.7-sonnet" ) const ( ProviderOpenAI ModelProvider = "openai" ProviderAnthropic ModelProvider = "anthropic" + ProviderBedrock ModelProvider = "bedrock" ProviderGemini ModelProvider = "gemini" ProviderGROQ ModelProvider = "groq" ) @@ -119,4 +123,16 @@ var SupportedModels = map[ModelID]Model{ CostPer1MOutCached: 0, CostPer1MOut: 0, }, + + // Bedrock + BedrockClaude37Sonnet: { + ID: BedrockClaude37Sonnet, + Name: "Bedrock: Claude 3.7 Sonnet", + Provider: ProviderBedrock, + APIModel: "anthropic.claude-3-7-sonnet-20250219-v1:0", + CostPer1MIn: 3.0, + CostPer1MInCached: 3.75, + CostPer1MOutCached: 0.30, + CostPer1MOut: 15.0, + }, } diff --git a/internal/llm/provider/anthropic.go b/internal/llm/provider/anthropic.go index 02bd572..625976a 100644 --- a/internal/llm/provider/anthropic.go +++ b/internal/llm/provider/anthropic.go @@ -9,6 +9,7 @@ import ( "time" "github.com/anthropics/anthropic-sdk-go" + "github.com/anthropics/anthropic-sdk-go/bedrock" "github.com/anthropics/anthropic-sdk-go/option" "github.com/kujtimiihoxha/termai/internal/llm/models" "github.com/kujtimiihoxha/termai/internal/llm/tools" @@ -21,6 +22,8 @@ type anthropicProvider struct { maxTokens int64 apiKey string systemMessage string + useBedrock bool + disableCache bool } type AnthropicOption func(*anthropicProvider) @@ -49,6 +52,18 @@ func WithAnthropicKey(apiKey string) AnthropicOption { } } +func WithAnthropicBedrock() AnthropicOption { + return func(a *anthropicProvider) { + a.useBedrock = true + } +} + +func WithAnthropicDisableCache() AnthropicOption { + return func(a *anthropicProvider) { + a.disableCache = true + } +} + func NewAnthropicProvider(opts ...AnthropicOption) (Provider, error) { provider := &anthropicProvider{ maxTokens: 1024, @@ -62,7 +77,16 @@ func NewAnthropicProvider(opts ...AnthropicOption) (Provider, error) { return nil, errors.New("system message is required") } - provider.client = anthropic.NewClient(option.WithAPIKey(provider.apiKey)) + anthropicOptions := []option.RequestOption{} + + if provider.apiKey != "" { + anthropicOptions = append(anthropicOptions, option.WithAPIKey(provider.apiKey)) + } + if provider.useBedrock { + anthropicOptions = append(anthropicOptions, bedrock.WithLoadDefaultConfig(context.Background())) + } + + provider.client = anthropic.NewClient(anthropicOptions...) return provider, nil } @@ -338,7 +362,7 @@ func (a *anthropicProvider) convertToAnthropicTools(tools []tools.BaseTool) []an }, } - if i == len(tools)-1 { + if i == len(tools)-1 && !a.disableCache { toolParam.CacheControl = anthropic.CacheControlEphemeralParam{ Type: "ephemeral", } @@ -358,7 +382,7 @@ func (a *anthropicProvider) convertToAnthropicMessages(messages []message.Messag switch msg.Role { case message.User: content := anthropic.NewTextBlock(msg.Content().String()) - if cachedBlocks < 2 { + if cachedBlocks < 2 && !a.disableCache { content.OfRequestTextBlock.CacheControl = anthropic.CacheControlEphemeralParam{ Type: "ephemeral", } @@ -370,7 +394,7 @@ func (a *anthropicProvider) convertToAnthropicMessages(messages []message.Messag blocks := []anthropic.ContentBlockParamUnion{} if msg.Content().String() != "" { content := anthropic.NewTextBlock(msg.Content().String()) - if cachedBlocks < 2 { + if cachedBlocks < 2 && !a.disableCache { content.OfRequestTextBlock.CacheControl = anthropic.CacheControlEphemeralParam{ Type: "ephemeral", } @@ -404,4 +428,3 @@ func (a *anthropicProvider) convertToAnthropicMessages(messages []message.Messag return anthropicMessages } - diff --git a/internal/llm/provider/bedrock.go b/internal/llm/provider/bedrock.go new file mode 100644 index 0000000..f1afefd --- /dev/null +++ b/internal/llm/provider/bedrock.go @@ -0,0 +1,87 @@ +package provider + +import ( + "context" + "errors" + "fmt" + "os" + "strings" + + "github.com/kujtimiihoxha/termai/internal/llm/models" + "github.com/kujtimiihoxha/termai/internal/llm/tools" + "github.com/kujtimiihoxha/termai/internal/message" +) + +type bedrockProvider struct { + childProvider Provider + model models.Model + maxTokens int64 + systemMessage string +} + +func (b *bedrockProvider) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) { + return b.childProvider.SendMessages(ctx, messages, tools) +} + +func (b *bedrockProvider) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (<-chan ProviderEvent, error) { + return b.childProvider.StreamResponse(ctx, messages, tools) +} + +func NewBedrockProvider(opts ...BedrockOption) (Provider, error) { + provider := &bedrockProvider{} + for _, opt := range opts { + opt(provider) + } + + // based on the AWS region prefix the model name with, us, eu, ap, sa, etc. + region := os.Getenv("AWS_REGION") + if region == "" { + region = os.Getenv("AWS_DEFAULT_REGION") + } + + if region == "" { + return nil, errors.New("AWS_REGION or AWS_DEFAULT_REGION environment variable is required") + } + if len(region) < 2 { + return nil, errors.New("AWS_REGION or AWS_DEFAULT_REGION environment variable is invalid") + } + regionPrefix := region[:2] + provider.model.APIModel = fmt.Sprintf("%s.%s", regionPrefix, provider.model.APIModel) + + if strings.Contains(string(provider.model.APIModel), "anthropic") { + anthropic, err := NewAnthropicProvider( + WithAnthropicModel(provider.model), + WithAnthropicMaxTokens(provider.maxTokens), + WithAnthropicSystemMessage(provider.systemMessage), + WithAnthropicBedrock(), + WithAnthropicDisableCache(), + ) + provider.childProvider = anthropic + if err != nil { + return nil, err + } + } else { + return nil, errors.New("unsupported model for bedrock provider") + } + return provider, nil +} + +type BedrockOption func(*bedrockProvider) + +func WithBedrockSystemMessage(message string) BedrockOption { + return func(a *bedrockProvider) { + a.systemMessage = message + } +} + +func WithBedrockMaxTokens(maxTokens int64) BedrockOption { + return func(a *bedrockProvider) { + a.maxTokens = maxTokens + } +} + +func WithBedrockModel(model models.Model) BedrockOption { + return func(a *bedrockProvider) { + a.model = model + } +} diff --git a/internal/tui/components/repl/editor.go b/internal/tui/components/repl/editor.go index d3e8d2c..f23de0e 100644 --- a/internal/tui/components/repl/editor.go +++ b/internal/tui/components/repl/editor.go @@ -1,6 +1,7 @@ package repl import ( + "log" "strings" "github.com/charmbracelet/bubbles/key" @@ -138,11 +139,22 @@ func (m *editorCmp) SetSize(width int, height int) { func (m *editorCmp) Send() tea.Cmd { return func() tea.Msg { - messages, _ := m.app.Messages.List(m.sessionID) + messages, err := m.app.Messages.List(m.sessionID) + log.Printf("error: %v", err) + log.Printf("messages: %v", messages) + + if err != nil { + return util.ReportError(err) + } if hasUnfinishedMessages(messages) { return util.ReportWarn("Assistant is still working on the previous message") } - a, _ := agent.NewCoderAgent(m.app) + a, err := agent.NewCoderAgent(m.app) + log.Printf("error: %v", err) + log.Printf("agent: %v", a) + if err != nil { + return util.ReportError(err) + } content := strings.Join(m.editor.GetBuffer().Lines(), "\n") go a.Generate(m.sessionID, content)