Skip to content

Commit

Permalink
fix: fix base64 images buffer billing
Browse files Browse the repository at this point in the history
  • Loading branch information
zmh-program committed Feb 15, 2024
1 parent b95bf94 commit c5ab97a
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 15 deletions.
4 changes: 2 additions & 2 deletions adapter/azure/processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@ func formatMessages(props *ChatProps) interface{} {
images := utils.EachNotNil[string, MessageContent](urls, func(url string) *MessageContent {
obj, err := utils.NewImage(url)
if err != nil {
globals.Info(fmt.Sprintf("cannot process image: %s (source: %s)", err.Error(), url))
globals.Info(fmt.Sprintf("cannot process image: %s (source: %s)", err.Error(), utils.Extract(url, 24, "...")))
return nil
}

props.Buffer.AddImage(obj, url)
props.Buffer.AddImage(obj)

return &MessageContent{
Type: "image_url",
Expand Down
5 changes: 2 additions & 3 deletions adapter/chatgpt/processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,9 @@ func formatMessages(props *ChatProps) interface{} {
content, urls := utils.ExtractImages(message.Content, true)
images := utils.EachNotNil[string, MessageContent](urls, func(url string) *MessageContent {
obj, err := utils.NewImage(url)
props.Buffer.AddImage(obj)
if err != nil {
globals.Info(fmt.Sprintf("cannot process image: %s (source: %s)", err.Error(), url))
} else {
props.Buffer.AddImage(obj, url)
globals.Info(fmt.Sprintf("cannot process image: %s (source: %s)", err.Error(), utils.Extract(url, 24, "...")))
}

return &MessageContent{
Expand Down
42 changes: 35 additions & 7 deletions utils/buffer.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,37 @@ type Buffer struct {
Charge Charge `json:"-"`
}

func initInputToken(charge Charge, model string, history []globals.Message) float32 {
if globals.IsOpenAIVisionModels(model) {
for _, message := range history {
if message.Role == globals.User {
content, _ := ExtractImages(message.Content, true)
message.Content = content
}
}

history = Each(history, func(message globals.Message) globals.Message {
if message.Role == globals.User {
raw, _ := ExtractImages(message.Content, true)
return globals.Message{
Role: message.Role,
Content: raw,
ToolCalls: message.ToolCalls,
ToolCallId: message.ToolCallId,
}
}

return message
})
}

return CountInputToken(charge, model, history)
}

func NewBuffer(model string, history []globals.Message, charge Charge) *Buffer {
return &Buffer{
Model: model,
Quota: CountInputToken(charge, model, history),
Quota: initInputToken(charge, model, history),
History: history,
Charge: charge,
}
Expand All @@ -58,14 +85,15 @@ func (b *Buffer) GetChunk() string {
return b.Latest
}

func (b *Buffer) AddImage(image *Image, source string) {
b.Images = append(b.Images, *image)
func (b *Buffer) AddImage(image *Image) {
if image != nil {
b.Images = append(b.Images, *image)
}

if b.Charge.IsBillingType(globals.TokenBilling) {
b.Quota += float32(image.CountTokens(b.Model)) * b.Charge.GetInput()

// remove tokens from image source
b.Quota -= CountInputToken(b.Charge, b.Model, []globals.Message{{Content: source, Role: globals.User}})
if image != nil {
b.Quota += float32(image.CountTokens(b.Model)) * b.Charge.GetInput()
}
}
}

Expand Down
7 changes: 4 additions & 3 deletions utils/image.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@ type Images []Image

func NewImage(url string) (*Image, error) {
if strings.HasPrefix(url, "data:image/") {
data := strings.Split(url, ",")
if len(data) != 2 {
data := SafeSplit(url, ",", 2)
if data[1] == "" {
return nil, nil
}

decoded, err := Base64Decode(data[1])
if err != nil {
return nil, err
Expand Down Expand Up @@ -78,7 +79,7 @@ func ConvertToBase64(url string) (string, error) {
}
return data[1], nil
}

res, err := http.Get(url)
if err != nil {
return "", err
Expand Down

0 comments on commit c5ab97a

Please sign in to comment.