Skip to content

Add initial support for bedrock #26

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,20 @@ require (
github.com/alecthomas/chroma/v2 v2.15.0 // indirect
github.com/andybalholm/cascadia v1.3.2 // indirect
github.com/atotto/clipboard v0.1.4 // indirect
github.com/aws/aws-sdk-go-v2 v1.30.3 // indirect
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.3 // indirect
github.com/aws/aws-sdk-go-v2/config v1.27.27 // indirect
github.com/aws/aws-sdk-go-v2/credentials v1.17.27 // indirect
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.11 // indirect
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.15 // indirect
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.15 // indirect
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.3 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.17 // indirect
github.com/aws/aws-sdk-go-v2/service/sso v1.22.4 // indirect
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.26.4 // indirect
github.com/aws/aws-sdk-go-v2/service/sts v1.30.3 // indirect
github.com/aws/smithy-go v1.20.3 // indirect
github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
github.com/aymerick/douceur v0.2.0 // indirect
github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc // indirect
Expand Down
28 changes: 28 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,34 @@ github.com/anthropics/anthropic-sdk-go v0.2.0-beta.2 h1:h7qxtumNjKPWFv1QM/HJy60M
github.com/anthropics/anthropic-sdk-go v0.2.0-beta.2/go.mod h1:AapDW22irxK2PSumZiQXYUFvsdQgkwIWlpESweWZI/c=
github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4=
github.com/atotto/clipboard v0.1.4/go.mod h1:ZY9tmq7sm5xIbd9bOK4onWV4S6X0u6GY7Vn0Yu86PYI=
github.com/aws/aws-sdk-go-v2 v1.30.3 h1:jUeBtG0Ih+ZIFH0F4UkmL9w3cSpaMv9tYYDbzILP8dY=
github.com/aws/aws-sdk-go-v2 v1.30.3/go.mod h1:nIQjQVp5sfpQcTc9mPSr1B0PaWK5ByX9MOoDadSN4lc=
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.3 h1:tW1/Rkad38LA15X4UQtjXZXNKsCgkshC3EbmcUmghTg=
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.3/go.mod h1:UbnqO+zjqk3uIt9yCACHJ9IVNhyhOCnYk8yA19SAWrM=
github.com/aws/aws-sdk-go-v2/config v1.27.27 h1:HdqgGt1OAP0HkEDDShEl0oSYa9ZZBSOmKpdpsDMdO90=
github.com/aws/aws-sdk-go-v2/config v1.27.27/go.mod h1:MVYamCg76dFNINkZFu4n4RjDixhVr51HLj4ErWzrVwg=
github.com/aws/aws-sdk-go-v2/credentials v1.17.27 h1:2raNba6gr2IfA0eqqiP2XiQ0UVOpGPgDSi0I9iAP+UI=
github.com/aws/aws-sdk-go-v2/credentials v1.17.27/go.mod h1:gniiwbGahQByxan6YjQUMcW4Aov6bLC3m+evgcoN4r4=
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.11 h1:KreluoV8FZDEtI6Co2xuNk/UqI9iwMrOx/87PBNIKqw=
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.11/go.mod h1:SeSUYBLsMYFoRvHE0Tjvn7kbxaUhl75CJi1sbfhMxkU=
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.15 h1:SoNJ4RlFEQEbtDcCEt+QG56MY4fm4W8rYirAmq+/DdU=
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.15/go.mod h1:U9ke74k1n2bf+RIgoX1SXFed1HLs51OgUSs+Ph0KJP8=
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.15 h1:C6WHdGnTDIYETAm5iErQUiVNsclNx9qbJVPIt03B6bI=
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.15/go.mod h1:ZQLZqhcu+JhSrA9/NXRm8SkDvsycE+JkV3WGY41e+IM=
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0 h1:hT8rVHwugYE2lEfdFE0QWVo81lF7jMrYJVDWI+f+VxU=
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0/go.mod h1:8tu/lYfQfFe6IGnaOdrpVgEL2IrrDOf6/m9RQum4NkY=
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.3 h1:dT3MqvGhSoaIhRseqw2I0yH81l7wiR2vjs57O51EAm8=
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.3/go.mod h1:GlAeCkHwugxdHaueRr4nhPuY+WW+gR8UjlcqzPr1SPI=
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.17 h1:HGErhhrxZlQ044RiM+WdoZxp0p+EGM62y3L6pwA4olE=
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.17/go.mod h1:RkZEx4l0EHYDJpWppMJ3nD9wZJAa8/0lq9aVC+r2UII=
github.com/aws/aws-sdk-go-v2/service/sso v1.22.4 h1:BXx0ZIxvrJdSgSvKTZ+yRBeSqqgPM89VPlulEcl37tM=
github.com/aws/aws-sdk-go-v2/service/sso v1.22.4/go.mod h1:ooyCOXjvJEsUw7x+ZDHeISPMhtwI3ZCB7ggFMcFfWLU=
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.26.4 h1:yiwVzJW2ZxZTurVbYWA7QOrAaCYQR72t0wrSBfoesUE=
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.26.4/go.mod h1:0oxfLkpz3rQ/CHlx5hB7H69YUpFiI1tql6Q6Ne+1bCw=
github.com/aws/aws-sdk-go-v2/service/sts v1.30.3 h1:ZsDKRLXGWHk8WdtyYMoGNO7bTudrvuKpDKgMVRlepGE=
github.com/aws/aws-sdk-go-v2/service/sts v1.30.3/go.mod h1:zwySh8fpFyXp9yOr/KVzxOl8SRqgf/IDw5aUt9UKFcQ=
github.com/aws/smithy-go v1.20.3 h1:ryHwveWzPV5BIof6fyDvor6V3iUL7nTfiTKXHiW05nE=
github.com/aws/smithy-go v1.20.3/go.mod h1:krry+ya/rV9RDcV/Q16kpu6ypI4K2czasz0NC3qS14E=
github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k=
github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8=
github.com/aymanbagabas/go-udiff v0.2.0 h1:TK0fH4MteXUDspT88n8CKzvK0X9O2xu9yQjWpi6yML8=
Expand Down
7 changes: 7 additions & 0 deletions internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ type Model struct {
// TODO: Maybe support multiple models for different purposes
}

type AnthropicConfig struct {
DisableCache bool `json:"disableCache"`
UseBedrock bool `json:"useBedrock"`
}

type Provider struct {
APIKey string `json:"apiKey"`
Enabled bool `json:"enabled"`
Expand Down Expand Up @@ -130,6 +135,8 @@ func Load(debug bool) error {
defaultModelSet = true
}
}

viper.SetDefault("providers.bedrock.enabled", true)
// TODO: add more providers
cfg = &Config{}

Expand Down
23 changes: 23 additions & 0 deletions internal/llm/agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,29 @@ func getAgentProviders(ctx context.Context, model models.Model) (provider.Provid
return nil, nil, err
}

case models.ProviderBedrock:
var err error
agentProvider, err = provider.NewBedrockProvider(
provider.WithBedrockSystemMessage(
prompt.CoderAnthropicSystemPrompt(),
),
provider.WithBedrockMaxTokens(maxTokens),
provider.WithBedrockModel(model),
)
if err != nil {
return nil, nil, err
}
titleGenerator, err = provider.NewBedrockProvider(
provider.WithBedrockSystemMessage(
prompt.TitlePrompt(),
),
provider.WithBedrockMaxTokens(maxTokens),
provider.WithBedrockModel(model),
)
if err != nil {
return nil, nil, err
}

}

