forked from sashabaranov/go-openai
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathclient.go
138 lines (118 loc) · 3.79 KB
/
client.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
package openai
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
)
// Client is OpenAI GPT-3 API client.
type Client struct {
config ClientConfig
requestBuilder requestBuilder
createFormBuilder func(io.Writer) formBuilder
}
// NewClient creates new OpenAI API client.
func NewClient(authToken string) *Client {
config := DefaultConfig(authToken)
return NewClientWithConfig(config)
}
// NewClientWithConfig creates new OpenAI API client for specified config.
func NewClientWithConfig(config ClientConfig) *Client {
return &Client{
config: config,
requestBuilder: newRequestBuilder(),
createFormBuilder: func(body io.Writer) formBuilder {
return newFormBuilder(body)
},
}
}
// NewOrgClient creates new OpenAI API client for specified Organization ID.
//
// Deprecated: Please use NewClientWithConfig.
func NewOrgClient(authToken, org string) *Client {
config := DefaultConfig(authToken)
config.OrgID = org
return NewClientWithConfig(config)
}
func (c *Client) sendRequest(req *http.Request, v interface{}) error {
req.Header.Set("Accept", "application/json; charset=utf-8")
// Azure API Key authentication
if c.config.APIType == APITypeAzure {
req.Header.Set(AzureAPIKeyHeader, c.config.authToken)
} else {
// OpenAI or Azure AD authentication
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.authToken))
}
// Check whether Content-Type is already set, Upload Files API requires
// Content-Type == multipart/form-data
contentType := req.Header.Get("Content-Type")
if contentType == "" {
req.Header.Set("Content-Type", "application/json; charset=utf-8")
}
if len(c.config.OrgID) > 0 {
req.Header.Set("OpenAI-Organization", c.config.OrgID)
}
res, err := c.config.HTTPClient.Do(req)
if err != nil {
return err
}
defer res.Body.Close()
if res.StatusCode < http.StatusOK || res.StatusCode >= http.StatusBadRequest {
return c.handleErrorResp(res)
}
if v != nil {
if err = json.NewDecoder(res.Body).Decode(v); err != nil {
return err
}
}
return nil
}
func (c *Client) fullURL(suffix string) string {
// /openai/deployments/{engine}/chat/completions?api-version={api_version}
if c.config.APIType == APITypeAzure || c.config.APIType == APITypeAzureAD {
baseURL := c.config.BaseURL
baseURL = strings.TrimRight(baseURL, "/")
return fmt.Sprintf("%s/%s/%s/%s%s?api-version=%s",
baseURL, azureAPIPrefix, azureDeploymentsPrefix, c.config.Engine, suffix, c.config.APIVersion)
}
// c.config.APIType == APITypeOpenAI || c.config.APIType == ""
return fmt.Sprintf("%s%s", c.config.BaseURL, suffix)
}
func (c *Client) newStreamRequest(
ctx context.Context,
method string,
urlSuffix string,
body any) (*http.Request, error) {
req, err := c.requestBuilder.build(ctx, method, c.fullURL(urlSuffix), body)
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "text/event-stream")
req.Header.Set("Cache-Control", "no-cache")
req.Header.Set("Connection", "keep-alive")
// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/reference#authentication
// Azure API Key authentication
if c.config.APIType == APITypeAzure {
req.Header.Set(AzureAPIKeyHeader, c.config.authToken)
} else {
// OpenAI or Azure AD authentication
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.authToken))
}
return req, nil
}
func (c *Client) handleErrorResp(resp *http.Response) error {
var errRes ErrorResponse
err := json.NewDecoder(resp.Body).Decode(&errRes)
if err != nil || errRes.Error == nil {
reqErr := RequestError{
HTTPStatusCode: resp.StatusCode,
Err: err,
}
return fmt.Errorf("error, %w", &reqErr)
}
errRes.Error.HTTPStatusCode = resp.StatusCode
return fmt.Errorf("error, status code: %d, message: %w", resp.StatusCode, errRes.Error)
}