Skip to content

Commit

Permalink
feat: merge pr coaidev#211 from @Sh1n3zZ: define sending defaults bas…
Browse files Browse the repository at this point in the history
…ed on different device types (coaidev#204); optimize tiktoken performance (coaidev#191) and function calling fields
  • Loading branch information
zmh-program authored Jun 22, 2024
2 parents 0ddd2d2 + 78ff6a1 commit a51dc7f
Show file tree
Hide file tree
Showing 7 changed files with 51 additions and 34 deletions.
11 changes: 7 additions & 4 deletions adapter/skylark/formatter.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package skylark
import (
"chat/globals"
"chat/utils"

structpb "github.com/golang/protobuf/ptypes/struct"
"github.com/volcengine/volc-sdk-golang/service/maas/models/api"
)
Expand All @@ -20,19 +21,21 @@ func getFunctionCall(calls *globals.ToolCalls) *api.FunctionCall {
}

func getType(p globals.ToolProperty) string {
if p.Type == nil {
t, ok := p["type"]
if !ok {
return "string"
}

return *p.Type
return t.(string)
}

func getDescription(p globals.ToolProperty) string {
if p.Description == nil {
desc, ok := p["description"]
if !ok {
return ""
}

return *p.Description
return desc.(string)
}

func getValue(p globals.ToolProperty) *structpb.Value {
Expand Down
5 changes: 3 additions & 2 deletions app/src/store/settings.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@ import {
setNumberMemory,
} from "@/utils/memory.ts";
import { RootState } from "@/store/index.ts";
import { isMobile } from "@/utils/device";

export const sendKeys = ["Ctrl + Enter", "Enter"];
export const sendKeys = isMobile() ? ["Ctrl + Enter", "Enter"] : ["Enter", "Ctrl + Enter"];
export const initialSettings = {
context: true,
align: false,
history: 8,
sender: false,
sender: isMobile(), // Defaults to true (Enter) in the case of mobile and false (Ctrl + Enter) on PCs
max_tokens: 2000,
temperature: 0.6,
top_p: 1,
Expand Down
5 changes: 3 additions & 2 deletions globals/logger.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@ package globals

import (
"fmt"
"strings"

"github.com/natefinch/lumberjack"
"github.com/sirupsen/logrus"
"github.com/spf13/viper"
"strings"
)

const DefaultLoggerFile = "chatnio.log"
Expand All @@ -25,7 +26,7 @@ func (l *AppLogger) Format(entry *logrus.Entry) ([]byte, error) {
)

if !viper.GetBool("log.ignore_console") {
fmt.Println(data)
fmt.Print(data)
}

return []byte(data), nil
Expand Down
5 changes: 3 additions & 2 deletions globals/tools.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ type ToolFunction struct {
type ToolParameters struct {
Type string `json:"type"`
Properties ToolProperties `json:"properties"`
Required []string `json:"required"`
Required *[]string `json:"required,omitempty"`
}

type ToolProperties map[string]ToolProperty
Expand All @@ -25,7 +25,8 @@ type ToolProperties map[string]ToolProperty

type JsonSchemaType any
type JSONSchemaDefinition any
type ToolProperty struct {
type ToolProperty map[string]interface{}
type DetailToolProperty struct {
Type *string `json:"type,omitempty"`
Enum *[]JsonSchemaType `json:"enum,omitempty"`
Const *JsonSchemaType `json:"const,omitempty"`
Expand Down
4 changes: 2 additions & 2 deletions manager/chat_completions.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ func sendTranshipmentResponse(c *gin.Context, form RelayForm, messages []globals
},
Usage: Usage{
PromptTokens: buffer.CountInputToken(),
CompletionTokens: buffer.CountOutputToken(),
CompletionTokens: buffer.CountOutputToken(false),
TotalTokens: buffer.CountToken(),
},
Quota: utils.Multi[*float32](form.Official, nil, utils.ToPtr(buffer.GetQuota())),
Expand Down Expand Up @@ -205,7 +205,7 @@ func getStreamTranshipmentForm(id string, created int64, form RelayForm, data *g
},
Usage: Usage{
PromptTokens: buffer.CountInputToken(),
CompletionTokens: buffer.CountOutputToken(),
CompletionTokens: buffer.CountOutputToken(true),
TotalTokens: buffer.CountToken(),
},
Quota: utils.Multi[*float32](form.Official, nil, utils.ToPtr(buffer.GetQuota())),
Expand Down
21 changes: 11 additions & 10 deletions utils/buffer.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ func initInputToken(model string, history []globals.Message) int {
})
}

return CountTokenPrice(history, model)
return NumTokensFromMessages(history, model, false)
}

func NewBuffer(model string, history []globals.Message, charge Charge) *Buffer {
Expand All @@ -79,7 +79,7 @@ func (b *Buffer) GetCursor() int {
}

func (b *Buffer) GetQuota() float32 {
return b.Quota + CountOutputToken(b.Charge, b.Model, b.ReadTimes())
return b.Quota + CountOutputToken(b.Charge, b.CountOutputToken(false))
}

func (b *Buffer) Write(data string) string {
Expand Down Expand Up @@ -197,11 +197,6 @@ func (b *Buffer) IsFunctionCalling() bool {
return b.FunctionCall != nil || b.ToolCalls != nil
}

func (b *Buffer) WriteBytes(data []byte) []byte {
b.Write(string(data))
return data
}

func (b *Buffer) IsEmpty() bool {
return b.Cursor == 0 && !b.IsFunctionCalling()
}
Expand Down Expand Up @@ -241,10 +236,16 @@ func (b *Buffer) CountInputToken() int {
return b.InputTokens
}

func (b *Buffer) CountOutputToken() int {
return b.ReadTimes() * GetWeightByModel(b.Model)
func (b *Buffer) CountOutputToken(running bool) int {
if running {
// performance optimization:
// if the buffer is still running, the output token counted using the times instead
return b.Times
}

return NumTokensFromResponse(b.Read(), b.Model)
}

func (b *Buffer) CountToken() int {
return b.CountInputToken() + b.CountOutputToken()
return b.CountInputToken() + b.CountOutputToken(false)
}
34 changes: 22 additions & 12 deletions utils/tokenizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@ package utils
import (
"chat/globals"
"fmt"
"github.com/pkoukk/tiktoken-go"
"strings"

"github.com/pkoukk/tiktoken-go"
)

// Using https://github.com/pkoukk/tiktoken-go
Expand Down Expand Up @@ -45,9 +46,10 @@ func GetWeightByModel(model string) int {
}
}
}
func NumTokensFromMessages(messages []globals.Message, model string) (tokens int) {
func NumTokensFromMessages(messages []globals.Message, model string, responseType bool) (tokens int) {
tokensPerMessage := GetWeightByModel(model)
tkm, err := tiktoken.EncodingForModel(model)

if err != nil {
// the method above was deprecated, use the recall method instead
// can not encode messages, use length of messages as a proxy for number of tokens
Expand All @@ -59,25 +61,33 @@ func NumTokensFromMessages(messages []globals.Message, model string) (tokens int
if globals.DebugMode {
globals.Debug(fmt.Sprintf("[tiktoken] error encoding messages: %s (model: %s), using default model instead", err, model))
}
return NumTokensFromMessages(messages, globals.GPT3Turbo0613)
return NumTokensFromMessages(messages, globals.GPT3Turbo0613, responseType)
}

for _, message := range messages {
tokens +=
len(tkm.Encode(message.Content, nil, nil)) +
len(tkm.Encode(message.Role, nil, nil)) +
tokensPerMessage
tokens += len(tkm.Encode(message.Content, nil, nil))

if !responseType {
tokens += len(tkm.Encode(message.Role, nil, nil)) + tokensPerMessage
}
}

if !responseType {
tokens += 3 // every reply is primed with <|start|>assistant<|message|>
}
tokens += 3 // every reply is primed with <|start|>assistant<|message|>

if globals.DebugMode {
globals.Debug(fmt.Sprintf("[tiktoken] num tokens from messages: %d (tokens per message: %d, model: %s)", tokens, tokensPerMessage, model))
}
return tokens
}

func CountTokenPrice(messages []globals.Message, model string) int {
return NumTokensFromMessages(messages, model) * GetWeightByModel(model)
func NumTokensFromResponse(response string, model string) int {
if len(response) == 0 {
return 0
}

return NumTokensFromMessages([]globals.Message{{Content: response}}, model, true)
}

func CountInputQuota(charge Charge, token int) float32 {
Expand All @@ -88,10 +98,10 @@ func CountInputQuota(charge Charge, token int) float32 {
return 0
}

func CountOutputToken(charge Charge, model string, token int) float32 {
func CountOutputToken(charge Charge, token int) float32 {
switch charge.GetType() {
case globals.TokenBilling:
return float32(token*GetWeightByModel(model)) / 1000 * charge.GetOutput()
return float32(token) / 1000 * charge.GetOutput()
case globals.TimesBilling:
return charge.GetOutput()
default:
Expand Down

0 comments on commit a51dc7f

Please sign in to comment.