return agentProvider, titleGenerator, nil
Expand Down
16 changes: 16 additions & 0 deletions internal/llm/models/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,15 @@ const (

// GROQ
QWENQwq ModelID = "qwen-qwq"

// Bedrock
BedrockClaude37Sonnet ModelID = "bedrock.claude-3.7-sonnet"
)

const (
ProviderOpenAI ModelProvider = "openai"
ProviderAnthropic ModelProvider = "anthropic"
ProviderBedrock ModelProvider = "bedrock"
ProviderGemini ModelProvider = "gemini"
ProviderGROQ ModelProvider = "groq"
)
Expand Down Expand Up @@ -119,4 +123,16 @@ var SupportedModels = map[ModelID]Model{
CostPer1MOutCached: 0,
CostPer1MOut: 0,
},

// Bedrock
BedrockClaude37Sonnet: {
ID: BedrockClaude37Sonnet,
Name: "Bedrock: Claude 3.7 Sonnet",
Provider: ProviderBedrock,
APIModel: "anthropic.claude-3-7-sonnet-20250219-v1:0",
CostPer1MIn: 3.0,
CostPer1MInCached: 3.75,
CostPer1MOutCached: 0.30,
CostPer1MOut: 15.0,
},
}
33 changes: 28 additions & 5 deletions internal/llm/provider/anthropic.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"time"

