-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathollama.go
93 lines (77 loc) · 1.64 KB
/
ollama.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
package main
import (
"context"
"fmt"
"github.com/ollama/ollama/api"
)
type OllamaConfig struct {
Model string `env:"HELPME_MODEL"`
}
const ollamaDefaultModel = "codegemma:instruct"
type Ollama struct {
config *OllamaConfig
client *api.Client
}
func NewOllamaClient(ctx context.Context, config *OllamaConfig, opts ...OllamaOpt) (*Ollama, error) {
for _, opt := range opts {
opt(config)
}
if config.Model == "" {
config.Model = ollamaDefaultModel
}
client, err := api.ClientFromEnvironment()
if err != nil {
return nil, err
}
err = client.Heartbeat(ctx)
if err != nil {
return nil, err
}
list, err := client.List(ctx)
if err != nil {
panic(err.Error())
}
contains := false
for _, m := range list.Models {
if m.Model == config.Model {
contains = true
break
}
}
if !contains {
fmt.Printf("Downloading %s...\n", config.Model)
err = client.Pull(context.Background(), &api.PullRequest{
Model: config.Model,
Insecure: true,
}, func(response api.ProgressResponse) error {
return nil
})
if err != nil {
return nil, err
}
fmt.Printf("Done")
}
return &Ollama{
config: config,
client: client,
}, nil
}
type OllamaOpt func(*OllamaConfig)
func WithModel(model string) OllamaOpt {
return func(c *OllamaConfig) {
c.Model = model
}
}
func (o *Ollama) Generate(ctx context.Context, system, prompt string, ch chan string, errCh chan error) error {
go func() {
errCh <- o.client.Generate(ctx, &api.GenerateRequest{
Model: o.config.Model,
Prompt: prompt,
System: system,
}, func(response api.GenerateResponse) error {
ch <- response.Response
return nil
})
}()
return nil
}