Skip to content

Commit

Permalink
add keep_alive to generate/chat/embedding api endpoints (ollama#2146)
Browse files Browse the repository at this point in the history
  • Loading branch information
pdevine authored Jan 26, 2024
1 parent cc4915e commit b5cf31b
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 20 deletions.
42 changes: 25 additions & 17 deletions api/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,24 +34,26 @@ func (e StatusError) Error() string {
type ImageData []byte

type GenerateRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
System string `json:"system"`
Template string `json:"template"`
Context []int `json:"context,omitempty"`
Stream *bool `json:"stream,omitempty"`
Raw bool `json:"raw,omitempty"`
Format string `json:"format"`
Images []ImageData `json:"images,omitempty"`
Model string `json:"model"`
Prompt string `json:"prompt"`
System string `json:"system"`
Template string `json:"template"`
Context []int `json:"context,omitempty"`
Stream *bool `json:"stream,omitempty"`
Raw bool `json:"raw,omitempty"`
Format string `json:"format"`
KeepAlive *Duration `json:"keep_alive,omitempty"`
Images []ImageData `json:"images,omitempty"`

Options map[string]interface{} `json:"options"`
}

type ChatRequest struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
Stream *bool `json:"stream,omitempty"`
Format string `json:"format"`
Model string `json:"model"`
Messages []Message `json:"messages"`
Stream *bool `json:"stream,omitempty"`
Format string `json:"format"`
KeepAlive *Duration `json:"keep_alive,omitempty"`

Options map[string]interface{} `json:"options"`
}
Expand Down Expand Up @@ -126,8 +128,9 @@ type Runner struct {
}

type EmbeddingRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
Model string `json:"model"`
Prompt string `json:"prompt"`
KeepAlive *Duration `json:"keep_alive,omitempty"`

Options map[string]interface{} `json:"options"`
}
Expand Down Expand Up @@ -413,14 +416,19 @@ func (d *Duration) UnmarshalJSON(b []byte) (err error) {
case float64:
if t < 0 {
t = math.MaxFloat64
d.Duration = time.Duration(t)
} else {
d.Duration = time.Duration(t * float64(time.Second))
}

d.Duration = time.Duration(t)
case string:
d.Duration, err = time.ParseDuration(t)
if err != nil {
return err
}
if d.Duration < 0 {
mf := math.MaxFloat64
d.Duration = time.Duration(mf)
}
}

return nil
Expand Down
26 changes: 23 additions & 3 deletions server/routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,13 @@ func GenerateHandler(c *gin.Context) {
return
}

sessionDuration := defaultSessionDuration
var sessionDuration time.Duration
if req.KeepAlive == nil {
sessionDuration = defaultSessionDuration
} else {
sessionDuration = req.KeepAlive.Duration
}

if err := load(c, model, opts, sessionDuration); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
Expand Down Expand Up @@ -378,7 +384,14 @@ func EmbeddingHandler(c *gin.Context) {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
sessionDuration := defaultSessionDuration

var sessionDuration time.Duration
if req.KeepAlive == nil {
sessionDuration = defaultSessionDuration
} else {
sessionDuration = req.KeepAlive.Duration
}

if err := load(c, model, opts, sessionDuration); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
Expand Down Expand Up @@ -1074,7 +1087,14 @@ func ChatHandler(c *gin.Context) {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
sessionDuration := defaultSessionDuration

var sessionDuration time.Duration
if req.KeepAlive == nil {
sessionDuration = defaultSessionDuration
} else {
sessionDuration = req.KeepAlive.Duration
}

if err := load(c, model, opts, sessionDuration); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
Expand Down

0 comments on commit b5cf31b

Please sign in to comment.