diff --git a/.github/workflows/pull_request.yaml b/.github/workflows/pull_request.yaml index 8e87f54..474235f 100644 --- a/.github/workflows/pull_request.yaml +++ b/.github/workflows/pull_request.yaml @@ -38,4 +38,5 @@ jobs: git_ref: ${{ github.event.pull_request.head.sha }} secrets: OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} + ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }} diff --git a/.github/workflows/push_main.yaml b/.github/workflows/push_main.yaml index 2bf364e..edca697 100644 --- a/.github/workflows/push_main.yaml +++ b/.github/workflows/push_main.yaml @@ -13,4 +13,5 @@ jobs: git_ref: '' secrets: OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} + ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }} diff --git a/.github/workflows/run_tests.yaml b/.github/workflows/run_tests.yaml index 5ef5d78..28eb3ae 100644 --- a/.github/workflows/run_tests.yaml +++ b/.github/workflows/run_tests.yaml @@ -9,6 +9,8 @@ on: secrets: OPENAI_API_KEY: required: true + ANTHROPIC_API_KEY: + required: true jobs: test-linux: @@ -21,7 +23,7 @@ jobs: - uses: actions/setup-go@v5 with: cache: false - go-version: "1.22" + go-version: "1.23" - name: Validate run: make validate - name: Install gptscript @@ -32,28 +34,5 @@ jobs: env: GPTSCRIPT_BIN: ./gptscriptexe OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} - run: make test - - test-windows: - runs-on: windows-latest - steps: - - uses: actions/checkout@v4 - with: - fetch-depth: 1 - ref: ${{ inputs.git_ref }} - - uses: actions/setup-go@v5 - with: - cache: false - go-version: "1.22" - - name: Install gptscript - run: | - curl https://get.gptscript.ai/releases/default_windows_amd64_v1/gptscript.exe -o gptscript.exe - - name: Create config file - run: | - echo '{"credsStore":"file"}' > config - - name: Run Tests - env: - GPTSCRIPT_BIN: .\gptscript.exe - OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} - GPTSCRIPT_CONFIG_FILE: .\config + ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }} run: make test diff --git a/.golangci.yaml b/.golangci.yaml index 62d0b9e..c2c3ee3 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -1,23 +1,35 @@ +version: "2" run: timeout: 5m - -output: - format: github-actions - linters: - disable-all: true + default: none enable: - errcheck - - gofmt - - gosimple - govet - ineffassign + - revive - staticcheck - - typecheck - thelper - unused - - goimports - whitespace - - revive - fast: false - max-same-issues: 50 + exclusions: + generated: lax + presets: + - comments + - common-false-positives + - legacy + - std-error-handling + paths: + - third_party$ + - builtin$ + - examples$ +formatters: + enable: + - gofmt + - goimports + exclusions: + generated: lax + paths: + - third_party$ + - builtin$ + - examples$ diff --git a/Makefile b/Makefile index 701a2fc..0227d0a 100644 --- a/Makefile +++ b/Makefile @@ -6,7 +6,7 @@ tidy: test: go test -v ./... -GOLANGCI_LINT_VERSION ?= v1.56.1 +GOLANGCI_LINT_VERSION ?= v2.1.2 lint: if ! command -v golangci-lint &> /dev/null; then \ echo "Could not find golangci-lint, installing version $(GOLANGCI_LINT_VERSION)."; \ diff --git a/README.md b/README.md index 418d4bd..725e81c 100644 --- a/README.md +++ b/README.md @@ -24,9 +24,11 @@ The GPTScript instance allows the caller to run gptscript files, tools, and othe When creating a `GTPScript` instance, you can pass the following global options. These options are also available as run `Options`. Anything specified as a run option will take precedence over the global option. +- `CacheDir`: The directory to use for caching. Default (""), which uses the default cache directory. - `APIKey`: Specify an OpenAI API key for authenticating requests - `BaseURL`: A base URL for an OpenAI compatible API (the default is `https://api.openai.com/v1`) -- `DefaultModel`: The default model to use for OpenAI requests +- `DefaultModel`: The default model to use for chat completion requests +- `DefaultModelProvider`: The default model provider to use for chat completion requests - `Env`: Supply the environment variables. Supplying anything here means that nothing from the environment is used. The default is `os.Environ()`. Supplying `Env` at the run/evaluate level will be treated as "additional." ## Run Options @@ -46,32 +48,6 @@ As noted above, the Global Options are also available to specify here. These opt ## Functions -### listTools - -Lists all the available built-in tools. - -**Usage:** - -```go -package main - -import ( - "context" - - "github.com/gptscript-ai/go-gptscript" -) - -func listTools(ctx context.Context) (string, error) { - g, err := gptscript.NewGPTScript(gptscript.GlobalOptions{}) - if err != nil { - return "", err - } - defer g.Close() - - return g.ListTools(ctx) -} -``` - ### listModels Lists all the available models, returns a list. diff --git a/credentials.go b/credentials.go new file mode 100644 index 0000000..ca04c6c --- /dev/null +++ b/credentials.go @@ -0,0 +1,28 @@ +package gptscript + +import "time" + +type CredentialType string + +const ( + CredentialTypeTool CredentialType = "tool" + CredentialTypeModelProvider CredentialType = "modelProvider" +) + +type Credential struct { + Context string `json:"context"` + ToolName string `json:"toolName"` + Type CredentialType `json:"type"` + Env map[string]string `json:"env"` + Ephemeral bool `json:"ephemeral,omitempty"` + CheckParam string `json:"checkParam"` + ExpiresAt *time.Time `json:"expiresAt"` + RefreshToken string `json:"refreshToken"` +} + +type CredentialRequest struct { + Content string `json:"content"` + AllContexts bool `json:"allContexts"` + Context []string `json:"context"` + Name string `json:"name"` +} diff --git a/datasets.go b/datasets.go new file mode 100644 index 0000000..fb408c4 --- /dev/null +++ b/datasets.go @@ -0,0 +1,150 @@ +package gptscript + +import ( + "context" + "encoding/json" + "fmt" +) + +type DatasetElementMeta struct { + Name string `json:"name"` + Description string `json:"description"` +} + +type DatasetElement struct { + DatasetElementMeta `json:",inline"` + Contents string `json:"contents"` + BinaryContents []byte `json:"binaryContents"` +} + +type DatasetMeta struct { + ID string `json:"id"` + Name string `json:"name"` + Description string `json:"description"` +} + +type datasetRequest struct { + Input string `json:"input"` + DatasetTool string `json:"datasetTool"` + Env []string `json:"env"` +} + +type addDatasetElementsArgs struct { + DatasetID string `json:"datasetID"` + Name string `json:"name"` + Description string `json:"description"` + Elements []DatasetElement `json:"elements"` +} + +type listDatasetElementArgs struct { + DatasetID string `json:"datasetID"` +} + +type getDatasetElementArgs struct { + DatasetID string `json:"datasetID"` + Element string `json:"name"` +} + +func (g *GPTScript) ListDatasets(ctx context.Context) ([]DatasetMeta, error) { + out, err := g.runBasicCommand(ctx, "datasets", datasetRequest{ + Input: "{}", + DatasetTool: g.globalOpts.DatasetTool, + Env: g.globalOpts.Env, + }) + if err != nil { + return nil, err + } + + var datasets []DatasetMeta + if err = json.Unmarshal([]byte(out), &datasets); err != nil { + return nil, err + } + return datasets, nil +} + +type DatasetOptions struct { + Name, Description string +} + +func (g *GPTScript) CreateDatasetWithElements(ctx context.Context, elements []DatasetElement, options ...DatasetOptions) (string, error) { + return g.AddDatasetElements(ctx, "", elements, options...) +} + +func (g *GPTScript) AddDatasetElements(ctx context.Context, datasetID string, elements []DatasetElement, options ...DatasetOptions) (string, error) { + args := addDatasetElementsArgs{ + DatasetID: datasetID, + Elements: elements, + } + + for _, opt := range options { + if opt.Name != "" { + args.Name = opt.Name + } + if opt.Description != "" { + args.Description = opt.Description + } + } + + argsJSON, err := json.Marshal(args) + if err != nil { + return "", fmt.Errorf("failed to marshal element args: %w", err) + } + + return g.runBasicCommand(ctx, "datasets/add-elements", datasetRequest{ + Input: string(argsJSON), + DatasetTool: g.globalOpts.DatasetTool, + Env: g.globalOpts.Env, + }) +} + +func (g *GPTScript) ListDatasetElements(ctx context.Context, datasetID string) ([]DatasetElementMeta, error) { + args := listDatasetElementArgs{ + DatasetID: datasetID, + } + argsJSON, err := json.Marshal(args) + if err != nil { + return nil, fmt.Errorf("failed to marshal element args: %w", err) + } + + out, err := g.runBasicCommand(ctx, "datasets/list-elements", datasetRequest{ + Input: string(argsJSON), + DatasetTool: g.globalOpts.DatasetTool, + Env: g.globalOpts.Env, + }) + if err != nil { + return nil, err + } + + var elements []DatasetElementMeta + if err = json.Unmarshal([]byte(out), &elements); err != nil { + return nil, err + } + return elements, nil +} + +func (g *GPTScript) GetDatasetElement(ctx context.Context, datasetID, elementName string) (DatasetElement, error) { + args := getDatasetElementArgs{ + DatasetID: datasetID, + Element: elementName, + } + argsJSON, err := json.Marshal(args) + if err != nil { + return DatasetElement{}, fmt.Errorf("failed to marshal element args: %w", err) + } + + out, err := g.runBasicCommand(ctx, "datasets/get-element", datasetRequest{ + Input: string(argsJSON), + DatasetTool: g.globalOpts.DatasetTool, + Env: g.globalOpts.Env, + }) + if err != nil { + return DatasetElement{}, err + } + + var element DatasetElement + if err = json.Unmarshal([]byte(out), &element); err != nil { + return DatasetElement{}, err + } + + return element, nil +} diff --git a/datasets_test.go b/datasets_test.go new file mode 100644 index 0000000..c1f1a92 --- /dev/null +++ b/datasets_test.go @@ -0,0 +1,98 @@ +package gptscript + +import ( + "context" + "os" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestDatasets(t *testing.T) { + workspaceID, err := g.CreateWorkspace(context.Background(), "directory") + require.NoError(t, err) + + client, err := NewGPTScript(GlobalOptions{ + OpenAIAPIKey: os.Getenv("OPENAI_API_KEY"), + Env: append(os.Environ(), "GPTSCRIPT_WORKSPACE_ID="+workspaceID), + }) + require.NoError(t, err) + + defer func() { + _ = g.DeleteWorkspace(context.Background(), workspaceID) + }() + + datasetID, err := client.CreateDatasetWithElements(context.Background(), []DatasetElement{ + { + DatasetElementMeta: DatasetElementMeta{ + Name: "test-element-1", + Description: "This is a test element 1", + }, + Contents: "This is the content 1", + }, + }, DatasetOptions{ + Name: "test-dataset", + Description: "this is a test dataset", + }) + require.NoError(t, err) + + // Add three more elements + _, err = client.AddDatasetElements(context.Background(), datasetID, []DatasetElement{ + { + DatasetElementMeta: DatasetElementMeta{ + Name: "test-element-2", + Description: "This is a test element 2", + }, + Contents: "This is the content 2", + }, + { + DatasetElementMeta: DatasetElementMeta{ + Name: "test-element-3", + Description: "This is a test element 3", + }, + Contents: "This is the content 3", + }, + { + DatasetElementMeta: DatasetElementMeta{ + Name: "binary-element", + Description: "this element has binary contents", + }, + BinaryContents: []byte("binary contents"), + }, + }) + require.NoError(t, err) + + // Get the first element + element, err := client.GetDatasetElement(context.Background(), datasetID, "test-element-1") + require.NoError(t, err) + require.Equal(t, "test-element-1", element.Name) + require.Equal(t, "This is a test element 1", element.Description) + require.Equal(t, "This is the content 1", element.Contents) + + // Get the third element + element, err = client.GetDatasetElement(context.Background(), datasetID, "test-element-3") + require.NoError(t, err) + require.Equal(t, "test-element-3", element.Name) + require.Equal(t, "This is a test element 3", element.Description) + require.Equal(t, "This is the content 3", element.Contents) + + // Get the binary element + element, err = client.GetDatasetElement(context.Background(), datasetID, "binary-element") + require.NoError(t, err) + require.Equal(t, "binary-element", element.Name) + require.Equal(t, "this element has binary contents", element.Description) + require.Equal(t, []byte("binary contents"), element.BinaryContents) + + // List elements in the dataset + elements, err := client.ListDatasetElements(context.Background(), datasetID) + require.NoError(t, err) + require.Equal(t, 4, len(elements)) + + // List datasets + datasets, err := client.ListDatasets(context.Background()) + require.NoError(t, err) + require.Equal(t, 1, len(datasets)) + require.Equal(t, datasetID, datasets[0].ID) + require.Equal(t, "test-dataset", datasets[0].Name) + require.Equal(t, "this is a test dataset", datasets[0].Description) +} diff --git a/frame.go b/frame.go index e3f37f8..4b77860 100644 --- a/frame.go +++ b/frame.go @@ -49,17 +49,30 @@ type RunFrame struct { Type EventType `json:"type"` } +type CallFrames map[string]CallFrame + +func (c CallFrames) ParentCallFrame() CallFrame { + for _, call := range c { + if call.ParentID == "" && call.ToolCategory == NoCategory { + return call + } + } + return CallFrame{} +} + type CallFrame struct { CallContext `json:",inline"` - Type EventType `json:"type"` - Start time.Time `json:"start"` - End time.Time `json:"end"` - Input string `json:"input"` - Output []Output `json:"output"` - Usage Usage `json:"usage"` - LLMRequest any `json:"llmRequest"` - LLMResponse any `json:"llmResponse"` + Type EventType `json:"type"` + Start time.Time `json:"start"` + End time.Time `json:"end"` + Input string `json:"input"` + Output []Output `json:"output"` + Usage Usage `json:"usage"` + ChatResponseCached bool `json:"chatResponseCached"` + ToolResults int `json:"toolResults"` + LLMRequest any `json:"llmRequest"` + LLMResponse any `json:"llmResponse"` } type Usage struct { @@ -103,17 +116,32 @@ type InputContext struct { Content string `json:"content,omitempty"` } +type Prompt struct { + Message string `json:"message,omitempty"` + Fields Fields `json:"fields,omitempty"` + Sensitive bool `json:"sensitive,omitempty"` + Metadata map[string]string `json:"metadata,omitempty"` +} + +type Field struct { + Name string `json:"name,omitempty"` + Sensitive *bool `json:"sensitive,omitempty"` + Description string `json:"description,omitempty"` + Options []string `json:"options,omitempty"` +} + +type Fields []Field + type PromptFrame struct { - ID string `json:"id,omitempty"` - Type EventType `json:"type,omitempty"` - Time time.Time `json:"time,omitempty"` - Message string `json:"message,omitempty"` - Fields []string `json:"fields,omitempty"` - Sensitive bool `json:"sensitive,omitempty"` + Prompt + ID string `json:"id,omitempty"` + Type EventType `json:"type,omitempty"` + Time time.Time `json:"time,omitempty"` } func (p *PromptFrame) String() string { return fmt.Sprintf(`Message: %s Fields: %v -Sensitive: %v`, p.Message, p.Fields, p.Sensitive) +Sensitive: %v`, p.Message, p.Fields, p.Sensitive, + ) } diff --git a/go.mod b/go.mod index fc551c5..a625857 100644 --- a/go.mod +++ b/go.mod @@ -1,16 +1,15 @@ module github.com/gptscript-ai/go-gptscript -go 1.22.2 +go 1.24.2 -require github.com/getkin/kin-openapi v0.124.0 +require ( + github.com/modelcontextprotocol/go-sdk v0.2.0 + github.com/stretchr/testify v1.10.0 +) require ( - github.com/go-openapi/jsonpointer v0.20.2 // indirect - github.com/go-openapi/swag v0.22.8 // indirect - github.com/invopop/yaml v0.2.0 // indirect - github.com/josharian/intern v1.0.0 // indirect - github.com/mailru/easyjson v0.7.7 // indirect - github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 // indirect - github.com/perimeterx/marshmallow v1.1.5 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 9d94c4d..10b0c7a 100644 --- a/go.sum +++ b/go.sum @@ -1,38 +1,20 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/getkin/kin-openapi v0.124.0 h1:VSFNMB9C9rTKBnQ/fpyDU8ytMTr4dWI9QovSKj9kz/M= -github.com/getkin/kin-openapi v0.124.0/go.mod h1:wb1aSZA/iWmorQP9KTAS/phLj/t17B5jT7+fS8ed9NM= -github.com/go-openapi/jsonpointer v0.20.2 h1:mQc3nmndL8ZBzStEo3JYF8wzmeWffDH4VbXz58sAx6Q= -github.com/go-openapi/jsonpointer v0.20.2/go.mod h1:bHen+N0u1KEO3YlmqOjTT9Adn1RfD91Ar825/PuiRVs= -github.com/go-openapi/swag v0.22.8 h1:/9RjDSQ0vbFR+NyjGMkFTsA1IA0fmhKSThmfGZjicbw= -github.com/go-openapi/swag v0.22.8/go.mod h1:6QT22icPLEqAM/z/TChgb4WAveCHF92+2gF0CNjHpPI= -github.com/go-test/deep v1.0.8 h1:TDsG77qcSprGbC6vTN8OuXp5g+J+b5Pcguhf7Zt61VM= -github.com/go-test/deep v1.0.8/go.mod h1:5C2ZWiW0ErCdrYzpqxLbTX7MG14M9iiw8DgHncVwcsE= -github.com/invopop/yaml v0.2.0 h1:7zky/qH+O0DwAyoobXUqvVBwgBFRxKoQ/3FjcVpjTMY= -github.com/invopop/yaml v0.2.0/go.mod h1:2XuRLgs/ouIrW3XNzuNj7J3Nvu/Dig5MXvbCEdiBN3Q= -github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY= -github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= -github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= -github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= -github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= -github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= -github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0= -github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= -github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 h1:RWengNIwukTxcDr9M+97sNutRR1RKhG96O6jWumTTnw= -github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826/go.mod h1:TaXosZuwdSHYgviHp1DAtfrULt5eUgsSMsZf+YrPgl8= -github.com/perimeterx/marshmallow v1.1.5 h1:a2LALqQ1BlHM8PZblsDdidgv1mWi1DgC2UmX50IvK2s= -github.com/perimeterx/marshmallow v1.1.5/go.mod h1:dsXbUu8CRzfYP5a87xpp0xq9S3u0Vchtcl8we9tYaXw= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/kr/pretty v0.2.1 h1:Fmg33tUaq4/8ym9TJN1x7sLJnHVwhP33CNkpYV/7rwI= +github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/modelcontextprotocol/go-sdk v0.2.0 h1:PESNYOmyM1c369tRkzXLY5hHrazj8x9CY1Xu0fLCryM= +github.com/modelcontextprotocol/go-sdk v0.2.0/go.mod h1:0sL9zUKKs2FTTkeCCVnKqbLJTw5TScefPAzojjU459E= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= -github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= -github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= -github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= -github.com/ugorji/go/codec v1.2.7 h1:YPXUKf7fYbp/y8xloBqZOw2qaVggbfwMlI8WM3wZUJ0= -github.com/ugorji/go/codec v1.2.7/go.mod h1:WGN1fab3R1fzQlVQTkfxVtIBhWDRqOviHU95kRgeqEY= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= -gopkg.in/yaml.v3 v3.0.0/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/gptscript.go b/gptscript.go index 8eb688c..3356982 100644 --- a/gptscript.go +++ b/gptscript.go @@ -2,7 +2,10 @@ package gptscript import ( "bufio" + "bytes" + "compress/gzip" "context" + "encoding/base64" "encoding/json" "fmt" "io" @@ -25,30 +28,33 @@ var ( const relativeToBinaryPath = "" type GPTScript struct { - url string + globalOpts GlobalOptions } -func NewGPTScript(opts GlobalOptions) (*GPTScript, error) { +func NewGPTScript(opts ...GlobalOptions) (*GPTScript, error) { + opt := completeGlobalOptions(opts...) + if opt.Env == nil { + opt.Env = os.Environ() + } + + opt.Env = append(opt.Env, opt.toEnv()...) + lock.Lock() defer lock.Unlock() gptscriptCount++ - disableServer := os.Getenv("GPTSCRIPT_DISABLE_SERVER") == "true" - - if serverURL == "" && disableServer { + startSDK := serverProcess == nil && serverURL == "" && opt.URL == "" + if serverURL == "" { serverURL = os.Getenv("GPTSCRIPT_URL") + startSDK = startSDK && serverURL == "" } - if serverProcessCancel == nil && !disableServer { + if startSDK { ctx, cancel := context.WithCancel(context.Background()) in, _ := io.Pipe() - serverProcess = exec.CommandContext(ctx, getCommand(), "sys.sdkserver", "--listen-address", serverURL) - if opts.Env == nil { - opts.Env = os.Environ() - } - - serverProcess.Env = append(opts.Env[:], opts.toEnv()...) + serverProcess = exec.CommandContext(ctx, getCommand(), "sys.sdkserver", "--listen-address", "127.0.0.1:0") + serverProcess.Env = opt.Env[:] serverProcess.Stdin = in stdErr, err := serverProcess.StderrPipe() @@ -88,7 +94,27 @@ func NewGPTScript(opts GlobalOptions) (*GPTScript, error) { serverURL = strings.TrimSpace(serverURL) } - return &GPTScript{url: "http://" + serverURL}, nil + + if opt.URL == "" { + opt.URL = serverURL + } + + if !strings.HasPrefix(opt.URL, "http://") && !strings.HasPrefix(opt.URL, "https://") { + opt.URL = "http://" + opt.URL + } + + opt.Env = append(opt.Env, "GPTSCRIPT_URL="+opt.URL) + + if opt.Token == "" { + opt.Token = os.Getenv("GPTSCRIPT_TOKEN") + } + if opt.Token != "" { + opt.Env = append(opt.Env, "GPTSCRIPT_TOKEN="+opt.Token) + } + + return &GPTScript{ + globalOpts: opt, + }, nil } func readAddress(stdErr io.Reader) (string, error) { @@ -105,6 +131,10 @@ func readAddress(stdErr io.Reader) (string, error) { return addr, nil } +func (g *GPTScript) URL() string { + return g.globalOpts.URL +} + func (g *GPTScript) Close() { lock.Lock() defer lock.Unlock() @@ -117,8 +147,10 @@ func (g *GPTScript) Close() { } func (g *GPTScript) Evaluate(ctx context.Context, opts Options, tools ...ToolDef) (*Run, error) { + opts.GlobalOptions = completeGlobalOptions(g.globalOpts, opts.GlobalOptions) return (&Run{ - url: g.url, + url: opts.URL, + token: opts.Token, requestPath: "evaluate", state: Creating, opts: opts, @@ -127,8 +159,10 @@ func (g *GPTScript) Evaluate(ctx context.Context, opts Options, tools ...ToolDef } func (g *GPTScript) Run(ctx context.Context, toolPath string, opts Options) (*Run, error) { + opts.GlobalOptions = completeGlobalOptions(g.globalOpts, opts.GlobalOptions) return (&Run{ - url: g.url, + url: opts.URL, + token: opts.Token, requestPath: "run", state: Creating, opts: opts, @@ -136,9 +170,23 @@ func (g *GPTScript) Run(ctx context.Context, toolPath string, opts Options) (*Ru }).NextChat(ctx, opts.Input) } +func (g *GPTScript) AbortRun(ctx context.Context, run *Run) error { + _, err := g.runBasicCommand(ctx, "abort/"+run.id, (map[string]any)(nil)) + return err +} + +type ParseOptions struct { + DisableCache bool +} + // Parse will parse the given file into an array of Nodes. -func (g *GPTScript) Parse(ctx context.Context, fileName string) ([]Node, error) { - out, err := g.runBasicCommand(ctx, "parse", map[string]any{"file": fileName}) +func (g *GPTScript) Parse(ctx context.Context, fileName string, opts ...ParseOptions) ([]Node, error) { + var disableCache bool + for _, opt := range opts { + disableCache = disableCache || opt.DisableCache + } + + out, err := g.runBasicCommand(ctx, "parse", map[string]any{"file": fileName, "disableCache": disableCache}) if err != nil { return nil, err } @@ -155,8 +203,8 @@ func (g *GPTScript) Parse(ctx context.Context, fileName string) ([]Node, error) return doc.Nodes, nil } -// ParseTool will parse the given string into a tool. -func (g *GPTScript) ParseTool(ctx context.Context, toolDef string) ([]Node, error) { +// ParseContent will parse the given string into a tool. +func (g *GPTScript) ParseContent(ctx context.Context, toolDef string) ([]Node, error) { out, err := g.runBasicCommand(ctx, "parse", map[string]any{"content": toolDef}) if err != nil { return nil, err @@ -188,6 +236,53 @@ func (g *GPTScript) Fmt(ctx context.Context, nodes []Node) (string, error) { return out, nil } +type LoadOptions struct { + DisableCache bool + SubTool string +} + +// LoadFile will load the given file into a Program. +func (g *GPTScript) LoadFile(ctx context.Context, fileName string, opts ...LoadOptions) (*Program, error) { + return g.load(ctx, map[string]any{"file": fileName}, opts...) +} + +// LoadContent will load the given content into a Program. +func (g *GPTScript) LoadContent(ctx context.Context, content string, opts ...LoadOptions) (*Program, error) { + return g.load(ctx, map[string]any{"content": content}, opts...) +} + +// LoadTools will load the given tools into a Program. +func (g *GPTScript) LoadTools(ctx context.Context, toolDefs []ToolDef, opts ...LoadOptions) (*Program, error) { + return g.load(ctx, map[string]any{"toolDefs": toolDefs}, opts...) +} + +func (g *GPTScript) load(ctx context.Context, payload map[string]any, opts ...LoadOptions) (*Program, error) { + for _, opt := range opts { + if opt.DisableCache { + payload["disableCache"] = true + } + if opt.SubTool != "" { + payload["subTool"] = opt.SubTool + } + } + + out, err := g.runBasicCommand(ctx, "load", payload) + if err != nil { + return nil, err + } + + type loadResponse struct { + Program *Program `json:"program"` + } + + prg := new(loadResponse) + if err = json.Unmarshal([]byte(out), prg); err != nil { + return nil, err + } + + return prg.Program, nil +} + // Version will return the output of `gptscript --version` func (g *GPTScript) Version(ctx context.Context) (string, error) { out, err := g.runBasicCommand(ctx, "version", nil) @@ -198,24 +293,64 @@ func (g *GPTScript) Version(ctx context.Context) (string, error) { return out, nil } -// ListTools will list all the available tools. -func (g *GPTScript) ListTools(ctx context.Context) (string, error) { - out, err := g.runBasicCommand(ctx, "list-tools", nil) - if err != nil { - return "", err - } +type ListModelsOptions struct { + Providers []string + CredentialOverrides []string +} - return out, nil +type Model struct { + CreatedAt int64 `json:"created"` + ID string `json:"id"` + Object string `json:"object"` + OwnedBy string `json:"owned_by"` + Permission []Permission `json:"permission"` + Root string `json:"root"` + Parent string `json:"parent"` + Metadata map[string]string `json:"metadata"` +} + +type Permission struct { + CreatedAt int64 `json:"created"` + ID string `json:"id"` + Object string `json:"object"` + AllowCreateEngine bool `json:"allow_create_engine"` + AllowSampling bool `json:"allow_sampling"` + AllowLogprobs bool `json:"allow_logprobs"` + AllowSearchIndices bool `json:"allow_search_indices"` + AllowView bool `json:"allow_view"` + AllowFineTuning bool `json:"allow_fine_tuning"` + Organization string `json:"organization"` + Group interface{} `json:"group"` + IsBlocking bool `json:"is_blocking"` } // ListModels will list all the available models. -func (g *GPTScript) ListModels(ctx context.Context) ([]string, error) { - out, err := g.runBasicCommand(ctx, "list-models", nil) +func (g *GPTScript) ListModels(ctx context.Context, opts ...ListModelsOptions) ([]Model, error) { + var o ListModelsOptions + for _, opt := range opts { + o.Providers = append(o.Providers, opt.Providers...) + o.CredentialOverrides = append(o.CredentialOverrides, opt.CredentialOverrides...) + } + + if g.globalOpts.DefaultModelProvider != "" { + o.Providers = append(o.Providers, g.globalOpts.DefaultModelProvider) + } + + out, err := g.runBasicCommand(ctx, "list-models", map[string]any{ + "providers": o.Providers, + "env": g.globalOpts.Env, + "credentialOverrides": o.CredentialOverrides, + }) if err != nil { return nil, err } - return strings.Split(strings.TrimSpace(out), "\n"), nil + var models []Model + if err = json.Unmarshal([]byte(out), &models); err != nil { + return nil, fmt.Errorf("failed to parse models: %w", err) + } + + return models, nil } func (g *GPTScript) Confirm(ctx context.Context, resp AuthResponse) error { @@ -228,9 +363,75 @@ func (g *GPTScript) PromptResponse(ctx context.Context, resp PromptResponse) err return err } +type ListCredentialsOptions struct { + CredentialContexts []string + AllContexts bool +} + +func (g *GPTScript) ListCredentials(ctx context.Context, opts ListCredentialsOptions) ([]Credential, error) { + req := CredentialRequest{} + if opts.AllContexts { + req.AllContexts = true + } else if len(opts.CredentialContexts) > 0 { + req.Context = opts.CredentialContexts + } else { + req.Context = []string{"default"} + } + + out, err := g.runBasicCommand(ctx, "credentials", req) + if err != nil { + return nil, err + } + + var creds []Credential + if err = json.Unmarshal([]byte(out), &creds); err != nil { + return nil, err + } + return creds, nil +} + +func (g *GPTScript) CreateCredential(ctx context.Context, cred Credential) error { + credJSON, err := json.Marshal(cred) + if err != nil { + return fmt.Errorf("failed to marshal credential: %w", err) + } + + _, err = g.runBasicCommand(ctx, "credentials/create", CredentialRequest{Content: string(credJSON)}) + return err +} + +func (g *GPTScript) RecreateAllCredentials(ctx context.Context) error { + _, err := g.runBasicCommand(ctx, "credentials/recreate-all", struct{}{}) + return err +} + +func (g *GPTScript) RevealCredential(ctx context.Context, credCtxs []string, name string) (Credential, error) { + out, err := g.runBasicCommand(ctx, "credentials/reveal", CredentialRequest{ + Context: credCtxs, + Name: name, + }) + if err != nil { + return Credential{}, err + } + + var cred Credential + if err = json.Unmarshal([]byte(out), &cred); err != nil { + return Credential{}, err + } + return cred, nil +} + +func (g *GPTScript) DeleteCredential(ctx context.Context, credCtx, name string) error { + _, err := g.runBasicCommand(ctx, "credentials/delete", CredentialRequest{ + Context: []string{credCtx}, // Only one context can be specified for delete operations + Name: name, + }) + return err +} + func (g *GPTScript) runBasicCommand(ctx context.Context, requestPath string, body any) (string, error) { run := &Run{ - url: g.url, + url: g.globalOpts.URL, requestPath: requestPath, state: Creating, basicCommand: true, @@ -276,3 +477,28 @@ func determineProperCommand(dir, bin string) string { slog.Debug("Using gptscript binary: " + bin) return bin } + +func GetEnv(key, def string) string { + v := os.Getenv(key) + if v == "" { + return def + } + + if strings.HasPrefix(v, `{"_gz":"`) && strings.HasSuffix(v, `"}`) { + data, err := base64.StdEncoding.DecodeString(v[8 : len(v)-2]) + if err != nil { + return v + } + gz, err := gzip.NewReader(bytes.NewBuffer(data)) + if err != nil { + return v + } + strBytes, err := io.ReadAll(gz) + if err != nil { + return v + } + return string(strBytes) + } + + return v +} diff --git a/gptscript_test.go b/gptscript_test.go index 1d85ef3..4cf7a6b 100644 --- a/gptscript_test.go +++ b/gptscript_test.go @@ -2,14 +2,19 @@ package gptscript import ( "context" + "errors" "fmt" + "math/rand" "os" "path/filepath" "runtime" + "strconv" "strings" "testing" + "time" - "github.com/getkin/kin-openapi/openapi3" + "github.com/modelcontextprotocol/go-sdk/jsonschema" + "github.com/stretchr/testify/require" ) var g *GPTScript @@ -19,14 +24,22 @@ func TestMain(m *testing.M) { panic("OPENAI_API_KEY or GPTSCRIPT_URL environment variable must be set") } - var err error + // Start an initial GPTScript instance. + // This one doesn't have any options, but it's there to ensure that using another instance works as expected in all cases. + gFirst, err := NewGPTScript(GlobalOptions{}) + if err != nil { + panic(fmt.Sprintf("error creating gptscript: %s", err)) + } + g, err = NewGPTScript(GlobalOptions{OpenAIAPIKey: os.Getenv("OPENAI_API_KEY")}) if err != nil { + gFirst.Close() panic(fmt.Sprintf("error creating gptscript: %s", err)) } exitCode := m.Run() g.Close() + gFirst.Close() os.Exit(exitCode) } @@ -58,19 +71,25 @@ func TestVersion(t *testing.T) { } } -func TestListTools(t *testing.T) { - tools, err := g.ListTools(context.Background()) +func TestListModels(t *testing.T) { + models, err := g.ListModels(context.Background()) if err != nil { - t.Errorf("Error listing tools: %v", err) + t.Errorf("Error listing models: %v", err) } - if len(tools) == 0 { - t.Error("No tools found") + if len(models) == 0 { + t.Error("No models found") } } -func TestListModels(t *testing.T) { - models, err := g.ListModels(context.Background()) +func TestListModelsWithProvider(t *testing.T) { + if os.Getenv("ANTHROPIC_API_KEY") == "" { + t.Skip("ANTHROPIC_API_KEY not set") + } + models, err := g.ListModels(context.Background(), ListModelsOptions{ + Providers: []string{"github.com/gptscript-ai/claude3-anthropic-provider"}, + CredentialOverrides: []string{"github.com/gptscript-ai/claude3-anthropic-provider/credential:ANTHROPIC_API_KEY"}, + }) if err != nil { t.Errorf("Error listing models: %v", err) } @@ -78,9 +97,45 @@ func TestListModels(t *testing.T) { if len(models) == 0 { t.Error("No models found") } + + for _, model := range models { + if !strings.HasPrefix(model.ID, "claude-3-") || !strings.HasSuffix(model.ID, "from github.com/gptscript-ai/claude3-anthropic-provider") { + t.Errorf("Unexpected model name: %s", model.ID) + } + } } -func TestAbortRun(t *testing.T) { +func TestListModelsWithDefaultProvider(t *testing.T) { + if os.Getenv("ANTHROPIC_API_KEY") == "" { + t.Skip("ANTHROPIC_API_KEY not set") + } + g, err := NewGPTScript(GlobalOptions{ + DefaultModelProvider: "github.com/gptscript-ai/claude3-anthropic-provider", + }) + if err != nil { + t.Fatalf("Error creating gptscript: %v", err) + } + defer g.Close() + + models, err := g.ListModels(context.Background(), ListModelsOptions{ + CredentialOverrides: []string{"github.com/gptscript-ai/claude3-anthropic-provider/credential:ANTHROPIC_API_KEY"}, + }) + if err != nil { + t.Errorf("Error listing models: %v", err) + } + + if len(models) == 0 { + t.Error("No models found") + } + + for _, model := range models { + if !strings.HasPrefix(model.ID, "claude-3-") || !strings.HasSuffix(model.ID, "from github.com/gptscript-ai/claude3-anthropic-provider") { + t.Errorf("Unexpected model name: %s", model.ID) + } + } +} + +func TestCancelRun(t *testing.T) { tool := ToolDef{Instructions: "What is the capital of the united states?"} run, err := g.Evaluate(context.Background(), Options{DisableCache: true, IncludeEvents: true}, tool) @@ -92,7 +147,7 @@ func TestAbortRun(t *testing.T) { <-run.Events() if err := run.Close(); err != nil { - t.Errorf("Error aborting run: %v", err) + t.Errorf("Error canceling run: %v", err) } if run.State() != Error { @@ -104,12 +159,81 @@ func TestAbortRun(t *testing.T) { } } +func TestAbortChatCompletionRun(t *testing.T) { + tool := ToolDef{Instructions: "Generate a real long essay about the meaning of life."} + + run, err := g.Evaluate(context.Background(), Options{DisableCache: true, IncludeEvents: true}, tool) + if err != nil { + t.Errorf("Error executing tool: %v", err) + } + + // Abort the run after the first event from the LLM + for e := range run.Events() { + if e.Call != nil && e.Call.Type == EventTypeCallProgress && len(e.Call.Output) > 0 && e.Call.Output[0].Content != "Waiting for model response..." { + break + } + } + + if err := g.AbortRun(context.Background(), run); err != nil { + t.Errorf("Error aborting run: %v", err) + } + + // Wait for run to stop + for range run.Events() { + continue + } + + if run.State() != Finished { + t.Errorf("Unexpected run state: %s", run.State()) + } + + if out, err := run.Text(); err != nil { + t.Errorf("Error reading output: %v", err) + } else if strings.TrimSpace(out) != "ABORTED BY USER" && !strings.HasSuffix(out, "\nABORTED BY USER") { + t.Errorf("Unexpected output: %s", out) + } +} + +func TestAbortCommandRun(t *testing.T) { + tool := ToolDef{Instructions: "#!/usr/bin/env bash\necho Hello, world!\nsleep 5\necho Hello, again!\nsleep 5"} + + run, err := g.Evaluate(context.Background(), Options{DisableCache: true, IncludeEvents: true}, tool) + if err != nil { + t.Errorf("Error executing tool: %v", err) + } + + // Abort the run after the first event. + for e := range run.Events() { + if e.Call != nil && e.Call.Type == EventTypeChat { + time.Sleep(2 * time.Second) + break + } + } + + if err := g.AbortRun(context.Background(), run); err != nil { + t.Errorf("Error aborting run: %v", err) + } + + // Wait for run to stop + for range run.Events() { + continue + } + + if run.State() != Finished { + t.Errorf("Unexpected run state: %s", run.State()) + } + + if out, err := run.Text(); err != nil { + t.Errorf("Error reading output: %v", err) + } else if !strings.Contains(out, "Hello, world!") || strings.Contains(out, "Hello, again!") || !strings.HasSuffix(out, "\nABORTED BY USER") { + t.Errorf("Unexpected output: %s", out) + } +} + func TestSimpleEvaluate(t *testing.T) { tool := ToolDef{Instructions: "What is the capital of the united states?"} - run, err := g.Evaluate(context.Background(), Options{ - GlobalOptions: GlobalOptions{}, - }, tool) + run, err := g.Evaluate(context.Background(), Options{DisableCache: true}, tool) if err != nil { t.Errorf("Error executing tool: %v", err) } @@ -136,6 +260,17 @@ func TestSimpleEvaluate(t *testing.T) { if run.Program() == nil { t.Error("Run program not set") } + + var promptTokens, completionTokens, totalTokens int + for _, c := range run.calls { + promptTokens += c.Usage.PromptTokens + completionTokens += c.Usage.CompletionTokens + totalTokens += c.Usage.TotalTokens + } + + if promptTokens == 0 || completionTokens == 0 || totalTokens == 0 { + t.Errorf("Usage not set: %d, %d, %d", promptTokens, completionTokens, totalTokens) + } } func TestEvaluateWithContext(t *testing.T) { @@ -146,7 +281,7 @@ func TestEvaluateWithContext(t *testing.T) { tool := ToolDef{ Instructions: "What is the capital of the united states?", - Context: []string{ + Tools: []string{ wd + "/test/acorn-labs-context.gpt", }, } @@ -231,6 +366,16 @@ func TestEvaluateWithToolList(t *testing.T) { if !strings.Contains(out, "hello there") { t.Errorf("Unexpected output: %s", out) } + + // In this case, we expect the total number of tool results to be 1 + var toolResults int + for _, c := range run.calls { + toolResults += c.ToolResults + } + + if toolResults != 1 { + t.Errorf("Unexpected number of tool results: %d", toolResults) + } } func TestEvaluateWithToolListAndSubTool(t *testing.T) { @@ -307,6 +452,54 @@ func TestStreamEvaluate(t *testing.T) { } } +func TestSimpleRun(t *testing.T) { + wd, err := os.Getwd() + if err != nil { + t.Fatalf("Error getting working directory: %v", err) + } + + run, err := g.Run(context.Background(), wd+"/test/catcher.gpt", Options{}) + if err != nil { + t.Fatalf("Error executing file: %v", err) + } + + out, err := run.Text() + if err != nil { + t.Errorf("Error reading output: %v", err) + } + + if !strings.Contains(out, "Salinger") { + t.Errorf("Unexpected output: %s", out) + } + + if len(run.ErrorOutput()) != 0 { + t.Error("Should have no stderr output") + } + + // Run it a second time, ensuring the same output and that a cached response is used + run, err = g.Run(context.Background(), wd+"/test/catcher.gpt", Options{}) + if err != nil { + t.Fatalf("Error executing file: %v", err) + } + + secondOut, err := run.Text() + if err != nil { + t.Errorf("Error reading output: %v", err) + } + + if secondOut != out { + t.Errorf("Unexpected output on second run: %s != %s", out, secondOut) + } + + // In this case, we expect a single call and that the response is cached + for _, c := range run.calls { + if !c.ChatResponseCached { + t.Error("Chat response should be cached") + } + break + } +} + func TestStreamRun(t *testing.T) { wd, err := os.Getwd() if err != nil { @@ -345,6 +538,47 @@ func TestStreamRun(t *testing.T) { } } +func TestRestartFailedRun(t *testing.T) { + shebang := "#!/bin/bash" + instructions := "%s\nexit ${EXIT_CODE}" + if runtime.GOOS == "windows" { + shebang = "#!/usr/bin/env powershell.exe" + instructions = "%s\nexit $env:EXIT_CODE" + } + instructions = fmt.Sprintf(instructions, shebang) + tools := []ToolDef{ + { + Instructions: "say hello", + Tools: []string{"my-context"}, + }, + { + Name: "my-context", + Type: "context", + Instructions: instructions, + }, + } + run, err := g.Evaluate(context.Background(), Options{GlobalOptions: GlobalOptions{Env: []string{"EXIT_CODE=1"}}, DisableCache: true}, tools...) + if err != nil { + t.Fatalf("Error executing tool: %v", err) + } + + _, err = run.Text() + if err == nil { + t.Errorf("Expected error but got nil") + } + + run.opts.Env = nil + run, err = run.NextChat(context.Background(), "") + if err != nil { + t.Fatalf("Error executing next run: %v", err) + } + + _, err = run.Text() + if err != nil { + t.Errorf("Error reading output: %v", err) + } +} + func TestCredentialOverride(t *testing.T) { wd, err := os.Getwd() if err != nil { @@ -404,8 +638,60 @@ func TestParseSimpleFile(t *testing.T) { } } +func TestParseEmptyFile(t *testing.T) { + wd, err := os.Getwd() + if err != nil { + t.Fatalf("Error getting working directory: %v", err) + } + + tools, err := g.Parse(context.Background(), wd+"/test/empty.gpt") + if err != nil { + t.Errorf("Error parsing file: %v", err) + } + + if len(tools) != 0 { + t.Fatalf("Unexpected number of tools: %d", len(tools)) + } +} + +func TestParseFileWithMetadata(t *testing.T) { + wd, err := os.Getwd() + if err != nil { + t.Fatalf("Error getting working directory: %v", err) + } + + tools, err := g.Parse(context.Background(), wd+"/test/parse-with-metadata.gpt") + if err != nil { + t.Errorf("Error parsing file: %v", err) + } + + if len(tools) != 2 { + t.Fatalf("Unexpected number of tools: %d", len(tools)) + } + + if tools[0].ToolNode == nil { + t.Fatalf("No tool node found") + } + + if !strings.Contains(tools[0].ToolNode.Tool.Instructions, "requests.get(") { + t.Errorf("Unexpected instructions: %s", tools[0].ToolNode.Tool.Instructions) + } + + if tools[0].ToolNode.Tool.MetaData["requirements.txt"] != "requests" { + t.Errorf("Unexpected metadata: %s", tools[0].ToolNode.Tool.MetaData["requirements.txt"]) + } + + if tools[1].TextNode == nil { + t.Fatalf("No text node found") + } + + if tools[1].TextNode.Fmt != "metadata:foo:requirements.txt" { + t.Errorf("Unexpected text: %s", tools[1].TextNode.Fmt) + } +} + func TestParseTool(t *testing.T) { - tools, err := g.ParseTool(context.Background(), "echo hello") + tools, err := g.ParseContent(context.Background(), "echo hello") if err != nil { t.Errorf("Error parsing tool: %v", err) } @@ -423,8 +709,19 @@ func TestParseTool(t *testing.T) { } } +func TestEmptyParseTool(t *testing.T) { + tools, err := g.ParseContent(context.Background(), "") + if err != nil { + t.Errorf("Error parsing tool: %v", err) + } + + if len(tools) != 0 { + t.Fatalf("Unexpected number of tools: %d", len(tools)) + } +} + func TestParseToolWithTextNode(t *testing.T) { - tools, err := g.ParseTool(context.Background(), "echo hello\n---\n!markdown\nhello") + tools, err := g.ParseContent(context.Background(), "echo hello\n---\n!markdown\nhello") if err != nil { t.Errorf("Error parsing tool: %v", err) } @@ -445,7 +742,7 @@ func TestParseToolWithTextNode(t *testing.T) { t.Fatalf("No text node found") } - if tools[1].TextNode.Text != "hello\n" { + if strings.TrimSpace(tools[1].TextNode.Text) != "hello" { t.Errorf("Unexpected text: %s", tools[1].TextNode.Text) } if tools[1].TextNode.Fmt != "markdown" { @@ -471,14 +768,12 @@ func TestFmt(t *testing.T) { ToolDef: ToolDef{ Name: "echo", Instructions: "#!/bin/bash\necho hello there", - }, - Arguments: &openapi3.Schema{ - Type: &openapi3.Types{"object"}, - Properties: map[string]*openapi3.SchemaRef{ - "input": { - Value: &openapi3.Schema{ + Arguments: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "input": { Description: "The string input to echo", - Type: &openapi3.Types{"string"}, + Type: "string", }, }, }, @@ -532,14 +827,12 @@ func TestFmtWithTextNode(t *testing.T) { ToolDef: ToolDef{ Instructions: "#!/bin/bash\necho hello there", Name: "echo", - }, - Arguments: &openapi3.Schema{ - Type: &openapi3.Types{"object"}, - Properties: map[string]*openapi3.SchemaRef{ - "input": { - Value: &openapi3.Schema{ + Arguments: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "input": { Description: "The string input to echo", - Type: &openapi3.Types{"string"}, + Type: "string", }, }, }, @@ -619,6 +912,69 @@ func TestToolChat(t *testing.T) { } } +func TestAbortChat(t *testing.T) { + tool := ToolDef{ + Chat: true, + Instructions: "You are a chat bot. Don't finish the conversation until I say 'bye'.", + Tools: []string{"sys.chat.finish"}, + } + + run, err := g.Evaluate(context.Background(), Options{DisableCache: true, IncludeEvents: true}, tool) + if err != nil { + t.Fatalf("Error executing tool: %v", err) + } + inputs := []string{ + "Tell me a joke.", + "What was my first message?", + } + + // Just wait for the chat to start up. + for range run.Events() { + continue + } + + for i, input := range inputs { + run, err = run.NextChat(context.Background(), input) + if err != nil { + t.Fatalf("Error sending next input %q: %v", input, err) + } + + // Abort the run after the first event from the LLM + for e := range run.Events() { + if e.Call != nil && e.Call.Type == EventTypeCallProgress && len(e.Call.Output) > 0 && e.Call.Output[0].Content != "Waiting for model response..." { + break + } + } + + if i == 0 { + if err := g.AbortRun(context.Background(), run); err != nil { + t.Fatalf("Error aborting run: %v", err) + } + } + + // Wait for the run to complete + for range run.Events() { + continue + } + + out, err := run.Text() + if err != nil { + t.Errorf("Error reading output: %s", run.ErrorOutput()) + t.Fatalf("Error reading output: %v", err) + } + + if i == 0 { + if strings.TrimSpace(out) != "ABORTED BY USER" && !strings.HasSuffix(out, "\nABORTED BY USER") { + t.Fatalf("Unexpected output: %s", out) + } + } else { + if !strings.Contains(out, "Tell me a joke") { + t.Errorf("Unexpected output: %s", out) + } + } + } +} + func TestFileChat(t *testing.T) { wd, err := os.Getwd() if err != nil { @@ -631,8 +987,8 @@ func TestFileChat(t *testing.T) { } inputs := []string{ "List the 3 largest of the Great Lakes by volume.", - "What is the volume of the second one in cubic miles?", - "What is the total area of the third one in square miles?", + "What is the second one in the list?", + "What is the third?", } expectedOutputs := []string{ @@ -674,28 +1030,30 @@ func TestToolWithGlobalTools(t *testing.T) { var eventContent string - run, err := g.Run(context.Background(), wd+"/test/global-tools.gpt", Options{DisableCache: true, IncludeEvents: true}) + run, err := g.Run(context.Background(), wd+"/test/global-tools.gpt", Options{DisableCache: true, IncludeEvents: true, CredentialOverrides: []string{"github.com/gptscript-ai/gateway:OPENAI_API_KEY"}}) if err != nil { t.Fatalf("Error executing tool: %v", err) } for e := range run.Events() { if e.Run != nil { - if e.Run.Type == EventTypeRunStart { + switch e.Run.Type { + case EventTypeRunStart: runStartSeen = true - } else if e.Run.Type == EventTypeRunFinish { + case EventTypeRunFinish: runFinishSeen = true } } else if e.Call != nil { - if e.Call.Type == EventTypeCallStart { + switch e.Call.Type { + case EventTypeCallStart: callStartSeen = true - } else if e.Call.Type == EventTypeCallFinish { + case EventTypeCallFinish: callFinishSeen = true for _, o := range e.Call.Output { eventContent += o.Content } - } else if e.Call.Type == EventTypeCallProgress { + case EventTypeCallProgress: callProgressSeen = true } } @@ -726,7 +1084,7 @@ func TestToolWithGlobalTools(t *testing.T) { func TestConfirm(t *testing.T) { var eventContent string tools := ToolDef{ - Instructions: "List the files in the current directory", + Instructions: "List all the files in the current directory. Respond with the names of the files in only the current directory.", Tools: []string{"sys.exec"}, } @@ -819,9 +1177,10 @@ func TestConfirmDeny(t *testing.T) { if confirmCallEvent == nil { t.Fatalf("No confirm call event") + return } - if !strings.Contains(confirmCallEvent.Input, "\"ls\"") { + if !strings.Contains(confirmCallEvent.Input, "ls") { t.Errorf("unexpected confirm input: %s", confirmCallEvent.Input) } @@ -890,6 +1249,7 @@ func TestPrompt(t *testing.T) { if promptFrame == nil { t.Fatalf("No prompt call event") + return } if promptFrame.Sensitive { @@ -904,13 +1264,13 @@ func TestPrompt(t *testing.T) { t.Fatalf("Unexpected number of fields: %d", len(promptFrame.Fields)) } - if promptFrame.Fields[0] != "first name" { - t.Errorf("Unexpected field: %s", promptFrame.Fields[0]) + if promptFrame.Fields[0].Name != "first name" { + t.Errorf("Unexpected field: %s", promptFrame.Fields[0].Name) } if err = g.PromptResponse(context.Background(), PromptResponse{ ID: promptFrame.ID, - Responses: map[string]string{promptFrame.Fields[0]: "Clicky"}, + Responses: map[string]string{promptFrame.Fields[0].Name: "Clicky"}, }); err != nil { t.Errorf("Error responding: %v", err) } @@ -942,6 +1302,70 @@ func TestPrompt(t *testing.T) { } } +func TestPromptWithMetadata(t *testing.T) { + run, err := g.Run(context.Background(), "sys.prompt", Options{IncludeEvents: true, Prompt: true, Input: `{"fields":"first name","metadata":{"key":"value"}}`}) + if err != nil { + t.Errorf("Error executing tool: %v", err) + } + + // Wait for the prompt event + var promptFrame *PromptFrame + for e := range run.Events() { + if e.Prompt != nil { + if e.Prompt.Type == EventTypePrompt { + promptFrame = e.Prompt + break + } + } + } + + if promptFrame == nil { + t.Fatalf("No prompt call event") + return + } + + if promptFrame.Sensitive { + t.Errorf("Unexpected sensitive prompt event: %v", promptFrame.Sensitive) + } + + if len(promptFrame.Fields) != 1 { + t.Fatalf("Unexpected number of fields: %d", len(promptFrame.Fields)) + } + + if promptFrame.Fields[0].Name != "first name" { + t.Errorf("Unexpected field: %s", promptFrame.Fields[0].Name) + } + + if promptFrame.Metadata["key"] != "value" { + t.Errorf("Unexpected metadata: %v", promptFrame.Metadata) + } + + if err = g.PromptResponse(context.Background(), PromptResponse{ + ID: promptFrame.ID, + Responses: map[string]string{promptFrame.Fields[0].Name: "Clicky"}, + }); err != nil { + t.Errorf("Error responding: %v", err) + } + + // Read the remainder of the events + //nolint:revive + for range run.Events() { + } + + out, err := run.Text() + if err != nil { + t.Errorf("Error reading output: %v", err) + } + + if !strings.Contains(out, "Clicky") { + t.Errorf("Unexpected output: %s", out) + } + + if len(run.ErrorOutput()) != 0 { + t.Errorf("Should have no stderr output: %v", run.ErrorOutput()) + } +} + func TestPromptWithoutPromptAllowed(t *testing.T) { tools := ToolDef{ Instructions: "Use the sys.prompt user to ask the user for 'first name' which is not sensitive. After you get their first name, say hello.", @@ -978,6 +1402,53 @@ func TestPromptWithoutPromptAllowed(t *testing.T) { } } +func TestPromptWithOptions(t *testing.T) { + run, err := g.Run(context.Background(), "sys.prompt", Options{IncludeEvents: true, Prompt: true, Input: `{"fields":[{"name":"Authentication Method","description":"The authentication token for the user","options":["API Key","OAuth"]}]}`}) + if err != nil { + t.Errorf("Error executing tool: %v", err) + } + + // Wait for the prompt event + var promptFrame *PromptFrame + for e := range run.Events() { + if e.Prompt != nil { + if e.Prompt.Type == EventTypePrompt { + promptFrame = e.Prompt + break + } + } + } + + if promptFrame == nil { + t.Fatalf("No prompt call event") + return + } + + if len(promptFrame.Fields) != 1 { + t.Fatalf("Unexpected number of fields: %d", len(promptFrame.Fields)) + } + + if promptFrame.Fields[0].Name != "Authentication Method" { + t.Errorf("Unexpected field: %s", promptFrame.Fields[0].Name) + } + + if promptFrame.Fields[0].Description != "The authentication token for the user" { + t.Errorf("Unexpected description: %s", promptFrame.Fields[0].Description) + } + + if len(promptFrame.Fields[0].Options) != 2 { + t.Fatalf("Unexpected number of options: %d", len(promptFrame.Fields[0].Options)) + } + + if promptFrame.Fields[0].Options[0] != "API Key" { + t.Errorf("Unexpected option: %s", promptFrame.Fields[0].Options[0]) + } + + if promptFrame.Fields[0].Options[1] != "OAuth" { + t.Errorf("Unexpected option: %s", promptFrame.Fields[0].Options[1]) + } +} + func TestGetCommand(t *testing.T) { currentEnvVar := os.Getenv("GPTSCRIPT_BIN") t.Cleanup(func() { @@ -1018,3 +1489,256 @@ func TestGetCommand(t *testing.T) { }) } } + +func TestGetEnv(t *testing.T) { + // Cleaning up + defer func(currentEnvValue string) { + os.Setenv("testKey", currentEnvValue) + }(os.Getenv("testKey")) + + // Tests + testCases := []struct { + name string + key string + def string + envValue string + expectedResult string + }{ + { + name: "NoValueUseDefault", + key: "testKey", + def: "defaultValue", + envValue: "", + expectedResult: "defaultValue", + }, + { + name: "ValueExistsNoCompress", + key: "testKey", + def: "defaultValue", + envValue: "testValue", + expectedResult: "testValue", + }, + { + name: "ValueExistsCompressed", + key: "testKey", + def: "defaultValue", + envValue: `{"_gz":"H4sIAEosrGYC/ytJLS5RKEvMKU0FACtB3ewKAAAA"}`, + + expectedResult: "test value", + }, + } + + for _, test := range testCases { + t.Run(test.name, func(t *testing.T) { + os.Setenv(test.key, test.envValue) + + result := GetEnv(test.key, test.def) + + if result != test.expectedResult { + t.Errorf("expected: %s, got: %s", test.expectedResult, result) + } + }) + } +} + +func TestRunPythonWithMetadata(t *testing.T) { + wd, err := os.Getwd() + if err != nil { + t.Fatalf("Error getting working directory: %v", err) + } + + run, err := g.Run(context.Background(), wd+"/test/parse-with-metadata.gpt", Options{IncludeEvents: true}) + if err != nil { + t.Fatalf("Error executing file: %v", err) + } + + out, err := run.Text() + if err != nil { + t.Fatalf("Error reading output: %v", err) + } + + if out != "200" { + t.Errorf("Unexpected output: %s", out) + } +} + +func TestParseThenEvaluateWithMetadata(t *testing.T) { + wd, err := os.Getwd() + if err != nil { + t.Fatalf("Error getting working directory: %v", err) + } + + tools, err := g.Parse(context.Background(), wd+"/test/parse-with-metadata.gpt") + if err != nil { + t.Fatalf("Error parsing file: %v", err) + } + + run, err := g.Evaluate(context.Background(), Options{}, tools[0].ToolNode.Tool.ToolDef) + if err != nil { + t.Fatalf("Error executing file: %v", err) + } + + out, err := run.Text() + if err != nil { + t.Fatalf("Error reading output: %v", err) + } + + if out != "200" { + t.Errorf("Unexpected output: %s", out) + } +} + +func TestLoadFile(t *testing.T) { + wd, err := os.Getwd() + if err != nil { + t.Fatalf("Error getting working directory: %v", err) + } + + prg, err := g.LoadFile(context.Background(), wd+"/test/global-tools.gpt") + if err != nil { + t.Fatalf("Error loading file: %v", err) + } + + if prg.EntryToolID == "" { + t.Errorf("Unexpected entry tool ID: %s", prg.EntryToolID) + } + + if len(prg.ToolSet) == 0 { + t.Errorf("Unexpected number of tools: %d", len(prg.ToolSet)) + } + + if prg.Name == "" { + t.Errorf("Unexpected name: %s", prg.Name) + } +} + +func TestLoadRemoteFile(t *testing.T) { + prg, err := g.LoadFile(context.Background(), "github.com/gptscript-ai/context/workspace") + if err != nil { + t.Fatalf("Error loading file: %v", err) + } + + if prg.EntryToolID == "" { + t.Errorf("Unexpected entry tool ID: %s", prg.EntryToolID) + } + + if len(prg.ToolSet) == 0 { + t.Errorf("Unexpected number of tools: %d", len(prg.ToolSet)) + } + + if prg.Name == "" { + t.Errorf("Unexpected name: %s", prg.Name) + } +} + +func TestLoadContent(t *testing.T) { + wd, err := os.Getwd() + if err != nil { + t.Fatalf("Error getting working directory: %v", err) + } + + content, err := os.ReadFile(wd + "/test/global-tools.gpt") + if err != nil { + t.Fatalf("Error reading file: %v", err) + } + + prg, err := g.LoadContent(context.Background(), string(content)) + if err != nil { + t.Fatalf("Error loading file: %v", err) + } + + if prg.EntryToolID == "" { + t.Errorf("Unexpected entry tool ID: %s", prg.EntryToolID) + } + + if len(prg.ToolSet) == 0 { + t.Errorf("Unexpected number of tools: %d", len(prg.ToolSet)) + } + + // Name won't be set in this case + if prg.Name != "" { + t.Errorf("Unexpected name: %s", prg.Name) + } +} + +func TestLoadTools(t *testing.T) { + tools := []ToolDef{ + { + Tools: []string{"echo"}, + Instructions: "echo 'hello there'", + }, + { + Name: "other", + Tools: []string{"echo"}, + Instructions: "echo 'hello somewhere else'", + }, + { + Name: "echo", + Tools: []string{"sys.exec"}, + Description: "Echoes the input", + Arguments: ObjectSchema("input", "The string input to echo"), + Instructions: "#!/bin/bash\n echo ${input}", + }, + } + + prg, err := g.LoadTools(context.Background(), tools) + if err != nil { + t.Fatalf("Error loading file: %v", err) + } + + if prg.EntryToolID == "" { + t.Errorf("Unexpected entry tool ID: %s", prg.EntryToolID) + } + + if len(prg.ToolSet) == 0 { + t.Errorf("Unexpected number of tools: %d", len(prg.ToolSet)) + } + + // Name won't be set in this case + if prg.Name != "" { + t.Errorf("Unexpected name: %s", prg.Name) + } +} + +func TestCredentials(t *testing.T) { + // We will test in the following order of create, list, reveal, delete. + name := "test-" + strconv.Itoa(rand.Int()) + if len(name) > 20 { + name = name[:20] + } + + // Create + err := g.CreateCredential(context.Background(), Credential{ + Context: "testing", + ToolName: name, + Type: CredentialTypeTool, + Env: map[string]string{"ENV": "testing"}, + RefreshToken: "my-refresh-token", + CheckParam: "my-check-param", + }) + require.NoError(t, err) + + // List + creds, err := g.ListCredentials(context.Background(), ListCredentialsOptions{ + CredentialContexts: []string{"testing"}, + }) + require.NoError(t, err) + require.GreaterOrEqual(t, len(creds), 1) + + // Reveal + cred, err := g.RevealCredential(context.Background(), []string{"testing"}, name) + require.NoError(t, err) + require.Contains(t, cred.Env, "ENV") + require.Equal(t, cred.Env["ENV"], "testing") + require.Equal(t, cred.RefreshToken, "my-refresh-token") + require.Equal(t, cred.CheckParam, "my-check-param") + + // Delete + err = g.DeleteCredential(context.Background(), "testing", name) + require.NoError(t, err) + + // Delete again and make sure we get a NotFoundError + err = g.DeleteCredential(context.Background(), "testing", name) + require.Error(t, err) + require.True(t, errors.As(err, &ErrNotFound{})) +} diff --git a/opts.go b/opts.go index ee1dac3..07507e2 100644 --- a/opts.go +++ b/opts.go @@ -3,10 +3,16 @@ package gptscript // GlobalOptions allows specification of settings that are used for every call made. // These options can be overridden by the corresponding Options. type GlobalOptions struct { - OpenAIAPIKey string `json:"APIKey"` - OpenAIBaseURL string `json:"BaseURL"` - DefaultModel string `json:"DefaultModel"` - Env []string `json:"env"` + URL string `json:"url"` + Token string `json:"token"` + OpenAIAPIKey string `json:"APIKey"` + OpenAIBaseURL string `json:"BaseURL"` + DefaultModel string `json:"DefaultModel"` + DefaultModelProvider string `json:"DefaultModelProvider"` + CacheDir string `json:"CacheDir"` + Env []string `json:"env"` + DatasetTool string `json:"DatasetTool"` + WorkspaceTool string `json:"WorkspaceTool"` } func (g GlobalOptions) toEnv() []string { @@ -20,24 +26,58 @@ func (g GlobalOptions) toEnv() []string { if g.DefaultModel != "" { args = append(args, "GPTSCRIPT_SDKSERVER_DEFAULT_MODEL="+g.DefaultModel) } + if g.DefaultModelProvider != "" { + args = append(args, "GPTSCRIPT_SDKSERVER_DEFAULT_MODEL_PROVIDER="+g.DefaultModelProvider) + } + if g.WorkspaceTool != "" { + args = append(args, "GPTSCRIPT_SDKSERVER_WORKSPACE_TOOL="+g.WorkspaceTool) + } return args } +func completeGlobalOptions(opts ...GlobalOptions) GlobalOptions { + var result GlobalOptions + for _, opt := range opts { + result.CacheDir = firstSet(opt.CacheDir, result.CacheDir) + result.URL = firstSet(opt.URL, result.URL) + result.Token = firstSet(opt.Token, result.Token) + result.OpenAIAPIKey = firstSet(opt.OpenAIAPIKey, result.OpenAIAPIKey) + result.OpenAIBaseURL = firstSet(opt.OpenAIBaseURL, result.OpenAIBaseURL) + result.DefaultModel = firstSet(opt.DefaultModel, result.DefaultModel) + result.DefaultModelProvider = firstSet(opt.DefaultModelProvider, result.DefaultModelProvider) + result.DatasetTool = firstSet(opt.DatasetTool, result.DatasetTool) + result.WorkspaceTool = firstSet(opt.WorkspaceTool, result.WorkspaceTool) + result.Env = append(result.Env, opt.Env...) + } + return result +} + +func firstSet[T comparable](in ...T) T { + var result T + for _, i := range in { + if i != result { + return i + } + } + + return result +} + // Options represents options for the gptscript tool or file. type Options struct { GlobalOptions `json:",inline"` + DisableCache bool `json:"disableCache"` Confirm bool `json:"confirm"` Input string `json:"input"` - DisableCache bool `json:"disableCache"` - CacheDir string `json:"cacheDir"` SubTool string `json:"subTool"` Workspace string `json:"workspace"` ChatState string `json:"chatState"` IncludeEvents bool `json:"includeEvents"` Prompt bool `json:"prompt"` CredentialOverrides []string `json:"credentialOverrides"` + CredentialContexts []string `json:"credentialContexts"` Location string `json:"location"` ForceSequential bool `json:"forceSequential"` } diff --git a/pkg/daemon/daemon.go b/pkg/daemon/daemon.go new file mode 100644 index 0000000..b25a49a --- /dev/null +++ b/pkg/daemon/daemon.go @@ -0,0 +1,102 @@ +package daemon + +import ( + "crypto/tls" + "crypto/x509" + "encoding/base64" + "errors" + "fmt" + "net/http" + "os" +) + +type Server struct { + mux *http.ServeMux + tlsConfig *tls.Config +} + +// CreateServer creates a new HTTP server with TLS configured for GPTScript. +// This function should be used when creating a new server for a daemon tool. +// The server should then be started with the StartServer function. +func CreateServer() (*Server, error) { + return CreateServerWithMux(http.DefaultServeMux) +} + +// CreateServerWithMux creates a new HTTP server with TLS configured for GPTScript. +// This function should be used when creating a new server for a daemon tool with a custom ServeMux. +// The server should then be started with the StartServer function. +func CreateServerWithMux(mux *http.ServeMux) (*Server, error) { + tlsConfig, err := getTLSConfig() + if err != nil { + return nil, fmt.Errorf("failed to get TLS config: %v", err) + } + + return &Server{ + mux: mux, + tlsConfig: tlsConfig, + }, nil +} + +// Start starts an HTTP server created by the CreateServer function. +// This is for use with daemon tools. +func (s *Server) Start() error { + server := &http.Server{ + Addr: fmt.Sprintf("127.0.0.1:%s", os.Getenv("PORT")), + TLSConfig: s.tlsConfig, + Handler: s.mux, + } + + if err := server.ListenAndServeTLS("", ""); err != nil && !errors.Is(err, http.ErrServerClosed) { + return fmt.Errorf("stopped serving: %v", err) + } + return nil +} + +func (s *Server) HandleFunc(pattern string, handler http.HandlerFunc) { + s.mux.HandleFunc(pattern, handler) +} + +func getTLSConfig() (*tls.Config, error) { + certB64 := os.Getenv("CERT") + privateKeyB64 := os.Getenv("PRIVATE_KEY") + gptscriptCertB64 := os.Getenv("GPTSCRIPT_CERT") + + if certB64 == "" { + return nil, fmt.Errorf("CERT not set") + } else if privateKeyB64 == "" { + return nil, fmt.Errorf("PRIVATE_KEY not set") + } else if gptscriptCertB64 == "" { + return nil, fmt.Errorf("GPTSCRIPT_CERT not set") + } + + certBytes, err := base64.StdEncoding.DecodeString(certB64) + if err != nil { + return nil, fmt.Errorf("failed to decode cert base64: %v", err) + } + + privateKeyBytes, err := base64.StdEncoding.DecodeString(privateKeyB64) + if err != nil { + return nil, fmt.Errorf("failed to decode private key base64: %v", err) + } + + gptscriptCertBytes, err := base64.StdEncoding.DecodeString(gptscriptCertB64) + if err != nil { + return nil, fmt.Errorf("failed to decode gptscript cert base64: %v", err) + } + + cert, err := tls.X509KeyPair(certBytes, privateKeyBytes) + if err != nil { + return nil, fmt.Errorf("failed to create X509 key pair: %v", err) + } + + pool := x509.NewCertPool() + if !pool.AppendCertsFromPEM(gptscriptCertBytes) { + return nil, fmt.Errorf("failed to append gptscript cert to pool") + } + + return &tls.Config{ + Certificates: []tls.Certificate{cert}, + ClientCAs: pool, + ClientAuth: tls.RequireAndVerifyClientCert, + }, nil +} diff --git a/run.go b/run.go index 6e6c5fd..983dccc 100644 --- a/run.go +++ b/run.go @@ -17,25 +17,34 @@ import ( var errAbortRun = errors.New("run aborted") +type ErrNotFound struct { + Message string +} + +func (e ErrNotFound) Error() string { + return e.Message +} + type Run struct { - url, requestPath, toolPath string - tools []ToolDef - opts Options - state RunState - chatState string - cancel context.CancelCauseFunc - err error - wait func() - basicCommand bool - - program *Program - callsLock sync.RWMutex - calls map[string]CallFrame - parentCallFrameID string - rawOutput map[string]any - output, errput string - events chan Frame - lock sync.Mutex + url, token, requestPath, toolPath string + tools []ToolDef + opts Options + state RunState + chatState string + cancel context.CancelCauseFunc + err error + wait func() + basicCommand bool + + program *Program + id string + callsLock sync.RWMutex + calls CallFrames + rawOutput map[string]any + output, errput string + events chan Frame + lock sync.Mutex + responseCode int } // Text returns the text output of the gptscript. It blocks until the output is ready. @@ -60,6 +69,11 @@ func (r *Run) State() RunState { // Err returns the error that caused the gptscript to fail, if any. func (r *Run) Err() error { if r.err != nil { + if r.responseCode == http.StatusNotFound { + return ErrNotFound{ + Message: fmt.Sprintf("run encountered an error: %s", r.errput), + } + } return fmt.Errorf("run encountered an error: %w with error output: %s", r.err, r.errput) } return nil @@ -67,8 +81,8 @@ func (r *Run) Err() error { // Program returns the gptscript program for the run. func (r *Run) Program() *Program { - r.lock.Lock() - defer r.lock.Unlock() + r.callsLock.Lock() + defer r.callsLock.Unlock() return r.program } @@ -90,7 +104,7 @@ func (r *Run) RespondingTool() Tool { } // Calls will return a flattened array of the calls for this run. -func (r *Run) Calls() map[string]CallFrame { +func (r *Run) Calls() CallFrames { r.callsLock.RLock() defer r.callsLock.RUnlock() return maps.Clone(r.calls) @@ -101,11 +115,22 @@ func (r *Run) ParentCallFrame() (CallFrame, bool) { r.callsLock.RLock() defer r.callsLock.RUnlock() - if r.parentCallFrameID == "" { - return CallFrame{}, false + return r.calls.ParentCallFrame(), true +} + +// Usage returns all the usage for this run. +func (r *Run) Usage() Usage { + var u Usage + r.callsLock.RLock() + defer r.callsLock.RUnlock() + + for _, c := range r.calls { + u.CompletionTokens += c.Usage.CompletionTokens + u.PromptTokens += c.Usage.PromptTokens + u.TotalTokens += c.Usage.TotalTokens } - return r.calls[r.parentCallFrameID], true + return u } // ErrorOutput returns the stderr output of the gptscript. @@ -126,6 +151,12 @@ func (r *Run) Close() error { return fmt.Errorf("run not started") } + if r.lock.TryLock() { + r.lock.Unlock() + // If we can get the lock, then the run isn't running, so nothing to do. + return nil + } + r.cancel(errAbortRun) if r.wait == nil { return nil @@ -175,18 +206,24 @@ func (r *Run) NextChat(ctx context.Context, input string) (*Run, error) { run.opts.ChatState = r.chatState } - var payload any + var ( + payload any + options = run.opts + ) + // Remove the url and token because they shouldn't be sent with the payload. + options.URL = "" + options.Token = "" if len(r.tools) != 0 { payload = requestPayload{ ToolDefs: r.tools, Input: input, - Options: run.opts, + Options: options, } } else if run.toolPath != "" { payload = requestPayload{ File: run.toolPath, Input: input, - Options: run.opts, + Options: options, } } @@ -228,6 +265,10 @@ func (r *Run) request(ctx context.Context, payload any) (err error) { return r.err } + if r.opts.Token != "" { + req.Header.Set("Authorization", "Bearer "+r.opts.Token) + } + resp, err := http.DefaultClient.Do(req) if err != nil { r.state = Error @@ -235,9 +276,10 @@ func (r *Run) request(ctx context.Context, payload any) (err error) { return r.err } + r.responseCode = resp.StatusCode if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusBadRequest { r.state = Error - r.err = fmt.Errorf("run encountered an error") + r.err = fmt.Errorf("run encountered an error: status code %d", resp.StatusCode) } else { r.state = Running } @@ -265,10 +307,10 @@ func (r *Run) request(ctx context.Context, payload any) (err error) { ) defer func() { resp.Body.Close() - close(r.events) cancel(r.err) r.wait() r.lock.Unlock() + close(r.events) }() r.callsLock.Lock() @@ -325,6 +367,15 @@ func (r *Run) request(ctx context.Context, payload any) (err error) { done, _ = out["done"].(bool) r.rawOutput = out + case []any: + b, err := json.Marshal(out) + if err != nil { + r.state = Error + r.err = fmt.Errorf("failed to process stdout: %w", err) + return + } + + r.output = string(b) default: r.state = Error r.err = fmt.Errorf("failed to process stdout, invalid type: %T", out) @@ -360,18 +411,16 @@ func (r *Run) request(ctx context.Context, payload any) (err error) { if event.Call != nil { r.callsLock.Lock() r.calls[event.Call.ID] = *event.Call - if r.parentCallFrameID == "" && event.Call.ParentID == "" { - r.parentCallFrameID = event.Call.ID - } r.callsLock.Unlock() } else if event.Run != nil { if event.Run.Type == EventTypeRunStart { r.callsLock.Lock() r.program = &event.Run.Program + r.id = event.Run.ID r.callsLock.Unlock() } else if event.Run.Type == EventTypeRunFinish && event.Run.Error != "" { r.state = Error - r.err = fmt.Errorf(event.Run.Error) + r.err = fmt.Errorf("%s", event.Run.Error) } } @@ -382,7 +431,7 @@ func (r *Run) request(ctx context.Context, payload any) (err error) { } } - if err != nil && !errors.Is(err, io.EOF) { + if !errors.Is(err, io.EOF) { slog.Debug("failed to read events from response", "error", err) r.err = fmt.Errorf("failed to read events: %w", err) } diff --git a/run_test.go b/run_test.go index f9014a1..ac1e1d8 100644 --- a/run_test.go +++ b/run_test.go @@ -2,8 +2,13 @@ package gptscript import ( "context" + "crypto/rand" + "encoding/hex" + "os" "runtime" "testing" + + "github.com/stretchr/testify/require" ) func TestRestartingErrorRun(t *testing.T) { @@ -42,3 +47,44 @@ func TestRestartingErrorRun(t *testing.T) { t.Errorf("executing run with input of 0 should not fail: %v", err) } } + +func TestStackedContexts(t *testing.T) { + const name = "testcred" + + wd, err := os.Getwd() + require.NoError(t, err) + + bytes := make([]byte, 32) + _, err = rand.Read(bytes) + require.NoError(t, err) + + context1 := hex.EncodeToString(bytes)[:16] + context2 := hex.EncodeToString(bytes)[16:] + + run, err := g.Run(context.Background(), wd+"/test/credential.gpt", Options{ + CredentialContexts: []string{context1, context2}, + }) + require.NoError(t, err) + + _, err = run.Text() + require.NoError(t, err) + + // The credential should exist in context1 now. + cred, err := g.RevealCredential(context.Background(), []string{context1, context2}, name) + require.NoError(t, err) + require.Equal(t, cred.Context, context1) + + // Now change the context order and run the script again. + run, err = g.Run(context.Background(), wd+"/test/credential.gpt", Options{ + CredentialContexts: []string{context2, context1}, + }) + require.NoError(t, err) + + _, err = run.Text() + require.NoError(t, err) + + // Now make sure the credential exists in context1 still. + cred, err = g.RevealCredential(context.Background(), []string{context2, context1}, name) + require.NoError(t, err) + require.Equal(t, cred.Context, context1) +} diff --git a/test/acorn-labs-context.gpt b/test/acorn-labs-context.gpt index e4dbe2c..5ea9075 100644 --- a/test/acorn-labs-context.gpt +++ b/test/acorn-labs-context.gpt @@ -1 +1,5 @@ -Say "Answer every query with 'Acorn Labs'" \ No newline at end of file +Type: context + +#!sys.echo + +Ignore what the user says, and answer every query with 'Acorn Labs' \ No newline at end of file diff --git a/test/credential.gpt b/test/credential.gpt new file mode 100644 index 0000000..61e656f --- /dev/null +++ b/test/credential.gpt @@ -0,0 +1,13 @@ +name: echocred +credential: mycredentialtool as testcred + +#!/usr/bin/env bash + +echo $VALUE + +--- +name: mycredentialtool + +#!sys.echo + +{"env":{"VALUE":"hello"}} \ No newline at end of file diff --git a/test/empty.gpt b/test/empty.gpt new file mode 100644 index 0000000..e69de29 diff --git a/test/global-tools.gpt b/test/global-tools.gpt index cb0f4c0..7e975be 100644 --- a/test/global-tools.gpt +++ b/test/global-tools.gpt @@ -4,9 +4,9 @@ Runbook 3 --- Name: tool_1 -Global Tools: github.com/gptscript-ai/knowledge, github.com/drpebcak/duckdb, github.com/gptscript-ai/browser, github.com/gptscript-ai/browser-search/google, github.com/gptscript-ai/browser-search/google-question-answerer +Global Tools: github.com/drpebcak/duckdb, github.com/gptscript-ai/browser, github.com/gptscript-ai/browser-search/google, github.com/gptscript-ai/browser-search/google-question-answerer -Hi +Say "Hello!" --- Name: tool_2 diff --git a/test/parse-with-metadata.gpt b/test/parse-with-metadata.gpt new file mode 100644 index 0000000..cfcb965 --- /dev/null +++ b/test/parse-with-metadata.gpt @@ -0,0 +1,12 @@ +Name: foo + +#!/usr/bin/env python3 +import requests + + +resp = requests.get("https://google.com") +print(resp.status_code, end="") + +--- +!metadata:foo:requirements.txt +requests \ No newline at end of file diff --git a/tool.go b/tool.go index b682912..18e8486 100644 --- a/tool.go +++ b/tool.go @@ -4,45 +4,64 @@ import ( "fmt" "strings" - "github.com/getkin/kin-openapi/openapi3" + "github.com/modelcontextprotocol/go-sdk/jsonschema" ) // ToolDef struct represents a tool with various configurations. type ToolDef struct { - Name string `json:"name,omitempty"` - Description string `json:"description,omitempty"` - MaxTokens int `json:"maxTokens,omitempty"` - ModelName string `json:"modelName,omitempty"` - ModelProvider bool `json:"modelProvider,omitempty"` - JSONResponse bool `json:"jsonResponse,omitempty"` - Chat bool `json:"chat,omitempty"` - Temperature *float32 `json:"temperature,omitempty"` - Cache *bool `json:"cache,omitempty"` - InternalPrompt *bool `json:"internalPrompt"` - Arguments *openapi3.Schema `json:"arguments,omitempty"` - Tools []string `json:"tools,omitempty"` - GlobalTools []string `json:"globalTools,omitempty"` - GlobalModelName string `json:"globalModelName,omitempty"` - Context []string `json:"context,omitempty"` - ExportContext []string `json:"exportContext,omitempty"` - Export []string `json:"export,omitempty"` - Agents []string `json:"agents,omitempty"` - Credentials []string `json:"credentials,omitempty"` - Instructions string `json:"instructions,omitempty"` + Name string `json:"name,omitempty"` + Description string `json:"description,omitempty"` + MaxTokens int `json:"maxTokens,omitempty"` + ModelName string `json:"modelName,omitempty"` + ModelProvider bool `json:"modelProvider,omitempty"` + JSONResponse bool `json:"jsonResponse,omitempty"` + Chat bool `json:"chat,omitempty"` + Temperature *float32 `json:"temperature,omitempty"` + Cache *bool `json:"cache,omitempty"` + InternalPrompt *bool `json:"internalPrompt"` + Arguments *jsonschema.Schema `json:"arguments,omitempty"` + Tools []string `json:"tools,omitempty"` + GlobalTools []string `json:"globalTools,omitempty"` + GlobalModelName string `json:"globalModelName,omitempty"` + Context []string `json:"context,omitempty"` + ExportContext []string `json:"exportContext,omitempty"` + Export []string `json:"export,omitempty"` + Agents []string `json:"agents,omitempty"` + Credentials []string `json:"credentials,omitempty"` + ExportCredentials []string `json:"exportCredentials,omitempty"` + InputFilters []string `json:"inputFilters,omitempty"` + ExportInputFilters []string `json:"exportInputFilters,omitempty"` + OutputFilters []string `json:"outputFilters,omitempty"` + ExportOutputFilters []string `json:"exportOutputFilters,omitempty"` + Instructions string `json:"instructions,omitempty"` + Type string `json:"type,omitempty"` + MetaData map[string]string `json:"metadata,omitempty"` } -func ObjectSchema(kv ...string) *openapi3.Schema { - s := &openapi3.Schema{ - Type: &openapi3.Types{"object"}, - Properties: openapi3.Schemas{}, +func ToolDefsToNodes(tools []ToolDef) []Node { + nodes := make([]Node, 0, len(tools)) + for _, tool := range tools { + nodes = append(nodes, Node{ + ToolNode: &ToolNode{ + Tool: Tool{ + ToolDef: tool, + }, + }, + }) + } + return nodes +} + +func ObjectSchema(kv ...string) *jsonschema.Schema { + s := &jsonschema.Schema{ + Type: "object", + Properties: make(map[string]*jsonschema.Schema, len(kv)/2), } for i, v := range kv { if i%2 == 1 { - s.Properties[kv[i-1]] = &openapi3.SchemaRef{ - Value: &openapi3.Schema{ - Description: v, - Type: &openapi3.Types{"string"}, - }, + s.Properties[kv[i-1]] = &jsonschema.Schema{ + Description: v, + Type: "string", } } } @@ -84,7 +103,6 @@ type ToolNode struct { type Tool struct { ToolDef `json:",inline"` ID string `json:"id,omitempty"` - Arguments *openapi3.Schema `json:"arguments,omitempty"` ToolMapping map[string][]ToolReference `json:"toolMapping,omitempty"` LocalTools map[string]string `json:"localTools,omitempty"` Source ToolSource `json:"source,omitempty"` diff --git a/workspace.go b/workspace.go new file mode 100644 index 0000000..f384a85 --- /dev/null +++ b/workspace.go @@ -0,0 +1,436 @@ +package gptscript + +import ( + "context" + "encoding/base64" + "encoding/json" + "fmt" + "os" + "regexp" + "strings" + "time" +) + +var conflictErrParser = regexp.MustCompile(`^.+500 Internal Server Error: conflict: (.+)/([^/]+) \(latest revision: (-?\d+), current revision: (-?\d+)\)$`) + +type NotFoundInWorkspaceError struct { + id string + name string +} + +func (e *NotFoundInWorkspaceError) Error() string { + return fmt.Sprintf("not found: %s/%s", e.id, e.name) +} + +func newNotFoundInWorkspaceError(id, name string) *NotFoundInWorkspaceError { + return &NotFoundInWorkspaceError{id: id, name: name} +} + +type ConflictInWorkspaceError struct { + ID string + Name string + LatestRevision string + CurrentRevision string +} + +func parsePossibleConflictInWorkspaceError(err error) error { + if err == nil { + return err + } + + matches := conflictErrParser.FindStringSubmatch(err.Error()) + if len(matches) != 5 { + return err + } + return &ConflictInWorkspaceError{ID: matches[1], Name: matches[2], LatestRevision: matches[3], CurrentRevision: matches[4]} +} + +func (e *ConflictInWorkspaceError) Error() string { + return fmt.Sprintf("conflict: %s/%s (latest revision: %s, current revision: %s)", e.ID, e.Name, e.LatestRevision, e.CurrentRevision) +} + +func (g *GPTScript) CreateWorkspace(ctx context.Context, providerType string, fromWorkspaces ...string) (string, error) { + out, err := g.runBasicCommand(ctx, "workspaces/create", map[string]any{ + "providerType": providerType, + "fromWorkspaceIDs": fromWorkspaces, + "workspaceTool": g.globalOpts.WorkspaceTool, + "env": g.globalOpts.Env, + }) + if err != nil { + return "", err + } + + return strings.TrimSpace(out), nil +} + +func (g *GPTScript) DeleteWorkspace(ctx context.Context, workspaceID string) error { + if workspaceID == "" { + return fmt.Errorf("workspace ID cannot be empty") + } + + _, err := g.runBasicCommand(ctx, "workspaces/delete", map[string]any{ + "id": workspaceID, + "workspaceTool": g.globalOpts.WorkspaceTool, + "env": g.globalOpts.Env, + }) + + return err +} + +type ListFilesInWorkspaceOptions struct { + WorkspaceID string + Prefix string +} + +func (g *GPTScript) ListFilesInWorkspace(ctx context.Context, opts ...ListFilesInWorkspaceOptions) ([]string, error) { + var opt ListFilesInWorkspaceOptions + for _, o := range opts { + if o.Prefix != "" { + opt.Prefix = o.Prefix + } + if o.WorkspaceID != "" { + opt.WorkspaceID = o.WorkspaceID + } + } + + if opt.WorkspaceID == "" { + opt.WorkspaceID = os.Getenv("GPTSCRIPT_WORKSPACE_ID") + } + + out, err := g.runBasicCommand(ctx, "workspaces/list", map[string]any{ + "id": opt.WorkspaceID, + "prefix": opt.Prefix, + "workspaceTool": g.globalOpts.WorkspaceTool, + "env": g.globalOpts.Env, + }) + if err != nil { + return nil, err + } + + out = strings.TrimSpace(out) + if len(out) == 0 { + return nil, nil + } + + var files []string + return files, json.Unmarshal([]byte(out), &files) +} + +type RemoveAllOptions struct { + WorkspaceID string + WithPrefix string +} + +func (g *GPTScript) RemoveAll(ctx context.Context, opts ...RemoveAllOptions) error { + var opt RemoveAllOptions + for _, o := range opts { + if o.WithPrefix != "" { + opt.WithPrefix = o.WithPrefix + } + if o.WorkspaceID != "" { + opt.WorkspaceID = o.WorkspaceID + } + } + + if opt.WorkspaceID == "" { + opt.WorkspaceID = os.Getenv("GPTSCRIPT_WORKSPACE_ID") + } + + _, err := g.runBasicCommand(ctx, "workspaces/remove-all-with-prefix", map[string]any{ + "id": opt.WorkspaceID, + "prefix": opt.WithPrefix, + "workspaceTool": g.globalOpts.WorkspaceTool, + "env": g.globalOpts.Env, + }) + + return err +} + +type WriteFileInWorkspaceOptions struct { + WorkspaceID string + CreateRevision *bool + LatestRevisionID string +} + +func (g *GPTScript) WriteFileInWorkspace(ctx context.Context, filePath string, contents []byte, opts ...WriteFileInWorkspaceOptions) error { + var opt WriteFileInWorkspaceOptions + for _, o := range opts { + if o.WorkspaceID != "" { + opt.WorkspaceID = o.WorkspaceID + } + if o.CreateRevision != nil { + opt.CreateRevision = o.CreateRevision + } + if o.LatestRevisionID != "" { + opt.LatestRevisionID = o.LatestRevisionID + } + } + + if opt.WorkspaceID == "" { + opt.WorkspaceID = os.Getenv("GPTSCRIPT_WORKSPACE_ID") + } + + _, err := g.runBasicCommand(ctx, "workspaces/write-file", map[string]any{ + "id": opt.WorkspaceID, + "contents": base64.StdEncoding.EncodeToString(contents), + "filePath": filePath, + "createRevision": opt.CreateRevision, + "latestRevisionID": opt.LatestRevisionID, + "workspaceTool": g.globalOpts.WorkspaceTool, + "env": g.globalOpts.Env, + }) + + return parsePossibleConflictInWorkspaceError(err) +} + +type DeleteFileInWorkspaceOptions struct { + WorkspaceID string +} + +func (g *GPTScript) DeleteFileInWorkspace(ctx context.Context, filePath string, opts ...DeleteFileInWorkspaceOptions) error { + var opt DeleteFileInWorkspaceOptions + for _, o := range opts { + if o.WorkspaceID != "" { + opt.WorkspaceID = o.WorkspaceID + } + } + + if opt.WorkspaceID == "" { + opt.WorkspaceID = os.Getenv("GPTSCRIPT_WORKSPACE_ID") + } + + _, err := g.runBasicCommand(ctx, "workspaces/delete-file", map[string]any{ + "id": opt.WorkspaceID, + "filePath": filePath, + "workspaceTool": g.globalOpts.WorkspaceTool, + "env": g.globalOpts.Env, + }) + + if err != nil && strings.HasSuffix(err.Error(), fmt.Sprintf("not found: %s/%s", opt.WorkspaceID, filePath)) { + return newNotFoundInWorkspaceError(opt.WorkspaceID, filePath) + } + + return err +} + +type ReadFileInWorkspaceOptions struct { + WorkspaceID string +} + +func (g *GPTScript) ReadFileInWorkspace(ctx context.Context, filePath string, opts ...ReadFileInWorkspaceOptions) ([]byte, error) { + var opt ReadFileInWorkspaceOptions + for _, o := range opts { + if o.WorkspaceID != "" { + opt.WorkspaceID = o.WorkspaceID + } + } + + if opt.WorkspaceID == "" { + opt.WorkspaceID = os.Getenv("GPTSCRIPT_WORKSPACE_ID") + } + + out, err := g.runBasicCommand(ctx, "workspaces/read-file", map[string]any{ + "id": opt.WorkspaceID, + "filePath": filePath, + "workspaceTool": g.globalOpts.WorkspaceTool, + "env": g.globalOpts.Env, + }) + if err != nil { + if strings.HasSuffix(err.Error(), fmt.Sprintf("not found: %s/%s", opt.WorkspaceID, filePath)) { + return nil, newNotFoundInWorkspaceError(opt.WorkspaceID, filePath) + } + return nil, err + } + + return base64.StdEncoding.DecodeString(out) +} + +type ReadFileWithRevisionInWorkspaceResponse struct { + Content []byte `json:"content"` + RevisionID string `json:"revisionID"` +} + +func (g *GPTScript) ReadFileWithRevisionInWorkspace(ctx context.Context, filePath string, opts ...ReadFileInWorkspaceOptions) (*ReadFileWithRevisionInWorkspaceResponse, error) { + var opt ReadFileInWorkspaceOptions + for _, o := range opts { + if o.WorkspaceID != "" { + opt.WorkspaceID = o.WorkspaceID + } + } + + if opt.WorkspaceID == "" { + opt.WorkspaceID = os.Getenv("GPTSCRIPT_WORKSPACE_ID") + } + + out, err := g.runBasicCommand(ctx, "workspaces/read-file-with-revision", map[string]any{ + "id": opt.WorkspaceID, + "filePath": filePath, + "workspaceTool": g.globalOpts.WorkspaceTool, + "env": g.globalOpts.Env, + }) + if err != nil { + if strings.HasSuffix(err.Error(), fmt.Sprintf("not found: %s/%s", opt.WorkspaceID, filePath)) { + return nil, newNotFoundInWorkspaceError(opt.WorkspaceID, filePath) + } + return nil, err + } + + var resp ReadFileWithRevisionInWorkspaceResponse + err = json.Unmarshal([]byte(out), &resp) + if err != nil { + return nil, err + } + + return &resp, nil +} + +type FileInfo struct { + WorkspaceID string + Name string + Size int64 + ModTime time.Time + MimeType string + RevisionID string +} + +type StatFileInWorkspaceOptions struct { + WorkspaceID string + WithLatestRevisionID bool +} + +func (g *GPTScript) StatFileInWorkspace(ctx context.Context, filePath string, opts ...StatFileInWorkspaceOptions) (FileInfo, error) { + var opt StatFileInWorkspaceOptions + for _, o := range opts { + if o.WorkspaceID != "" { + opt.WorkspaceID = o.WorkspaceID + } + opt.WithLatestRevisionID = opt.WithLatestRevisionID || o.WithLatestRevisionID + } + + if opt.WorkspaceID == "" { + opt.WorkspaceID = os.Getenv("GPTSCRIPT_WORKSPACE_ID") + } + + out, err := g.runBasicCommand(ctx, "workspaces/stat-file", map[string]any{ + "id": opt.WorkspaceID, + "filePath": filePath, + "withLatestRevisionID": opt.WithLatestRevisionID, + "workspaceTool": g.globalOpts.WorkspaceTool, + "env": g.globalOpts.Env, + }) + if err != nil { + if strings.HasSuffix(err.Error(), fmt.Sprintf("not found: %s/%s", opt.WorkspaceID, filePath)) { + return FileInfo{}, newNotFoundInWorkspaceError(opt.WorkspaceID, filePath) + } + return FileInfo{}, err + } + + var info FileInfo + err = json.Unmarshal([]byte(out), &info) + if err != nil { + return FileInfo{}, err + } + + return info, nil +} + +type ListRevisionsForFileInWorkspaceOptions struct { + WorkspaceID string +} + +func (g *GPTScript) ListRevisionsForFileInWorkspace(ctx context.Context, filePath string, opts ...ListRevisionsForFileInWorkspaceOptions) ([]FileInfo, error) { + var opt ListRevisionsForFileInWorkspaceOptions + for _, o := range opts { + if o.WorkspaceID != "" { + opt.WorkspaceID = o.WorkspaceID + } + } + + if opt.WorkspaceID == "" { + opt.WorkspaceID = os.Getenv("GPTSCRIPT_WORKSPACE_ID") + } + + out, err := g.runBasicCommand(ctx, "workspaces/list-revisions", map[string]any{ + "id": opt.WorkspaceID, + "filePath": filePath, + "workspaceTool": g.globalOpts.WorkspaceTool, + "env": g.globalOpts.Env, + }) + if err != nil { + if strings.HasSuffix(err.Error(), fmt.Sprintf("not found: %s/%s", opt.WorkspaceID, filePath)) { + return nil, newNotFoundInWorkspaceError(opt.WorkspaceID, filePath) + } + return nil, err + } + + var info []FileInfo + err = json.Unmarshal([]byte(out), &info) + if err != nil { + return nil, err + } + + return info, nil +} + +type GetRevisionForFileInWorkspaceOptions struct { + WorkspaceID string +} + +func (g *GPTScript) GetRevisionForFileInWorkspace(ctx context.Context, filePath, revisionID string, opts ...GetRevisionForFileInWorkspaceOptions) ([]byte, error) { + var opt GetRevisionForFileInWorkspaceOptions + for _, o := range opts { + if o.WorkspaceID != "" { + opt.WorkspaceID = o.WorkspaceID + } + } + + if opt.WorkspaceID == "" { + opt.WorkspaceID = os.Getenv("GPTSCRIPT_WORKSPACE_ID") + } + + out, err := g.runBasicCommand(ctx, "workspaces/get-revision", map[string]any{ + "id": opt.WorkspaceID, + "filePath": filePath, + "revisionID": revisionID, + "workspaceTool": g.globalOpts.WorkspaceTool, + "env": g.globalOpts.Env, + }) + if err != nil { + if strings.HasSuffix(err.Error(), fmt.Sprintf("not found: %s/%s", opt.WorkspaceID, filePath)) { + return nil, newNotFoundInWorkspaceError(opt.WorkspaceID, filePath) + } + return nil, err + } + + return base64.StdEncoding.DecodeString(out) +} + +type DeleteRevisionForFileInWorkspaceOptions struct { + WorkspaceID string +} + +func (g *GPTScript) DeleteRevisionForFileInWorkspace(ctx context.Context, filePath, revisionID string, opts ...DeleteRevisionForFileInWorkspaceOptions) error { + var opt DeleteRevisionForFileInWorkspaceOptions + for _, o := range opts { + if o.WorkspaceID != "" { + opt.WorkspaceID = o.WorkspaceID + } + } + + if opt.WorkspaceID == "" { + opt.WorkspaceID = os.Getenv("GPTSCRIPT_WORKSPACE_ID") + } + + _, err := g.runBasicCommand(ctx, "workspaces/delete-revision", map[string]any{ + "id": opt.WorkspaceID, + "filePath": filePath, + "revisionID": revisionID, + "workspaceTool": g.globalOpts.WorkspaceTool, + "env": g.globalOpts.Env, + }) + if err != nil && strings.HasSuffix(err.Error(), fmt.Sprintf("not found: %s/%s", opt.WorkspaceID, filePath)) { + return newNotFoundInWorkspaceError(opt.WorkspaceID, fmt.Sprintf("revision %s for %s", revisionID, filePath)) + } + + return err +} diff --git a/workspace_test.go b/workspace_test.go new file mode 100644 index 0000000..eb895d3 --- /dev/null +++ b/workspace_test.go @@ -0,0 +1,1110 @@ +package gptscript + +import ( + "bytes" + "context" + "errors" + "fmt" + "os" + "testing" +) + +func TestWorkspaceIDRequiredForDelete(t *testing.T) { + if err := g.DeleteWorkspace(context.Background(), ""); err == nil { + t.Error("Expected error but got nil") + } +} + +func TestCreateAndDeleteWorkspace(t *testing.T) { + id, err := g.CreateWorkspace(context.Background(), "directory") + if err != nil { + t.Fatalf("Error creating workspace: %v", err) + } + + err = g.DeleteWorkspace(context.Background(), id) + if err != nil { + t.Errorf("Error deleting workspace: %v", err) + } +} + +func TestCreateAndDeleteWorkspaceFromWorkspace(t *testing.T) { + id, err := g.CreateWorkspace(context.Background(), "directory") + if err != nil { + t.Fatalf("Error creating workspace: %v", err) + } + + t.Cleanup(func() { + err = g.DeleteWorkspace(context.Background(), id) + if err != nil { + t.Errorf("Error deleting workspace: %v", err) + } + }) + + err = g.WriteFileInWorkspace(context.Background(), "file.txt", []byte("hello world"), WriteFileInWorkspaceOptions{ + WorkspaceID: id, + }) + if err != nil { + t.Errorf("Error creating file: %v", err) + } + + newID, err := g.CreateWorkspace(context.Background(), "directory", id) + if err != nil { + t.Errorf("Error creating workspace from workspace: %v", err) + } + + content, err := g.ReadFileInWorkspace(context.Background(), "file.txt", ReadFileInWorkspaceOptions{ + WorkspaceID: newID, + }) + if err != nil { + t.Fatalf("Error reading file: %v", err) + } + + if !bytes.Equal(content, []byte("hello world")) { + t.Errorf("Unexpected content: %s", content) + } + + err = g.DeleteWorkspace(context.Background(), id) + if err != nil { + t.Errorf("Error deleting workspace: %v", err) + } +} + +func TestWriteReadAndDeleteFileFromWorkspace(t *testing.T) { + id, err := g.CreateWorkspace(context.Background(), "directory") + if err != nil { + t.Fatalf("Error creating workspace: %v", err) + } + + t.Cleanup(func() { + err := g.DeleteWorkspace(context.Background(), id) + if err != nil { + t.Errorf("Error deleting workspace: %v", err) + } + }) + + err = g.WriteFileInWorkspace(context.Background(), "test.txt", []byte("test"), WriteFileInWorkspaceOptions{WorkspaceID: id}) + if err != nil { + t.Fatalf("Error creating file: %v", err) + } + + content, err := g.ReadFileInWorkspace(context.Background(), "test.txt", ReadFileInWorkspaceOptions{WorkspaceID: id}) + if err != nil { + t.Errorf("Error reading file: %v", err) + } + + if !bytes.Equal(content, []byte("test")) { + t.Errorf("Unexpected content: %s", content) + } + + // Read the file and request the revision ID + contentWithRevision, err := g.ReadFileWithRevisionInWorkspace(context.Background(), "test.txt", ReadFileInWorkspaceOptions{WorkspaceID: id}) + if err != nil { + t.Errorf("Error reading file: %v", err) + } + + if !bytes.Equal(contentWithRevision.Content, []byte("test")) { + t.Errorf("Unexpected content: %s", contentWithRevision.Content) + } + + if contentWithRevision.RevisionID == "" { + t.Errorf("Expected file revision ID when requesting it: %s", contentWithRevision.RevisionID) + } + + // Stat the file to ensure it exists + fileInfo, err := g.StatFileInWorkspace(context.Background(), "test.txt", StatFileInWorkspaceOptions{WorkspaceID: id}) + if err != nil { + t.Errorf("Error statting file: %v", err) + } + + if fileInfo.WorkspaceID != id { + t.Errorf("Unexpected file workspace ID: %v", fileInfo.WorkspaceID) + } + + if fileInfo.Name != "test.txt" { + t.Errorf("Unexpected file name: %s", fileInfo.Name) + } + + if fileInfo.Size != 4 { + t.Errorf("Unexpected file size: %d", fileInfo.Size) + } + + if fileInfo.ModTime.IsZero() { + t.Errorf("Unexpected file mod time: %v", fileInfo.ModTime) + } + + if fileInfo.MimeType != "text/plain" { + t.Errorf("Unexpected file mime type: %s", fileInfo.MimeType) + } + + if fileInfo.RevisionID != "" { + t.Errorf("Unexpected file revision ID when not requesting it: %s", fileInfo.RevisionID) + } + + // Stat file and request the revision ID + fileInfo, err = g.StatFileInWorkspace(context.Background(), "test.txt", StatFileInWorkspaceOptions{WorkspaceID: id, WithLatestRevisionID: true}) + if err != nil { + t.Errorf("Error statting file: %v", err) + } + + if fileInfo.WorkspaceID != id { + t.Errorf("Unexpected file workspace ID: %v", fileInfo.WorkspaceID) + } + + if fileInfo.RevisionID == "" { + t.Errorf("Expected file revision ID when requesting it: %s", fileInfo.RevisionID) + } + + // Ensure we get the error we expect when trying to read a non-existent file + _, err = g.ReadFileInWorkspace(context.Background(), "test1.txt", ReadFileInWorkspaceOptions{WorkspaceID: id}) + if nf := (*NotFoundInWorkspaceError)(nil); !errors.As(err, &nf) { + t.Errorf("Unexpected error reading non-existent file: %v", err) + } + + err = g.DeleteFileInWorkspace(context.Background(), "test.txt", DeleteFileInWorkspaceOptions{WorkspaceID: id}) + if err != nil { + t.Errorf("Error deleting file: %v", err) + } +} + +func TestRevisionsForFileInWorkspace(t *testing.T) { + id, err := g.CreateWorkspace(context.Background(), "directory") + if err != nil { + t.Fatalf("Error creating workspace: %v", err) + } + + t.Cleanup(func() { + err := g.DeleteWorkspace(context.Background(), id) + if err != nil { + t.Errorf("Error deleting workspace: %v", err) + } + }) + + err = g.WriteFileInWorkspace(context.Background(), "test.txt", []byte("test0"), WriteFileInWorkspaceOptions{WorkspaceID: id}) + if err != nil { + t.Fatalf("Error creating file: %v", err) + } + + err = g.WriteFileInWorkspace(context.Background(), "test.txt", []byte("test1"), WriteFileInWorkspaceOptions{WorkspaceID: id}) + if err != nil { + t.Fatalf("Error creating file: %v", err) + } + + err = g.WriteFileInWorkspace(context.Background(), "test.txt", []byte("test2"), WriteFileInWorkspaceOptions{WorkspaceID: id}) + if err != nil { + t.Fatalf("Error creating file: %v", err) + } + + revisions, err := g.ListRevisionsForFileInWorkspace(context.Background(), "test.txt", ListRevisionsForFileInWorkspaceOptions{WorkspaceID: id}) + if err != nil { + t.Errorf("Error reading file: %v", err) + } + + if len(revisions) != 2 { + t.Errorf("Unexpected number of revisions: %d", len(revisions)) + } + + for i, rev := range revisions { + if rev.WorkspaceID != id { + t.Errorf("Unexpected file workspace ID: %v", rev.WorkspaceID) + } + + if rev.Name != "test.txt" { + t.Errorf("Unexpected file name: %s", rev.Name) + } + + if rev.Size != 5 { + t.Errorf("Unexpected file size: %d", rev.Size) + } + + if rev.ModTime.IsZero() { + t.Errorf("Unexpected file mod time: %v", rev.ModTime) + } + + if rev.MimeType != "text/plain" { + t.Errorf("Unexpected file mime type: %s", rev.MimeType) + } + + if rev.RevisionID != fmt.Sprintf("%d", i+1) { + t.Errorf("Unexpected revision ID: %s", rev.RevisionID) + } + } + + err = g.DeleteRevisionForFileInWorkspace(context.Background(), "test.txt", "1", DeleteRevisionForFileInWorkspaceOptions{WorkspaceID: id}) + if err != nil { + t.Errorf("Error deleting revision for file: %v", err) + } + + revisions, err = g.ListRevisionsForFileInWorkspace(context.Background(), "test.txt", ListRevisionsForFileInWorkspaceOptions{WorkspaceID: id}) + if err != nil { + t.Errorf("Error reading file: %v", err) + } + + if len(revisions) != 1 { + t.Errorf("Unexpected number of revisions: %d", len(revisions)) + } + + err = g.DeleteFileInWorkspace(context.Background(), "test.txt", DeleteFileInWorkspaceOptions{WorkspaceID: id}) + if err != nil { + t.Errorf("Error deleting file: %v", err) + } + + revisions, err = g.ListRevisionsForFileInWorkspace(context.Background(), "test.txt", ListRevisionsForFileInWorkspaceOptions{WorkspaceID: id}) + if err != nil { + t.Errorf("Error reading file: %v", err) + } + + if len(revisions) != 0 { + t.Errorf("Unexpected number of revisions: %d", len(revisions)) + } +} + +func TestDisableCreateRevisionsForFileInWorkspace(t *testing.T) { + id, err := g.CreateWorkspace(context.Background(), "directory") + if err != nil { + t.Fatalf("Error creating workspace: %v", err) + } + + t.Cleanup(func() { + err := g.DeleteWorkspace(context.Background(), id) + if err != nil { + t.Errorf("Error deleting workspace: %v", err) + } + }) + + err = g.WriteFileInWorkspace(context.Background(), "test.txt", []byte("test0"), WriteFileInWorkspaceOptions{WorkspaceID: id}) + if err != nil { + t.Fatalf("Error creating file: %v", err) + } + + err = g.WriteFileInWorkspace(context.Background(), "test.txt", []byte("test1"), WriteFileInWorkspaceOptions{WorkspaceID: id, CreateRevision: new(bool)}) + if err != nil { + t.Fatalf("Error creating file: %v", err) + } + + err = g.WriteFileInWorkspace(context.Background(), "test.txt", []byte("test2"), WriteFileInWorkspaceOptions{WorkspaceID: id}) + if err != nil { + t.Fatalf("Error creating file: %v", err) + } + + revisions, err := g.ListRevisionsForFileInWorkspace(context.Background(), "test.txt", ListRevisionsForFileInWorkspaceOptions{WorkspaceID: id}) + if err != nil { + t.Errorf("Error reading file: %v", err) + } + + if len(revisions) != 1 { + t.Errorf("Unexpected number of revisions: %d", len(revisions)) + } + + for i, rev := range revisions { + if rev.WorkspaceID != id { + t.Errorf("Unexpected file workspace ID: %v", rev.WorkspaceID) + } + + if rev.Name != "test.txt" { + t.Errorf("Unexpected file name: %s", rev.Name) + } + + if rev.Size != 5 { + t.Errorf("Unexpected file size: %d", rev.Size) + } + + if rev.ModTime.IsZero() { + t.Errorf("Unexpected file mod time: %v", rev.ModTime) + } + + if rev.MimeType != "text/plain" { + t.Errorf("Unexpected file mime type: %s", rev.MimeType) + } + + if rev.RevisionID != fmt.Sprintf("%d", i+1) { + t.Errorf("Unexpected revision ID: %s", rev.RevisionID) + } + } + + err = g.DeleteRevisionForFileInWorkspace(context.Background(), "test.txt", "1", DeleteRevisionForFileInWorkspaceOptions{WorkspaceID: id}) + if err != nil { + t.Errorf("Error deleting revision for file: %v", err) + } + + revisions, err = g.ListRevisionsForFileInWorkspace(context.Background(), "test.txt", ListRevisionsForFileInWorkspaceOptions{WorkspaceID: id}) + if err != nil { + t.Errorf("Error reading file: %v", err) + } + + if len(revisions) != 0 { + t.Errorf("Unexpected number of revisions: %d", len(revisions)) + } + + err = g.DeleteFileInWorkspace(context.Background(), "test.txt", DeleteFileInWorkspaceOptions{WorkspaceID: id}) + if err != nil { + t.Errorf("Error deleting file: %v", err) + } +} + +func TestConflictsForFileInWorkspace(t *testing.T) { + id, err := g.CreateWorkspace(context.Background(), "directory") + if err != nil { + t.Fatalf("Error creating workspace: %v", err) + } + + t.Cleanup(func() { + err := g.DeleteWorkspace(context.Background(), id) + if err != nil { + t.Errorf("Error deleting workspace: %v", err) + } + }) + + ce := (*ConflictInWorkspaceError)(nil) + // Writing a new file with a non-zero latest revision should fail + err = g.WriteFileInWorkspace(context.Background(), "test.txt", []byte("test0"), WriteFileInWorkspaceOptions{WorkspaceID: id, LatestRevisionID: "1"}) + if err == nil || !errors.As(err, &ce) { + t.Errorf("Expected error writing file with non-zero latest revision: %v", err) + } + + err = g.WriteFileInWorkspace(context.Background(), "test.txt", []byte("test0"), WriteFileInWorkspaceOptions{WorkspaceID: id}) + if err != nil { + t.Fatalf("Error creating file: %v", err) + } + + err = g.WriteFileInWorkspace(context.Background(), "test.txt", []byte("test1"), WriteFileInWorkspaceOptions{WorkspaceID: id}) + if err != nil { + t.Fatalf("Error creating file: %v", err) + } + + revisions, err := g.ListRevisionsForFileInWorkspace(context.Background(), "test.txt", ListRevisionsForFileInWorkspaceOptions{WorkspaceID: id}) + if err != nil { + t.Errorf("Error reading file: %v", err) + } + + if len(revisions) != 1 { + t.Errorf("Unexpected number of revisions: %d", len(revisions)) + } + + // Writing to the file with the latest revision should succeed + err = g.WriteFileInWorkspace(context.Background(), "test.txt", []byte("test2"), WriteFileInWorkspaceOptions{WorkspaceID: id, LatestRevisionID: revisions[0].RevisionID}) + if err != nil { + t.Fatalf("Error creating file: %v", err) + } + + revisions, err = g.ListRevisionsForFileInWorkspace(context.Background(), "test.txt", ListRevisionsForFileInWorkspaceOptions{WorkspaceID: id}) + if err != nil { + t.Errorf("Error reading file: %v", err) + } + + if len(revisions) != 2 { + t.Errorf("Unexpected number of revisions: %d", len(revisions)) + } + + // Writing to the file with the same revision should fail + err = g.WriteFileInWorkspace(context.Background(), "test.txt", []byte("test3"), WriteFileInWorkspaceOptions{WorkspaceID: id, LatestRevisionID: revisions[0].RevisionID}) + if err == nil || !errors.As(err, &ce) { + t.Errorf("Expected error writing file with same revision: %v", err) + } + + latestRevisionID := revisions[1].RevisionID + err = g.DeleteRevisionForFileInWorkspace(context.Background(), "test.txt", latestRevisionID, DeleteRevisionForFileInWorkspaceOptions{WorkspaceID: id}) + if err != nil { + t.Errorf("Error deleting revision for file: %v", err) + } + + revisions, err = g.ListRevisionsForFileInWorkspace(context.Background(), "test.txt", ListRevisionsForFileInWorkspaceOptions{WorkspaceID: id}) + if err != nil { + t.Errorf("Error reading file: %v", err) + } + + if len(revisions) != 1 { + t.Errorf("Unexpected number of revisions: %d", len(revisions)) + } + + // Ensure we cannot write a new file with the zero-th revision ID + err = g.WriteFileInWorkspace(context.Background(), "test.txt", []byte("test4"), WriteFileInWorkspaceOptions{WorkspaceID: id, LatestRevisionID: revisions[0].RevisionID}) + if err == nil || !errors.As(err, &ce) { + t.Errorf("Unexpected error writing to file: %v", err) + } + + // Ensure we can write a new file after deleting the latest revision + err = g.WriteFileInWorkspace(context.Background(), "test.txt", []byte("test4"), WriteFileInWorkspaceOptions{WorkspaceID: id, LatestRevisionID: latestRevisionID}) + if err != nil { + t.Errorf("Error writing file: %v", err) + } + + err = g.DeleteFileInWorkspace(context.Background(), "test.txt", DeleteFileInWorkspaceOptions{WorkspaceID: id}) + if err != nil { + t.Errorf("Error deleting file: %v", err) + } +} + +func TestLsComplexWorkspace(t *testing.T) { + id, err := g.CreateWorkspace(context.Background(), "directory") + if err != nil { + t.Fatalf("Error creating workspace: %v", err) + } + + t.Cleanup(func() { + err := g.DeleteWorkspace(context.Background(), id) + if err != nil { + t.Errorf("Error deleting workspace: %v", err) + } + }) + + err = g.WriteFileInWorkspace(context.Background(), "test/test1.txt", []byte("hello1"), WriteFileInWorkspaceOptions{WorkspaceID: id}) + if err != nil { + t.Fatalf("Error creating file: %v", err) + } + + err = g.WriteFileInWorkspace(context.Background(), "test1/test2.txt", []byte("hello2"), WriteFileInWorkspaceOptions{WorkspaceID: id}) + if err != nil { + t.Fatalf("Error creating file: %v", err) + } + + err = g.WriteFileInWorkspace(context.Background(), "test1/test3.txt", []byte("hello3"), WriteFileInWorkspaceOptions{WorkspaceID: id}) + if err != nil { + t.Fatalf("Error creating file: %v", err) + } + + err = g.WriteFileInWorkspace(context.Background(), ".hidden.txt", []byte("hidden"), WriteFileInWorkspaceOptions{WorkspaceID: id}) + if err != nil { + t.Fatalf("Error creating hidden file: %v", err) + } + + // List all files + content, err := g.ListFilesInWorkspace(context.Background(), ListFilesInWorkspaceOptions{WorkspaceID: id}) + if err != nil { + t.Fatalf("Error listing files: %v", err) + } + + if len(content) != 4 { + t.Errorf("Unexpected number of files: %d", len(content)) + } + + // List files in subdirectory + content, err = g.ListFilesInWorkspace(context.Background(), ListFilesInWorkspaceOptions{WorkspaceID: id, Prefix: "test1"}) + if err != nil { + t.Fatalf("Error listing files: %v", err) + } + + if len(content) != 2 { + t.Errorf("Unexpected number of files: %d", len(content)) + } + + // Remove all files with test1 prefix + err = g.RemoveAll(context.Background(), RemoveAllOptions{WorkspaceID: id, WithPrefix: "test1"}) + if err != nil { + t.Fatalf("Error removing files: %v", err) + } + + // List files in subdirectory + content, err = g.ListFilesInWorkspace(context.Background(), ListFilesInWorkspaceOptions{WorkspaceID: id}) + if err != nil { + t.Fatalf("Error listing files: %v", err) + } + + if len(content) != 2 { + t.Errorf("Unexpected number of files: %d", len(content)) + } +} + +func TestCreateAndDeleteWorkspaceS3(t *testing.T) { + if os.Getenv("AWS_ACCESS_KEY_ID") == "" || os.Getenv("AWS_SECRET_ACCESS_KEY") == "" || os.Getenv("WORKSPACE_PROVIDER_S3_BUCKET") == "" { + t.Skip("Skipping test because AWS credentials are not set") + } + + id, err := g.CreateWorkspace(context.Background(), "s3") + if err != nil { + t.Fatalf("Error creating workspace: %v", err) + } + + err = g.DeleteWorkspace(context.Background(), id) + if err != nil { + t.Errorf("Error deleting workspace: %v", err) + } +} + +func TestCreateAndDeleteWorkspaceFromWorkspaceS3(t *testing.T) { + if os.Getenv("AWS_ACCESS_KEY_ID") == "" || os.Getenv("AWS_SECRET_ACCESS_KEY") == "" || os.Getenv("WORKSPACE_PROVIDER_S3_BUCKET") == "" { + t.Skip("Skipping test because AWS credentials are not set") + } + + id, err := g.CreateWorkspace(context.Background(), "s3") + if err != nil { + t.Fatalf("Error creating workspace: %v", err) + } + + err = g.WriteFileInWorkspace(context.Background(), "file.txt", []byte("hello world"), WriteFileInWorkspaceOptions{ + WorkspaceID: id, + }) + if err != nil { + t.Errorf("Error creating file: %v", err) + } + + newID, err := g.CreateWorkspace(context.Background(), "s3", id) + if err != nil { + t.Errorf("Error creating workspace from workspace: %v", err) + } + + content, err := g.ReadFileInWorkspace(context.Background(), "file.txt", ReadFileInWorkspaceOptions{ + WorkspaceID: newID, + }) + if err != nil { + t.Errorf("Error reading file: %v", err) + } + + if !bytes.Equal(content, []byte("hello world")) { + t.Errorf("Unexpected content: %s", content) + } + + err = g.DeleteWorkspace(context.Background(), id) + if err != nil { + t.Errorf("Error deleting workspace: %v", err) + } + + err = g.DeleteWorkspace(context.Background(), newID) + if err != nil { + t.Errorf("Error deleting new workspace: %v", err) + } +} + +func TestCreateAndDeleteDirectoryWorkspaceFromWorkspaceS3(t *testing.T) { + if os.Getenv("AWS_ACCESS_KEY_ID") == "" || os.Getenv("AWS_SECRET_ACCESS_KEY") == "" || os.Getenv("WORKSPACE_PROVIDER_S3_BUCKET") == "" { + t.Skip("Skipping test because AWS credentials are not set") + } + + id, err := g.CreateWorkspace(context.Background(), "s3") + if err != nil { + t.Fatalf("Error creating workspace: %v", err) + } + + err = g.WriteFileInWorkspace(context.Background(), "file.txt", []byte("hello world"), WriteFileInWorkspaceOptions{ + WorkspaceID: id, + }) + if err != nil { + t.Errorf("Error creating file: %v", err) + } + + newID, err := g.CreateWorkspace(context.Background(), "directory", id) + if err != nil { + t.Errorf("Error creating workspace from workspace: %v", err) + } + + content, err := g.ReadFileInWorkspace(context.Background(), "file.txt", ReadFileInWorkspaceOptions{ + WorkspaceID: newID, + }) + if err != nil { + t.Errorf("Error reading file: %v", err) + } + + if !bytes.Equal(content, []byte("hello world")) { + t.Errorf("Unexpected content: %s", content) + } + + err = g.DeleteWorkspace(context.Background(), id) + if err != nil { + t.Errorf("Error deleting workspace: %v", err) + } + + err = g.DeleteWorkspace(context.Background(), newID) + if err != nil { + t.Errorf("Error deleting new workspace: %v", err) + } +} + +func TestCreateAndDeleteS3WorkspaceFromWorkspaceDirectory(t *testing.T) { + if os.Getenv("AWS_ACCESS_KEY_ID") == "" || os.Getenv("AWS_SECRET_ACCESS_KEY") == "" || os.Getenv("WORKSPACE_PROVIDER_S3_BUCKET") == "" { + t.Skip("Skipping test because AWS credentials are not set") + } + + id, err := g.CreateWorkspace(context.Background(), "s3") + if err != nil { + t.Fatalf("Error creating workspace: %v", err) + } + + t.Cleanup(func() { + err = g.DeleteWorkspace(context.Background(), id) + if err != nil { + t.Errorf("Error deleting workspace: %v", err) + } + }) + + err = g.WriteFileInWorkspace(context.Background(), "file.txt", []byte("hello world"), WriteFileInWorkspaceOptions{ + WorkspaceID: id, + }) + if err != nil { + t.Errorf("Error creating file: %v", err) + } + + newID, err := g.CreateWorkspace(context.Background(), "directory", id) + if err != nil { + t.Errorf("Error creating workspace from workspace: %v", err) + } + + content, err := g.ReadFileInWorkspace(context.Background(), "file.txt", ReadFileInWorkspaceOptions{ + WorkspaceID: newID, + }) + if err != nil { + t.Fatalf("Error reading file: %v", err) + } + + if !bytes.Equal(content, []byte("hello world")) { + t.Errorf("Unexpected content: %s", content) + } + + err = g.DeleteWorkspace(context.Background(), id) + if err != nil { + t.Errorf("Error deleting workspace: %v", err) + } +} + +func TestWriteReadAndDeleteFileFromWorkspaceS3(t *testing.T) { + if os.Getenv("AWS_ACCESS_KEY_ID") == "" || os.Getenv("AWS_SECRET_ACCESS_KEY") == "" || os.Getenv("WORKSPACE_PROVIDER_S3_BUCKET") == "" { + t.Skip("Skipping test because AWS credentials are not set") + } + + id, err := g.CreateWorkspace(context.Background(), "s3") + if err != nil { + t.Fatalf("Error creating workspace: %v", err) + } + + t.Cleanup(func() { + err := g.DeleteWorkspace(context.Background(), id) + if err != nil { + t.Errorf("Error deleting workspace: %v", err) + } + }) + + err = g.WriteFileInWorkspace(context.Background(), "test.txt", []byte("test"), WriteFileInWorkspaceOptions{WorkspaceID: id}) + if err != nil { + t.Fatalf("Error creating file: %v", err) + } + + content, err := g.ReadFileInWorkspace(context.Background(), "test.txt", ReadFileInWorkspaceOptions{WorkspaceID: id}) + if err != nil { + t.Errorf("Error reading file: %v", err) + } + + if !bytes.Equal(content, []byte("test")) { + t.Errorf("Unexpected content: %s", content) + } + + // Read the file and request the revision ID + contentWithRevision, err := g.ReadFileWithRevisionInWorkspace(context.Background(), "test.txt", ReadFileInWorkspaceOptions{WorkspaceID: id}) + if err != nil { + t.Errorf("Error reading file: %v", err) + } + + if !bytes.Equal(contentWithRevision.Content, []byte("test")) { + t.Errorf("Unexpected content: %s", contentWithRevision.Content) + } + + if contentWithRevision.RevisionID == "" { + t.Errorf("Expected file revision ID when requesting it: %s", contentWithRevision.RevisionID) + } + + // Stat the file to ensure it exists + fileInfo, err := g.StatFileInWorkspace(context.Background(), "test.txt", StatFileInWorkspaceOptions{WorkspaceID: id}) + if err != nil { + t.Errorf("Error statting file: %v", err) + } + + if fileInfo.WorkspaceID != id { + t.Errorf("Unexpected file workspace ID: %v", fileInfo.WorkspaceID) + } + + if fileInfo.Name != "test.txt" { + t.Errorf("Unexpected file name: %s", fileInfo.Name) + } + + if fileInfo.Size != 4 { + t.Errorf("Unexpected file size: %d", fileInfo.Size) + } + + if fileInfo.ModTime.IsZero() { + t.Errorf("Unexpected file mod time: %v", fileInfo.ModTime) + } + + if fileInfo.MimeType != "text/plain" { + t.Errorf("Unexpected file mime type: %s", fileInfo.MimeType) + } + + if fileInfo.RevisionID != "" { + t.Errorf("Unexpected file revision ID when not requesting it: %s", fileInfo.RevisionID) + } + + // Stat file and request the revision ID + fileInfo, err = g.StatFileInWorkspace(context.Background(), "test.txt", StatFileInWorkspaceOptions{WorkspaceID: id, WithLatestRevisionID: true}) + if err != nil { + t.Errorf("Error statting file: %v", err) + } + + if fileInfo.WorkspaceID != id { + t.Errorf("Unexpected file workspace ID: %v", fileInfo.WorkspaceID) + } + + if fileInfo.RevisionID == "" { + t.Errorf("Expected file revision ID when requesting it: %s", fileInfo.RevisionID) + } + + // Ensure we get the error we expect when trying to read a non-existent file + _, err = g.ReadFileInWorkspace(context.Background(), "test1.txt", ReadFileInWorkspaceOptions{WorkspaceID: id}) + if nf := (*NotFoundInWorkspaceError)(nil); !errors.As(err, &nf) { + t.Errorf("Unexpected error reading non-existent file: %v", err) + } + + err = g.DeleteFileInWorkspace(context.Background(), "test.txt", DeleteFileInWorkspaceOptions{WorkspaceID: id}) + if err != nil { + t.Errorf("Error deleting file: %v", err) + } +} + +func TestRevisionsForFileInWorkspaceS3(t *testing.T) { + if os.Getenv("AWS_ACCESS_KEY_ID") == "" || os.Getenv("AWS_SECRET_ACCESS_KEY") == "" || os.Getenv("WORKSPACE_PROVIDER_S3_BUCKET") == "" { + t.Skip("Skipping test because AWS credentials are not set") + } + + id, err := g.CreateWorkspace(context.Background(), "s3") + if err != nil { + t.Fatalf("Error creating workspace: %v", err) + } + + t.Cleanup(func() { + err := g.DeleteWorkspace(context.Background(), id) + if err != nil { + t.Errorf("Error deleting workspace: %v", err) + } + }) + + err = g.WriteFileInWorkspace(context.Background(), "test.txt", []byte("test0"), WriteFileInWorkspaceOptions{WorkspaceID: id}) + if err != nil { + t.Fatalf("Error creating file: %v", err) + } + + err = g.WriteFileInWorkspace(context.Background(), "test.txt", []byte("test1"), WriteFileInWorkspaceOptions{WorkspaceID: id}) + if err != nil { + t.Fatalf("Error creating file: %v", err) + } + + err = g.WriteFileInWorkspace(context.Background(), "test.txt", []byte("test2"), WriteFileInWorkspaceOptions{WorkspaceID: id}) + if err != nil { + t.Fatalf("Error creating file: %v", err) + } + + revisions, err := g.ListRevisionsForFileInWorkspace(context.Background(), "test.txt", ListRevisionsForFileInWorkspaceOptions{WorkspaceID: id}) + if err != nil { + t.Errorf("Error reading file: %v", err) + } + + if len(revisions) != 2 { + t.Errorf("Unexpected number of revisions: %d", len(revisions)) + } + + for i, rev := range revisions { + if rev.WorkspaceID != id { + t.Errorf("Unexpected file workspace ID: %v", rev.WorkspaceID) + } + + if rev.Name != "test.txt" { + t.Errorf("Unexpected file name: %s", rev.Name) + } + + if rev.Size != 5 { + t.Errorf("Unexpected file size: %d", rev.Size) + } + + if rev.ModTime.IsZero() { + t.Errorf("Unexpected file mod time: %v", rev.ModTime) + } + + if rev.MimeType != "text/plain" { + t.Errorf("Unexpected file mime type: %s", rev.MimeType) + } + + if rev.RevisionID != fmt.Sprintf("%d", i+1) { + t.Errorf("Unexpected revision ID: %s", rev.RevisionID) + } + } + + err = g.DeleteRevisionForFileInWorkspace(context.Background(), "test.txt", "1", DeleteRevisionForFileInWorkspaceOptions{WorkspaceID: id}) + if err != nil { + t.Errorf("Error deleting revision for file: %v", err) + } + + revisions, err = g.ListRevisionsForFileInWorkspace(context.Background(), "test.txt", ListRevisionsForFileInWorkspaceOptions{WorkspaceID: id}) + if err != nil { + t.Errorf("Error reading file: %v", err) + } + + if len(revisions) != 1 { + t.Errorf("Unexpected number of revisions: %d", len(revisions)) + } + + err = g.DeleteFileInWorkspace(context.Background(), "test.txt", DeleteFileInWorkspaceOptions{WorkspaceID: id}) + if err != nil { + t.Errorf("Error deleting file: %v", err) + } + + revisions, err = g.ListRevisionsForFileInWorkspace(context.Background(), "test.txt", ListRevisionsForFileInWorkspaceOptions{WorkspaceID: id}) + if err != nil { + t.Errorf("Error reading file: %v", err) + } + + if len(revisions) != 0 { + t.Errorf("Unexpected number of revisions: %d", len(revisions)) + } +} + +func TestConflictsForFileInWorkspaceS3(t *testing.T) { + if os.Getenv("AWS_ACCESS_KEY_ID") == "" || os.Getenv("AWS_SECRET_ACCESS_KEY") == "" || os.Getenv("WORKSPACE_PROVIDER_S3_BUCKET") == "" { + t.Skip("Skipping test because AWS credentials are not set") + } + + id, err := g.CreateWorkspace(context.Background(), "s3") + if err != nil { + t.Fatalf("Error creating workspace: %v", err) + } + + t.Cleanup(func() { + err := g.DeleteWorkspace(context.Background(), id) + if err != nil { + t.Errorf("Error deleting workspace: %v", err) + } + }) + + ce := (*ConflictInWorkspaceError)(nil) + // Writing a new file with a non-zero latest revision should fail + err = g.WriteFileInWorkspace(context.Background(), "test.txt", []byte("test0"), WriteFileInWorkspaceOptions{WorkspaceID: id, LatestRevisionID: "1"}) + if err == nil || !errors.As(err, &ce) { + t.Errorf("Expected error writing file with non-zero latest revision: %v", err) + } + + err = g.WriteFileInWorkspace(context.Background(), "test.txt", []byte("test0"), WriteFileInWorkspaceOptions{WorkspaceID: id}) + if err != nil { + t.Fatalf("Error creating file: %v", err) + } + + err = g.WriteFileInWorkspace(context.Background(), "test.txt", []byte("test1"), WriteFileInWorkspaceOptions{WorkspaceID: id}) + if err != nil { + t.Fatalf("Error creating file: %v", err) + } + + revisions, err := g.ListRevisionsForFileInWorkspace(context.Background(), "test.txt", ListRevisionsForFileInWorkspaceOptions{WorkspaceID: id}) + if err != nil { + t.Errorf("Error reading file: %v", err) + } + + if len(revisions) != 1 { + t.Errorf("Unexpected number of revisions: %d", len(revisions)) + } + + // Writing to the file with the latest revision should succeed + err = g.WriteFileInWorkspace(context.Background(), "test.txt", []byte("test2"), WriteFileInWorkspaceOptions{WorkspaceID: id, LatestRevisionID: revisions[0].RevisionID}) + if err != nil { + t.Fatalf("Error creating file: %v", err) + } + + revisions, err = g.ListRevisionsForFileInWorkspace(context.Background(), "test.txt", ListRevisionsForFileInWorkspaceOptions{WorkspaceID: id}) + if err != nil { + t.Errorf("Error reading file: %v", err) + } + + if len(revisions) != 2 { + t.Errorf("Unexpected number of revisions: %d", len(revisions)) + } + + // Writing to the file with the same revision should fail + err = g.WriteFileInWorkspace(context.Background(), "test.txt", []byte("test3"), WriteFileInWorkspaceOptions{WorkspaceID: id, LatestRevisionID: revisions[0].RevisionID}) + if err == nil || !errors.As(err, &ce) { + t.Errorf("Expected error writing file with same revision: %v", err) + } + + latestRevisionID := revisions[1].RevisionID + err = g.DeleteRevisionForFileInWorkspace(context.Background(), "test.txt", latestRevisionID, DeleteRevisionForFileInWorkspaceOptions{WorkspaceID: id}) + if err != nil { + t.Errorf("Error deleting revision for file: %v", err) + } + + revisions, err = g.ListRevisionsForFileInWorkspace(context.Background(), "test.txt", ListRevisionsForFileInWorkspaceOptions{WorkspaceID: id}) + if err != nil { + t.Errorf("Error reading file: %v", err) + } + + if len(revisions) != 1 { + t.Errorf("Unexpected number of revisions: %d", len(revisions)) + } + + // Ensure we cannot write a new file with the zero-th revision ID + err = g.WriteFileInWorkspace(context.Background(), "test.txt", []byte("test4"), WriteFileInWorkspaceOptions{WorkspaceID: id, LatestRevisionID: revisions[0].RevisionID}) + if err == nil || !errors.As(err, &ce) { + t.Fatalf("Error creating file: %v", err) + } + + // Ensure we can write a new file after deleting the latest revision + err = g.WriteFileInWorkspace(context.Background(), "test.txt", []byte("test4"), WriteFileInWorkspaceOptions{WorkspaceID: id, LatestRevisionID: latestRevisionID}) + if err != nil { + t.Fatalf("Error creating file: %v", err) + } + + err = g.DeleteFileInWorkspace(context.Background(), "test.txt", DeleteFileInWorkspaceOptions{WorkspaceID: id}) + if err != nil { + t.Errorf("Error deleting file: %v", err) + } +} + +func TestDisableCreatingRevisionsForFileInWorkspaceS3(t *testing.T) { + if os.Getenv("AWS_ACCESS_KEY_ID") == "" || os.Getenv("AWS_SECRET_ACCESS_KEY") == "" || os.Getenv("WORKSPACE_PROVIDER_S3_BUCKET") == "" { + t.Skip("Skipping test because AWS credentials are not set") + } + + id, err := g.CreateWorkspace(context.Background(), "s3") + if err != nil { + t.Fatalf("Error creating workspace: %v", err) + } + + t.Cleanup(func() { + err := g.DeleteWorkspace(context.Background(), id) + if err != nil { + t.Errorf("Error deleting workspace: %v", err) + } + }) + + err = g.WriteFileInWorkspace(context.Background(), "test.txt", []byte("test0"), WriteFileInWorkspaceOptions{WorkspaceID: id}) + if err != nil { + t.Fatalf("Error creating file: %v", err) + } + + err = g.WriteFileInWorkspace(context.Background(), "test.txt", []byte("test1"), WriteFileInWorkspaceOptions{WorkspaceID: id, CreateRevision: new(bool)}) + if err != nil { + t.Fatalf("Error creating file: %v", err) + } + + err = g.WriteFileInWorkspace(context.Background(), "test.txt", []byte("test2"), WriteFileInWorkspaceOptions{WorkspaceID: id}) + if err != nil { + t.Fatalf("Error creating file: %v", err) + } + + revisions, err := g.ListRevisionsForFileInWorkspace(context.Background(), "test.txt", ListRevisionsForFileInWorkspaceOptions{WorkspaceID: id}) + if err != nil { + t.Errorf("Error reading file: %v", err) + } + + if len(revisions) != 1 { + t.Errorf("Unexpected number of revisions: %d", len(revisions)) + } + + for i, rev := range revisions { + if rev.WorkspaceID != id { + t.Errorf("Unexpected file workspace ID: %v", rev.WorkspaceID) + } + + if rev.Name != "test.txt" { + t.Errorf("Unexpected file name: %s", rev.Name) + } + + if rev.Size != 5 { + t.Errorf("Unexpected file size: %d", rev.Size) + } + + if rev.ModTime.IsZero() { + t.Errorf("Unexpected file mod time: %v", rev.ModTime) + } + + if rev.MimeType != "text/plain" { + t.Errorf("Unexpected file mime type: %s", rev.MimeType) + } + + if rev.RevisionID != fmt.Sprintf("%d", i+1) { + t.Errorf("Unexpected revision ID: %s", rev.RevisionID) + } + } + + err = g.DeleteRevisionForFileInWorkspace(context.Background(), "test.txt", "1", DeleteRevisionForFileInWorkspaceOptions{WorkspaceID: id}) + if err != nil { + t.Errorf("Error deleting revision for file: %v", err) + } + + revisions, err = g.ListRevisionsForFileInWorkspace(context.Background(), "test.txt", ListRevisionsForFileInWorkspaceOptions{WorkspaceID: id}) + if err != nil { + t.Errorf("Error reading file: %v", err) + } + + if len(revisions) != 0 { + t.Errorf("Unexpected number of revisions: %d", len(revisions)) + } + + err = g.DeleteFileInWorkspace(context.Background(), "test.txt", DeleteFileInWorkspaceOptions{WorkspaceID: id}) + if err != nil { + t.Errorf("Error deleting file: %v", err) + } +} + +func TestLsComplexWorkspaceS3(t *testing.T) { + if os.Getenv("AWS_ACCESS_KEY_ID") == "" || os.Getenv("AWS_SECRET_ACCESS_KEY") == "" || os.Getenv("WORKSPACE_PROVIDER_S3_BUCKET") == "" { + t.Skip("Skipping test because AWS credentials are not set") + } + + id, err := g.CreateWorkspace(context.Background(), "s3") + if err != nil { + t.Fatalf("Error creating workspace: %v", err) + } + + t.Cleanup(func() { + err := g.DeleteWorkspace(context.Background(), id) + if err != nil { + t.Errorf("Error deleting workspace: %v", err) + } + }) + + err = g.WriteFileInWorkspace(context.Background(), "test/test1.txt", []byte("hello1"), WriteFileInWorkspaceOptions{WorkspaceID: id}) + if err != nil { + t.Fatalf("Error creating file: %v", err) + } + + err = g.WriteFileInWorkspace(context.Background(), "test1/test2.txt", []byte("hello2"), WriteFileInWorkspaceOptions{WorkspaceID: id}) + if err != nil { + t.Fatalf("Error creating file: %v", err) + } + + err = g.WriteFileInWorkspace(context.Background(), "test1/test3.txt", []byte("hello3"), WriteFileInWorkspaceOptions{WorkspaceID: id}) + if err != nil { + t.Fatalf("Error creating file: %v", err) + } + + err = g.WriteFileInWorkspace(context.Background(), ".hidden.txt", []byte("hidden"), WriteFileInWorkspaceOptions{WorkspaceID: id}) + if err != nil { + t.Fatalf("Error creating hidden file: %v", err) + } + + // List all files + content, err := g.ListFilesInWorkspace(context.Background(), ListFilesInWorkspaceOptions{WorkspaceID: id}) + if err != nil { + t.Fatalf("Error listing files: %v", err) + } + + if len(content) != 4 { + t.Errorf("Unexpected number of files: %d", len(content)) + } + + // List files in subdirectory + content, err = g.ListFilesInWorkspace(context.Background(), ListFilesInWorkspaceOptions{WorkspaceID: id, Prefix: "test1"}) + if err != nil { + t.Fatalf("Error listing files: %v", err) + } + + if len(content) != 2 { + t.Errorf("Unexpected number of files: %d", len(content)) + } + + // Remove all files with test1 prefix + err = g.RemoveAll(context.Background(), RemoveAllOptions{WorkspaceID: id, WithPrefix: "test1"}) + if err != nil { + t.Fatalf("Error removing files: %v", err) + } + + // List files in subdirectory + content, err = g.ListFilesInWorkspace(context.Background(), ListFilesInWorkspaceOptions{WorkspaceID: id}) + if err != nil { + t.Fatalf("Error listing files: %v", err) + } + + if len(content) != 2 { + t.Errorf("Unexpected number of files: %d", len(content)) + } +}