forked from sashabaranov/go-openai
-
Notifications
You must be signed in to change notification settings - Fork 0
/
stream_test.go
142 lines (124 loc) · 4.2 KB
/
stream_test.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
package gogpt_test
import (
. "github.com/sashabaranov/go-gpt3"
"github.com/sashabaranov/go-gpt3/internal/test"
"context"
"errors"
"io"
"net/http"
"net/http/httptest"
"testing"
)
func TestCreateCompletionStream(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
// Send test responses
dataBytes := []byte{}
dataBytes = append(dataBytes, []byte("event: message\n")...)
//nolint:lll
data := `{"id":"1","object":"completion","created":1598069254,"model":"text-davinci-002","choices":[{"text":"response1","finish_reason":"max_tokens"}]}`
dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...)
dataBytes = append(dataBytes, []byte("event: message\n")...)
//nolint:lll
data = `{"id":"2","object":"completion","created":1598069255,"model":"text-davinci-002","choices":[{"text":"response2","finish_reason":"max_tokens"}]}`
dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...)
dataBytes = append(dataBytes, []byte("event: done\n")...)
dataBytes = append(dataBytes, []byte("data: [DONE]\n\n")...)
_, err := w.Write(dataBytes)
if err != nil {
t.Errorf("Write error: %s", err)
}
}))
defer server.Close()
// Client portion of the test
config := DefaultConfig(test.GetTestToken())
config.BaseURL = server.URL + "/v1"
config.HTTPClient.Transport = &tokenRoundTripper{
test.GetTestToken(),
http.DefaultTransport,
}
client := NewClientWithConfig(config)
ctx := context.Background()
request := CompletionRequest{
Prompt: "Ex falso quodlibet",
Model: "text-davinci-002",
MaxTokens: 10,
Stream: true,
}
stream, err := client.CreateCompletionStream(ctx, request)
if err != nil {
t.Errorf("CreateCompletionStream returned error: %v", err)
}
defer stream.Close()
expectedResponses := []CompletionResponse{
{
ID: "1",
Object: "completion",
Created: 1598069254,
Model: "text-davinci-002",
Choices: []CompletionChoice{{Text: "response1", FinishReason: "max_tokens"}},
},
{
ID: "2",
Object: "completion",
Created: 1598069255,
Model: "text-davinci-002",
Choices: []CompletionChoice{{Text: "response2", FinishReason: "max_tokens"}},
},
}
for ix, expectedResponse := range expectedResponses {
receivedResponse, streamErr := stream.Recv()
if streamErr != nil {
t.Errorf("stream.Recv() failed: %v", streamErr)
}
if !compareResponses(expectedResponse, receivedResponse) {
t.Errorf("Stream response %v is %v, expected %v", ix, receivedResponse, expectedResponse)
}
}
_, streamErr := stream.Recv()
if !errors.Is(streamErr, io.EOF) {
t.Errorf("stream.Recv() did not return EOF in the end: %v", streamErr)
}
}
// A "tokenRoundTripper" is a struct that implements the RoundTripper
// interface, specifically to handle the authentication token by adding a token
// to the request header. We need this because the API requires that each
// request include a valid API token in the headers for authentication and
// authorization.
type tokenRoundTripper struct {
token string
fallback http.RoundTripper
}
// RoundTrip takes an *http.Request as input and returns an
// *http.Response and an error.
//
// It is expected to use the provided request to create a connection to an HTTP
// server and return the response, or an error if one occurred. The returned
// Response should have its Body closed. If the RoundTrip method returns an
// error, the Client's Get, Head, Post, and PostForm methods return the same
// error.
func (t *tokenRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
req.Header.Set("Authorization", "Bearer "+t.token)
return t.fallback.RoundTrip(req)
}
// Helper funcs.
func compareResponses(r1, r2 CompletionResponse) bool {
if r1.ID != r2.ID || r1.Object != r2.Object || r1.Created != r2.Created || r1.Model != r2.Model {
return false
}
if len(r1.Choices) != len(r2.Choices) {
return false
}
for i := range r1.Choices {
if !compareResponseChoices(r1.Choices[i], r2.Choices[i]) {
return false
}
}
return true
}
func compareResponseChoices(c1, c2 CompletionChoice) bool {
if c1.Text != c2.Text || c1.FinishReason != c2.FinishReason {
return false
}
return true
}