diff --git a/audio.go b/audio.go index 12c6ccc22..bf2365391 100644 --- a/audio.go +++ b/audio.go @@ -6,6 +6,8 @@ import ( "fmt" "net/http" "os" + + utils "github.com/sashabaranov/go-openai/internal" ) // Whisper Defines the models provided by OpenAI to use when processing audio with OpenAI. @@ -72,7 +74,7 @@ func (c *Client) callAudioAPI( if err != nil { return AudioResponse{}, err } - req.Header.Add("Content-Type", builder.formDataContentType()) + req.Header.Add("Content-Type", builder.FormDataContentType()) if request.HasJSONResponse() { err = c.sendRequest(req, &response) @@ -92,26 +94,26 @@ func (r AudioRequest) HasJSONResponse() bool { // audioMultipartForm creates a form with audio file contents and the name of the model to use for // audio processing. -func audioMultipartForm(request AudioRequest, b formBuilder) error { +func audioMultipartForm(request AudioRequest, b utils.FormBuilder) error { f, err := os.Open(request.FilePath) if err != nil { return fmt.Errorf("opening audio file: %w", err) } defer f.Close() - err = b.createFormFile("file", f) + err = b.CreateFormFile("file", f) if err != nil { return fmt.Errorf("creating form file: %w", err) } - err = b.writeField("model", request.Model) + err = b.WriteField("model", request.Model) if err != nil { return fmt.Errorf("writing model name: %w", err) } // Create a form field for the prompt (if provided) if request.Prompt != "" { - err = b.writeField("prompt", request.Prompt) + err = b.WriteField("prompt", request.Prompt) if err != nil { return fmt.Errorf("writing prompt: %w", err) } @@ -119,7 +121,7 @@ func audioMultipartForm(request AudioRequest, b formBuilder) error { // Create a form field for the format (if provided) if request.Format != "" { - err = b.writeField("response_format", string(request.Format)) + err = b.WriteField("response_format", string(request.Format)) if err != nil { return fmt.Errorf("writing format: %w", err) } @@ -127,7 +129,7 @@ func audioMultipartForm(request AudioRequest, b formBuilder) error { // Create a form field for the temperature (if provided) if request.Temperature != 0 { - err = b.writeField("temperature", fmt.Sprintf("%.2f", request.Temperature)) + err = b.WriteField("temperature", fmt.Sprintf("%.2f", request.Temperature)) if err != nil { return fmt.Errorf("writing temperature: %w", err) } @@ -135,12 +137,12 @@ func audioMultipartForm(request AudioRequest, b formBuilder) error { // Create a form field for the language (if provided) if request.Language != "" { - err = b.writeField("language", request.Language) + err = b.WriteField("language", request.Language) if err != nil { return fmt.Errorf("writing language: %w", err) } } // Close the multipart writer - return b.close() + return b.Close() } diff --git a/client.go b/client.go index 9579ba27b..c55166aa6 100644 --- a/client.go +++ b/client.go @@ -7,6 +7,8 @@ import ( "io" "net/http" "strings" + + utils "github.com/sashabaranov/go-openai/internal" ) // Client is OpenAI GPT-3 API client. @@ -14,7 +16,7 @@ type Client struct { config ClientConfig requestBuilder requestBuilder - createFormBuilder func(io.Writer) formBuilder + createFormBuilder func(io.Writer) utils.FormBuilder } // NewClient creates new OpenAI API client. @@ -28,8 +30,8 @@ func NewClientWithConfig(config ClientConfig) *Client { return &Client{ config: config, requestBuilder: newRequestBuilder(), - createFormBuilder: func(body io.Writer) formBuilder { - return newFormBuilder(body) + createFormBuilder: func(body io.Writer) utils.FormBuilder { + return utils.NewFormBuilder(body) }, } } diff --git a/files.go b/files.go index b701b9454..5667ec861 100644 --- a/files.go +++ b/files.go @@ -36,7 +36,7 @@ func (c *Client) CreateFile(ctx context.Context, request FileRequest) (file File var b bytes.Buffer builder := c.createFormBuilder(&b) - err = builder.writeField("purpose", request.Purpose) + err = builder.WriteField("purpose", request.Purpose) if err != nil { return } @@ -46,12 +46,12 @@ func (c *Client) CreateFile(ctx context.Context, request FileRequest) (file File return } - err = builder.createFormFile("file", fileData) + err = builder.CreateFormFile("file", fileData) if err != nil { return } - err = builder.close() + err = builder.Close() if err != nil { return } @@ -61,7 +61,7 @@ func (c *Client) CreateFile(ctx context.Context, request FileRequest) (file File return } - req.Header.Set("Content-Type", builder.formDataContentType()) + req.Header.Set("Content-Type", builder.FormDataContentType()) err = c.sendRequest(req, &file) diff --git a/files_test.go b/files_test.go index bb06498c8..56dbb414f 100644 --- a/files_test.go +++ b/files_test.go @@ -1,6 +1,7 @@ package openai //nolint:testpackage // testing private field import ( + . "github.com/sashabaranov/go-openai/internal" "github.com/sashabaranov/go-openai/internal/test" "github.com/sashabaranov/go-openai/internal/test/checks" @@ -85,7 +86,7 @@ func TestFileUploadWithFailingFormBuilder(t *testing.T) { config.BaseURL = "" client := NewClientWithConfig(config) mockBuilder := &mockFormBuilder{} - client.createFormBuilder = func(io.Writer) formBuilder { + client.createFormBuilder = func(io.Writer) FormBuilder { return mockBuilder } diff --git a/form_builder.go b/form_builder.go deleted file mode 100644 index 7fbb1643a..000000000 --- a/form_builder.go +++ /dev/null @@ -1,49 +0,0 @@ -package openai - -import ( - "io" - "mime/multipart" - "os" -) - -type formBuilder interface { - createFormFile(fieldname string, file *os.File) error - writeField(fieldname, value string) error - close() error - formDataContentType() string -} - -type defaultFormBuilder struct { - writer *multipart.Writer -} - -func newFormBuilder(body io.Writer) *defaultFormBuilder { - return &defaultFormBuilder{ - writer: multipart.NewWriter(body), - } -} - -func (fb *defaultFormBuilder) createFormFile(fieldname string, file *os.File) error { - fieldWriter, err := fb.writer.CreateFormFile(fieldname, file.Name()) - if err != nil { - return err - } - - _, err = io.Copy(fieldWriter, file) - if err != nil { - return err - } - return nil -} - -func (fb *defaultFormBuilder) writeField(fieldname, value string) error { - return fb.writer.WriteField(fieldname, value) -} - -func (fb *defaultFormBuilder) close() error { - return fb.writer.Close() -} - -func (fb *defaultFormBuilder) formDataContentType() string { - return fb.writer.FormDataContentType() -} diff --git a/image.go b/image.go index 21703bda7..87ffea25e 100644 --- a/image.go +++ b/image.go @@ -69,40 +69,40 @@ func (c *Client) CreateEditImage(ctx context.Context, request ImageEditRequest) builder := c.createFormBuilder(body) // image - err = builder.createFormFile("image", request.Image) + err = builder.CreateFormFile("image", request.Image) if err != nil { return } // mask, it is optional if request.Mask != nil { - err = builder.createFormFile("mask", request.Mask) + err = builder.CreateFormFile("mask", request.Mask) if err != nil { return } } - err = builder.writeField("prompt", request.Prompt) + err = builder.WriteField("prompt", request.Prompt) if err != nil { return } - err = builder.writeField("n", strconv.Itoa(request.N)) + err = builder.WriteField("n", strconv.Itoa(request.N)) if err != nil { return } - err = builder.writeField("size", request.Size) + err = builder.WriteField("size", request.Size) if err != nil { return } - err = builder.writeField("response_format", request.ResponseFormat) + err = builder.WriteField("response_format", request.ResponseFormat) if err != nil { return } - err = builder.close() + err = builder.Close() if err != nil { return } @@ -113,7 +113,7 @@ func (c *Client) CreateEditImage(ctx context.Context, request ImageEditRequest) return } - req.Header.Set("Content-Type", builder.formDataContentType()) + req.Header.Set("Content-Type", builder.FormDataContentType()) err = c.sendRequest(req, &response) return } @@ -133,27 +133,27 @@ func (c *Client) CreateVariImage(ctx context.Context, request ImageVariRequest) builder := c.createFormBuilder(body) // image - err = builder.createFormFile("image", request.Image) + err = builder.CreateFormFile("image", request.Image) if err != nil { return } - err = builder.writeField("n", strconv.Itoa(request.N)) + err = builder.WriteField("n", strconv.Itoa(request.N)) if err != nil { return } - err = builder.writeField("size", request.Size) + err = builder.WriteField("size", request.Size) if err != nil { return } - err = builder.writeField("response_format", request.ResponseFormat) + err = builder.WriteField("response_format", request.ResponseFormat) if err != nil { return } - err = builder.close() + err = builder.Close() if err != nil { return } @@ -165,7 +165,7 @@ func (c *Client) CreateVariImage(ctx context.Context, request ImageVariRequest) return } - req.Header.Set("Content-Type", builder.formDataContentType()) + req.Header.Set("Content-Type", builder.FormDataContentType()) err = c.sendRequest(req, &response) return } diff --git a/image_test.go b/image_test.go index 4a7dad58f..5cf6a268d 100644 --- a/image_test.go +++ b/image_test.go @@ -1,6 +1,7 @@ package openai //nolint:testpackage // testing private field import ( + utils "github.com/sashabaranov/go-openai/internal" "github.com/sashabaranov/go-openai/internal/test" "github.com/sashabaranov/go-openai/internal/test/checks" @@ -268,19 +269,19 @@ type mockFormBuilder struct { mockClose func() error } -func (fb *mockFormBuilder) createFormFile(fieldname string, file *os.File) error { +func (fb *mockFormBuilder) CreateFormFile(fieldname string, file *os.File) error { return fb.mockCreateFormFile(fieldname, file) } -func (fb *mockFormBuilder) writeField(fieldname, value string) error { +func (fb *mockFormBuilder) WriteField(fieldname, value string) error { return fb.mockWriteField(fieldname, value) } -func (fb *mockFormBuilder) close() error { +func (fb *mockFormBuilder) Close() error { return fb.mockClose() } -func (fb *mockFormBuilder) formDataContentType() string { +func (fb *mockFormBuilder) FormDataContentType() string { return "" } @@ -290,7 +291,7 @@ func TestImageFormBuilderFailures(t *testing.T) { client := NewClientWithConfig(config) mockBuilder := &mockFormBuilder{} - client.createFormBuilder = func(io.Writer) formBuilder { + client.createFormBuilder = func(io.Writer) utils.FormBuilder { return mockBuilder } ctx := context.Background() @@ -357,7 +358,7 @@ func TestVariImageFormBuilderFailures(t *testing.T) { client := NewClientWithConfig(config) mockBuilder := &mockFormBuilder{} - client.createFormBuilder = func(io.Writer) formBuilder { + client.createFormBuilder = func(io.Writer) utils.FormBuilder { return mockBuilder } ctx := context.Background() diff --git a/internal/form_builder.go b/internal/form_builder.go new file mode 100644 index 000000000..359dd7e2a --- /dev/null +++ b/internal/form_builder.go @@ -0,0 +1,49 @@ +package openai + +import ( + "io" + "mime/multipart" + "os" +) + +type FormBuilder interface { + CreateFormFile(fieldname string, file *os.File) error + WriteField(fieldname, value string) error + Close() error + FormDataContentType() string +} + +type DefaultFormBuilder struct { + writer *multipart.Writer +} + +func NewFormBuilder(body io.Writer) *DefaultFormBuilder { + return &DefaultFormBuilder{ + writer: multipart.NewWriter(body), + } +} + +func (fb *DefaultFormBuilder) CreateFormFile(fieldname string, file *os.File) error { + fieldWriter, err := fb.writer.CreateFormFile(fieldname, file.Name()) + if err != nil { + return err + } + + _, err = io.Copy(fieldWriter, file) + if err != nil { + return err + } + return nil +} + +func (fb *DefaultFormBuilder) WriteField(fieldname, value string) error { + return fb.writer.WriteField(fieldname, value) +} + +func (fb *DefaultFormBuilder) Close() error { + return fb.writer.Close() +} + +func (fb *DefaultFormBuilder) FormDataContentType() string { + return fb.writer.FormDataContentType() +} diff --git a/form_builder_test.go b/internal/form_builder_test.go similarity index 88% rename from form_builder_test.go rename to internal/form_builder_test.go index 78e2ec968..d3faf9982 100644 --- a/form_builder_test.go +++ b/internal/form_builder_test.go @@ -30,8 +30,8 @@ func TestFormBuilderWithFailingWriter(t *testing.T) { defer file.Close() defer os.Remove(file.Name()) - builder := newFormBuilder(&failingWriter{}) - err = builder.createFormFile("file", file) + builder := NewFormBuilder(&failingWriter{}) + err = builder.CreateFormFile("file", file) checks.ErrorIs(t, err, errMockFailingWriterError, "formbuilder should return error if writer fails") } @@ -47,8 +47,8 @@ func TestFormBuilderWithClosedFile(t *testing.T) { defer os.Remove(file.Name()) body := &bytes.Buffer{} - builder := newFormBuilder(body) - err = builder.createFormFile("file", file) + builder := NewFormBuilder(body) + err = builder.CreateFormFile("file", file) checks.HasError(t, err, "formbuilder should return error if file is closed") checks.ErrorIs(t, err, os.ErrClosed, "formbuilder should return error if file is closed") }