"github.com/anthropics/anthropic-sdk-go"
"github.com/anthropics/anthropic-sdk-go/bedrock"
"github.com/anthropics/anthropic-sdk-go/option"
"github.com/kujtimiihoxha/termai/internal/llm/models"
"github.com/kujtimiihoxha/termai/internal/llm/tools"
Expand All @@ -21,6 +22,8 @@ type anthropicProvider struct {
maxTokens int64
apiKey string
systemMessage string
useBedrock bool
disableCache bool
}

type AnthropicOption func(*anthropicProvider)
Expand Down Expand Up @@ -49,6 +52,18 @@ func WithAnthropicKey(apiKey string) AnthropicOption {
}
}

func WithAnthropicBedrock() AnthropicOption {
return func(a *anthropicProvider) {
a.useBedrock = true
}
}

func WithAnthropicDisableCache() AnthropicOption {
return func(a *anthropicProvider) {
a.disableCache = true
}
}

func NewAnthropicProvider(opts ...AnthropicOption) (Provider, error) {
provider := &anthropicProvider{
maxTokens: 1024,
Expand All @@ -62,7 +77,16 @@ func NewAnthropicProvider(opts ...AnthropicOption) (Provider, error) {
return nil, errors.New("system message is required")
}

provider.client = anthropic.NewClient(option.WithAPIKey(provider.apiKey))
anthropicOptions := []option.RequestOption{}

if provider.apiKey != "" {
anthropicOptions = append(anthropicOptions, option.WithAPIKey(provider.apiKey))
}
if provider.useBedrock {
anthropicOptions = append(anthropicOptions, bedrock.WithLoadDefaultConfig(context.Background()))
}

provider.client = anthropic.NewClient(anthropicOptions...)
return provider, nil
}

Expand Down Expand Up @@ -338,7 +362,7 @@ func (a *anthropicProvider) convertToAnthropicTools(tools []tools.BaseTool) []an
},
}

