Skip to content

Commit

Permalink
fix: nil pointer dereference error when carrying an image to a conv…
Browse files Browse the repository at this point in the history
…ersation (coaidev#221)
  • Loading branch information
Sh1n3zZ committed Jul 1, 2024
1 parent 9cb9580 commit 576213d
Show file tree
Hide file tree
Showing 5 changed files with 137 additions and 13 deletions.
15 changes: 9 additions & 6 deletions channel/system.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,13 @@ type SearchState struct {
}

type commonState struct {
Article []string `json:"article" mapstructure:"article"`
Generation []string `json:"generation" mapstructure:"generation"`
Cache []string `json:"cache" mapstructure:"cache"`
Expire int64 `json:"expire" mapstructure:"expire"`
Size int64 `json:"size" mapstructure:"size"`
ImageStore bool `json:"image_store" mapstructure:"imagestore"`
Article []string `json:"article" mapstructure:"article"`
Generation []string `json:"generation" mapstructure:"generation"`
Cache []string `json:"cache" mapstructure:"cache"`
Expire int64 `json:"expire" mapstructure:"expire"`
Size int64 `json:"size" mapstructure:"size"`
ImageStore bool `json:"image_store" mapstructure:"imagestore"`
PromptStore bool `json:"prompt_store" mapstructure:"promptstore"`
}

type SystemConfig struct {
Expand Down Expand Up @@ -114,6 +115,8 @@ func (c *SystemConfig) Load() {
globals.CacheAcceptedSize = c.GetCacheAcceptedSize()
globals.AcceptImageStore = c.AcceptImageStore()

globals.AcceptPromptStore = c.Common.PromptStore

if c.General.PWAManifest == "" {
c.General.PWAManifest = utils.ReadPWAManifest()
}
Expand Down
6 changes: 6 additions & 0 deletions globals/constant.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,9 @@ const (
HttpsProxyType
Socks5ProxyType
)

const (
WebTokenType = "web"
ApiTokenType = "api"
SystemToken = "system"
)
1 change: 1 addition & 0 deletions globals/variables.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ var CacheAcceptedModels []string
var CacheAcceptedExpire int64
var CacheAcceptedSize int64
var AcceptImageStore bool
var AcceptPromptStore bool
var CloseRegistration bool
var CloseRelay bool

Expand Down
106 changes: 99 additions & 7 deletions utils/buffer.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ package utils

import (
"chat/globals"
"fmt"
"strings"
"time"
)

type Charge interface {
Expand All @@ -28,7 +30,11 @@ type Buffer struct {
ToolCalls *globals.ToolCalls `json:"tool_calls"`
ToolCallsCursor int `json:"tool_calls_cursor"`
FunctionCall *globals.FunctionCall `json:"function_call"`
StartTime *time.Time `json:"-"`
Prompts string `json:"prompts"`
TokenName string `json:"-"`
Charge Charge `json:"-"`
VisionRecall bool `json:"-"`
}

func initInputToken(model string, history []globals.Message) int {
Expand Down Expand Up @@ -71,6 +77,7 @@ func NewBuffer(model string, history []globals.Message, charge Charge) *Buffer {
FunctionCall: nil,
ToolCalls: nil,
ToolCallsCursor: 0,
StartTime: ToPtr(time.Now()),
}
}

Expand All @@ -79,6 +86,11 @@ func (b *Buffer) GetCursor() int {
}

func (b *Buffer) GetQuota() float32 {
return b.Quota + CountOutputToken(b.Charge, b.CountOutputToken(true))
}

func (b *Buffer) GetRecordQuota() float32 {
// end of the buffer, the output token is counted using the times
return b.Quota + CountOutputToken(b.Charge, b.CountOutputToken(false))
}

Expand Down Expand Up @@ -106,15 +118,23 @@ func (b *Buffer) GetChunk() string {
return b.Latest
}

func (b *Buffer) InitVisionRecall() {
// set the vision recall flag to true to prevent the buffer from adding more images of retrying the channel
b.VisionRecall = true
}

func (b *Buffer) AddImage(image *Image) {
if image != nil {
b.Images = append(b.Images, *image)
if image == nil || b.VisionRecall {
return
}

b.Images = append(b.Images, *image)

tokens := image.CountTokens(b.Model)
b.InputTokens += tokens

if b.Charge.IsBillingType(globals.TokenBilling) {
if image != nil {
b.Quota += float32(image.CountTokens(b.Model)) * b.Charge.GetInput()
}
b.Quota += float32(tokens) / 1000 * b.Charge.GetInput()
}
}

Expand Down Expand Up @@ -145,6 +165,13 @@ func hitTool(tool globals.ToolCall, tools globals.ToolCalls) (int, *globals.Tool
return 0, nil
}

func appendTool(tool globals.ToolCall, chunk globals.ToolCall) string {
from := ToString(tool.Function.Arguments)
to := ToString(chunk.Function.Arguments)

return from + to
}

func mixTools(source *globals.ToolCalls, target *globals.ToolCalls) *globals.ToolCalls {
if source == nil {
return target
Expand All @@ -157,7 +184,7 @@ func mixTools(source *globals.ToolCalls, target *globals.ToolCalls) *globals.Too
idx, hit := hitTool(tool, tools)

if hit != nil {
tools[idx].Function.Arguments += tool.Function.Arguments
tools[idx].Function.Arguments = appendTool(tools[idx], tool)
} else {
tools = append(tools, tool)
}
Expand Down Expand Up @@ -209,6 +236,27 @@ func (b *Buffer) GetCharge() Charge {
return b.Charge
}

func (b *Buffer) ToChargeInfo() string {
switch b.Charge.GetType() {
case globals.TokenBilling:
return fmt.Sprintf(
"input tokens: %0.4f quota / 1k tokens\n"+
"output tokens: %0.4f quota / 1k tokens\n",
b.Charge.GetInput(), b.Charge.GetOutput(),
)
case globals.TimesBilling:
return fmt.Sprintf("%f quota per request\n", b.Charge.GetLimit())
case globals.NonBilling:
return "no cost"
}

return ""
}

func (b *Buffer) SetPrompts(prompts interface{}) {
b.Prompts = ToString(prompts)
}

func (b *Buffer) Read() string {
return b.Data
}
Expand Down Expand Up @@ -247,5 +295,49 @@ func (b *Buffer) CountOutputToken(running bool) int {
}

func (b *Buffer) CountToken() int {
return b.CountInputToken() + b.CountOutputToken(false)
return b.CountInputToken() + b.CountOutputToken(true)
}

func (b *Buffer) GetDuration() float32 {
if b.StartTime == nil {
return 0
}

return float32(time.Since(*b.StartTime).Seconds())
}

func (b *Buffer) GetStartTime() *time.Time {
return b.StartTime
}

func (b *Buffer) GetPrompts() string {
return b.Prompts
}

func (b *Buffer) GetTokenName() string {
if len(b.TokenName) == 0 {
return globals.WebTokenType
}

return b.TokenName
}

func (b *Buffer) SetTokenName(tokenName string) {
b.TokenName = tokenName
}

func (b *Buffer) GetRecordPrompts() string {
if !globals.AcceptPromptStore {
return ""
}

return b.GetPrompts()
}

func (b *Buffer) GetRecordResponsePrompts() string {
if !globals.AcceptPromptStore {
return ""
}

return b.Read()
}
22 changes: 22 additions & 0 deletions utils/net.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,28 @@ func Post(uri string, headers map[string]string, body interface{}, config ...glo
return data, err
}

func ToString(data interface{}) string {
switch v := data.(type) {
case string:
return v
case int, int8, int16, int32, int64:
return fmt.Sprintf("%d", v)
case uint, uint8, uint16, uint32, uint64:
return fmt.Sprintf("%d", v)
case float32, float64:
return fmt.Sprintf("%f", v)
case bool:
return fmt.Sprintf("%t", v)
default:
data := Marshal(data)
if len(data) > 0 {
return data
}

return fmt.Sprintf("%v", data)
}
}

func PostRaw(uri string, headers map[string]string, body interface{}, config ...globals.ProxyConfig) (data string, err error) {
buffer, err := HttpRaw(uri, http.MethodPost, headers, ConvertBody(body), config)
if err != nil {
Expand Down

0 comments on commit 576213d

Please sign in to comment.