Skip to content

Commit

Permalink
Merge pull request ollama#2296 from ollama/mxyng/img-tags
Browse files Browse the repository at this point in the history
append image tags to user content
  • Loading branch information
mxyng authored Feb 1, 2024
2 parents fe3cbd0 + f376140 commit bfbf2f7
Show file tree
Hide file tree
Showing 6 changed files with 89 additions and 36 deletions.
9 changes: 3 additions & 6 deletions llm/dyn_ext_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,13 +161,10 @@ func newDynExtServer(library, model string, adapters, projectors []string, opts
func (llm *dynExtServer) Predict(ctx context.Context, predict PredictOpts, fn func(PredictResult)) error {
resp := newExtServerResp(128)
defer freeExtServerResp(resp)
var imageData []ImageData

if len(predict.Images) > 0 {
for cnt, i := range predict.Images {
imageData = append(imageData, ImageData{Data: i, ID: cnt})
}
slog.Info(fmt.Sprintf("loaded %d images", len(predict.Images)))
}
slog.Info(fmt.Sprintf("loaded %d images", len(imageData)))

request := map[string]any{
"prompt": predict.Prompt,
Expand All @@ -189,7 +186,7 @@ func (llm *dynExtServer) Predict(ctx context.Context, predict PredictOpts, fn fu
"penalize_nl": predict.Options.PenalizeNewline,
"seed": predict.Options.Seed,
"stop": predict.Options.Stop,
"image_data": imageData,
"image_data": predict.Images,
"cache_prompt": true,
}

Expand Down
2 changes: 1 addition & 1 deletion llm/llama.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ const maxRetries = 3
type PredictOpts struct {
Prompt string
Format string
Images []api.ImageData
Images []ImageData
Options api.Options
}

Expand Down
25 changes: 17 additions & 8 deletions server/images.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ type PromptVars struct {
Prompt string
Response string
First bool
Images []llm.ImageData
}

// extractParts extracts the parts of the template before and after the {{.Response}} node.
Expand Down Expand Up @@ -147,22 +148,21 @@ func (m *Model) PostResponseTemplate(p PromptVars) (string, error) {
}

type ChatHistory struct {
Prompts []PromptVars
CurrentImages []api.ImageData
LastSystem string
Prompts []PromptVars
LastSystem string
}

// ChatPrompts returns a list of formatted chat prompts from a list of messages
func (m *Model) ChatPrompts(msgs []api.Message) (*ChatHistory, error) {
// build the prompt from the list of messages
var currentImages []api.ImageData
lastSystem := m.System
currentVars := PromptVars{
First: true,
System: m.System,
}

prompts := []PromptVars{}
var images []llm.ImageData

for _, msg := range msgs {
switch strings.ToLower(msg.Role) {
Expand All @@ -179,8 +179,18 @@ func (m *Model) ChatPrompts(msgs []api.Message) (*ChatHistory, error) {
prompts = append(prompts, currentVars)
currentVars = PromptVars{}
}

currentVars.Prompt = msg.Content
currentImages = msg.Images
for i := range msg.Images {
id := len(images) + i
currentVars.Prompt += fmt.Sprintf(" [img-%d]", id)
currentVars.Images = append(currentVars.Images, llm.ImageData{
ID: id,
Data: msg.Images[i],
})
}

images = append(images, currentVars.Images...)
case "assistant":
currentVars.Response = msg.Content
prompts = append(prompts, currentVars)
Expand All @@ -196,9 +206,8 @@ func (m *Model) ChatPrompts(msgs []api.Message) (*ChatHistory, error) {
}

return &ChatHistory{
Prompts: prompts,
CurrentImages: currentImages,
LastSystem: lastSystem,
Prompts: prompts,
LastSystem: lastSystem,
}, nil
}

Expand Down
33 changes: 26 additions & 7 deletions server/images_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -238,18 +238,37 @@ func chatHistoryEqual(a, b ChatHistory) bool {
if len(a.Prompts) != len(b.Prompts) {
return false
}
if len(a.CurrentImages) != len(b.CurrentImages) {
return false
}
for i, v := range a.Prompts {
if v != b.Prompts[i] {

if v.First != b.Prompts[i].First {
return false
}
}
for i, v := range a.CurrentImages {
if !bytes.Equal(v, b.CurrentImages[i]) {

if v.Response != b.Prompts[i].Response {
return false
}

if v.Prompt != b.Prompts[i].Prompt {
return false
}

if v.System != b.Prompts[i].System {
return false
}

if len(v.Images) != len(b.Prompts[i].Images) {
return false
}

for j, img := range v.Images {
if img.ID != b.Prompts[i].Images[j].ID {
return false
}

if !bytes.Equal(img.Data, b.Prompts[i].Images[j].Data) {
return false
}
}
}
return a.LastSystem == b.LastSystem
}
Expand Down
53 changes: 40 additions & 13 deletions server/routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,10 @@ func GenerateHandler(c *gin.Context) {
promptVars.System = model.System
}

for i := range req.Images {
promptVars.Prompt += fmt.Sprintf(" [img-%d]", i)
}

p, err := model.PreResponsePrompt(promptVars)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
Expand Down Expand Up @@ -308,11 +312,19 @@ func GenerateHandler(c *gin.Context) {
ch <- resp
}

var images []llm.ImageData
for i := range req.Images {
images = append(images, llm.ImageData{
ID: i,
Data: req.Images[i],
})
}

// Start prediction
predictReq := llm.PredictOpts{
Prompt: prompt,
Format: req.Format,
Images: req.Images,
Images: images,
Options: opts,
}
if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
Expand Down Expand Up @@ -1139,7 +1151,8 @@ func ChatHandler(c *gin.Context) {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
prompt, err := trimmedPrompt(c.Request.Context(), chat, model)

prompt, images, err := trimmedPrompt(c.Request.Context(), chat, model)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
Expand Down Expand Up @@ -1182,7 +1195,7 @@ func ChatHandler(c *gin.Context) {
predictReq := llm.PredictOpts{
Prompt: prompt,
Format: req.Format,
Images: chat.CurrentImages,
Images: images,
Options: opts,
}
if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
Expand Down Expand Up @@ -1229,42 +1242,55 @@ type promptInfo struct {

// trimmedPrompt builds a prompt to send to a running model. It ensures the prompt fits within the max context length,
// while preserving the most recent system message.
func trimmedPrompt(ctx context.Context, chat *ChatHistory, model *Model) (string, error) {
func trimmedPrompt(ctx context.Context, chat *ChatHistory, model *Model) (string, []llm.ImageData, error) {
if len(chat.Prompts) == 0 {
return "", nil
return "", nil, nil
}

var promptsToAdd []promptInfo
var totalTokenLength int
var systemPromptIncluded bool

var images []llm.ImageData
// reverse iterate through the prompts to build the prompt string in a way that fits the max context length
for i := len(chat.Prompts) - 1; i >= 0; i-- {
promptText, err := promptString(model, chat.Prompts[i], i == len(chat.Prompts)-1)
prompt := chat.Prompts[i]
promptText, err := promptString(model, prompt, i == len(chat.Prompts)-1)
if err != nil {
return "", err
return "", nil, err
}

encodedTokens, err := loaded.runner.Encode(ctx, promptText)
if err != nil {
return "", err
return "", nil, err
}

if totalTokenLength+len(encodedTokens) > loaded.NumCtx && i != len(chat.Prompts)-1 {
break // reached max context length, stop adding more prompts
}

for j := range prompt.Images {
if totalTokenLength+768 > loaded.NumCtx {
// this decreases the token length but overestimating is fine
prompt.Prompt = strings.ReplaceAll(prompt.Prompt, fmt.Sprintf(" [img-%d]", prompt.Images[j].ID), "")
continue
}

totalTokenLength += 768
images = append(images, prompt.Images[j])
}

totalTokenLength += len(encodedTokens)
systemPromptIncluded = systemPromptIncluded || chat.Prompts[i].System != ""
promptsToAdd = append(promptsToAdd, promptInfo{vars: chat.Prompts[i], tokenLen: len(encodedTokens)})
systemPromptIncluded = systemPromptIncluded || prompt.System != ""
promptsToAdd = append(promptsToAdd, promptInfo{vars: prompt, tokenLen: len(encodedTokens)})
}

// ensure the system prompt is included, if not already
if chat.LastSystem != "" && !systemPromptIncluded {
var err error
promptsToAdd, err = includeSystemPrompt(ctx, chat.LastSystem, totalTokenLength, promptsToAdd)
if err != nil {
return "", err
return "", nil, err
}
}

Expand All @@ -1275,11 +1301,12 @@ func trimmedPrompt(ctx context.Context, chat *ChatHistory, model *Model) (string
for i, prompt := range promptsToAdd {
promptText, err := promptString(model, prompt.vars, i == 0)
if err != nil {
return "", err
return "", nil, err
}
result = promptText + result
}
return result, nil

return result, images, nil
}

// promptString applies the model template to the prompt
Expand Down
3 changes: 2 additions & 1 deletion server/routes_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,8 @@ func Test_ChatPrompt(t *testing.T) {
NumCtx: tt.numCtx,
},
}
got, err := trimmedPrompt(context.Background(), tt.chat, m)
// TODO: add tests for trimming images
got, _, err := trimmedPrompt(context.Background(), tt.chat, m)
if tt.wantErr != "" {
if err == nil {
t.Errorf("ChatPrompt() expected error, got nil")
Expand Down

0 comments on commit bfbf2f7

Please sign in to comment.