if i == len(tools)-1 {
if i == len(tools)-1 && !a.disableCache {
toolParam.CacheControl = anthropic.CacheControlEphemeralParam{
Type: "ephemeral",
}
Expand All @@ -358,7 +382,7 @@ func (a *anthropicProvider) convertToAnthropicMessages(messages []message.Messag
switch msg.Role {
case message.User:
content := anthropic.NewTextBlock(msg.Content().String())
if cachedBlocks < 2 {
if cachedBlocks < 2 && !a.disableCache {
content.OfRequestTextBlock.CacheControl = anthropic.CacheControlEphemeralParam{
Type: "ephemeral",
}
Expand All @@ -370,7 +394,7 @@ func (a *anthropicProvider) convertToAnthropicMessages(messages []message.Messag
blocks := []anthropic.ContentBlockParamUnion{}
if msg.Content().String() != "" {
content := anthropic.NewTextBlock(msg.Content().String())
if cachedBlocks < 2 {
if cachedBlocks < 2 && !a.disableCache {
content.OfRequestTextBlock.CacheControl = anthropic.CacheControlEphemeralParam{
Type: "ephemeral",
}
Expand Down Expand Up @@ -404,4 +428,3 @@ func (a *anthropicProvider) convertToAnthropicMessages(messages []message.Messag

return anthropicMessages
}

87 changes: 87 additions & 0 deletions internal/llm/provider/bedrock.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
package provider

import (
"context"
"errors"
"fmt"
"os"
"strings"

"github.com/kujtimiihoxha/termai/internal/llm/models"
"github.com/kujtimiihoxha/termai/internal/llm/tools"
"github.com/kujtimiihoxha/termai/internal/message"
)

type bedrockProvider struct {
childProvider Provider
model models.Model
maxTokens int64
systemMessage string
}

func (b *bedrockProvider) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
return b.childProvider.SendMessages(ctx, messages, tools)
}

func (b *bedrockProvider) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (<-chan ProviderEvent, error) {
return b.childProvider.StreamResponse(ctx, messages, tools)
}

func NewBedrockProvider(opts ...BedrockOption) (Provider, error) {
provider := &bedrockProvider{}
for _, opt := range opts {
opt(provider)
}

// based on the AWS region prefix the model name with, us, eu, ap, sa, etc.
region := os.Getenv("AWS_REGION")
if region == "" {
region = os.Getenv("AWS_DEFAULT_REGION")
}

if region == "" {
return nil, errors.New("AWS_REGION or AWS_DEFAULT_REGION environment variable is required")
}
if len(region) < 2 {
return nil, errors.New("AWS_REGION or AWS_DEFAULT_REGION environment variable is invalid")
}
regionPrefix := region[:2]
provider.model.APIModel = fmt.Sprintf("%s.%s", regionPrefix, provider.model.APIModel)

if strings.Contains(string(provider.model.APIModel), "anthropic") {
anthropic, err := NewAnthropicProvider(
WithAnthropicModel(provider.model),
WithAnthropicMaxTokens(provider.maxTokens),
WithAnthropicSystemMessage(provider.systemMessage),
WithAnthropicBedrock(),
WithAnthropicDisableCache(),
)
provider.childProvider = anthropic
if err != nil {
return nil, err
}
} else {
return nil, errors.New("unsupported model for bedrock provider")
}
return provider, nil
}

type BedrockOption func(*bedrockProvider)

func WithBedrockSystemMessage(message string) BedrockOption {
return func(a *bedrockProvider) {
a.systemMessage = message
}
}

func WithBedrockMaxTokens(maxTokens int64) BedrockOption {
return func(a *bedrockProvider) {
a.maxTokens = maxTokens
}
}

func WithBedrockModel(model models.Model) BedrockOption {
return func(a *bedrockProvider) {
a.model = model
}
}
16 changes: 14 additions & 2 deletions internal/tui/components/repl/editor.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package repl

import (
"log"
"strings"

"github.com/charmbracelet/bubbles/key"
Expand Down Expand Up @@ -138,11 +139,22 @@ func (m *editorCmp) SetSize(width int, height int) {

func (m *editorCmp) Send() tea.Cmd {
return func() tea.Msg {
messages, _ := m.app.Messages.List(m.sessionID)
messages, err := m.app.Messages.List(m.sessionID)
log.Printf("error: %v", err)
log.Printf("messages: %v", messages)

if err != nil {
return util.ReportError(err)
}
if hasUnfinishedMessages(messages) {
return util.ReportWarn("Assistant is still working on the previous message")
}
a, _ := agent.NewCoderAgent(m.app)
a, err := agent.NewCoderAgent(m.app)
log.Printf("error: %v", err)
log.Printf("agent: %v", a)
if err != nil {
return util.ReportError(err)
}

content := strings.Join(m.editor.GetBuffer().Lines(), "\n")
go a.Generate(m.sessionID, content)
Expand Down