Skip to content

Commit

Permalink
rerefactor
Browse files Browse the repository at this point in the history
  • Loading branch information
mxyng authored and jmorganca committed Feb 15, 2024
1 parent 823a520 commit e43648a
Show file tree
Hide file tree
Showing 9 changed files with 224 additions and 251 deletions.
38 changes: 30 additions & 8 deletions app/lifecycle/updater.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@ package lifecycle

import (
"context"
"crypto/rand"
"encoding/json"
"errors"
"fmt"
"io"
"log/slog"
"mime"
"net/http"
"net/url"
"os"
"path"
"path/filepath"
Expand All @@ -21,7 +23,7 @@ import (
)

var (
UpdateCheckURLBase = "https://ollama.ai/api/update"
UpdateCheckURLBase = "https://ollama.com/api/update"
UpdateDownloaded = false
)

Expand All @@ -47,22 +49,42 @@ func getClient(req *http.Request) http.Client {

func IsNewReleaseAvailable(ctx context.Context) (bool, UpdateResponse) {
var updateResp UpdateResponse
updateCheckURL := UpdateCheckURLBase + "?os=" + runtime.GOOS + "&arch=" + runtime.GOARCH + "&version=" + version.Version
headers := make(http.Header)
err := auth.SignRequest(http.MethodGet, updateCheckURL, nil, headers)

requestURL, err := url.Parse(UpdateCheckURLBase)
if err != nil {
slog.Info(fmt.Sprintf("failed to sign update request %s", err))
return false, updateResp
}

query := requestURL.Query()
query.Add("os", runtime.GOOS)
query.Add("arch", runtime.GOARCH)
query.Add("version", version.Version)
query.Add("ts", fmt.Sprintf("%d", time.Now().Unix()))

nonce, err := auth.NewNonce(rand.Reader, 16)
if err != nil {
return false, updateResp
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, updateCheckURL, nil)

query.Add("nonce", nonce)
requestURL.RawQuery = query.Encode()

data := []byte(fmt.Sprintf("%s,%s", http.MethodGet, requestURL.RequestURI()))
signature, err := auth.Sign(ctx, data)
if err != nil {
return false, updateResp
}

req, err := http.NewRequestWithContext(ctx, http.MethodGet, requestURL.String(), nil)
if err != nil {
slog.Warn(fmt.Sprintf("failed to check for update: %s", err))
return false, updateResp
}
req.Header = headers
req.Header.Set("Authorization", signature)
req.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version()))
client := getClient(req)

slog.Debug(fmt.Sprintf("checking for available update at %s with headers %v", updateCheckURL, headers))
slog.Debug(fmt.Sprintf("checking for available update at %s with headers %v", requestURL, req.Header))
resp, err := client.Do(req)
if err != nil {
slog.Warn(fmt.Sprintf("failed to check for update: %s", err))
Expand Down
2 changes: 1 addition & 1 deletion app/ollama.iss
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
#define MyAppVersion "0.0.0"
#endif
#define MyAppPublisher "Ollama, Inc."
#define MyAppURL "https://ollama.ai/"
#define MyAppURL "https://ollama.com/"
#define MyAppExeName "ollama app.exe"
#define MyIcon ".\assets\app.ico"

Expand Down
153 changes: 13 additions & 140 deletions auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,185 +4,58 @@ import (
"bytes"
"context"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"log/slog"
"net/http"
"net/url"
"os"
"path/filepath"
"strconv"
"strings"
"time"

"golang.org/x/crypto/ssh"

"github.com/jmorganca/ollama/api"
)

const (
KeyType = "id_ed25519"
)

type AuthRedirect struct {
Realm string
Service string
Scope string
}
const defaultPrivateKey = "id_ed25519"

type SignatureData struct {
Method string
Path string
Data []byte
}

func generateNonce(length int) (string, error) {
func NewNonce(r io.Reader, length int) (string, error) {
nonce := make([]byte, length)
_, err := rand.Read(nonce)
if err != nil {
if _, err := io.ReadFull(r, nonce); err != nil {
return "", err
}
return base64.RawURLEncoding.EncodeToString(nonce), nil
}

func (r AuthRedirect) URL() (*url.URL, error) {
redirectURL, err := url.Parse(r.Realm)
if err != nil {
return nil, err
}

values := redirectURL.Query()

values.Add("service", r.Service)

for _, s := range strings.Split(r.Scope, " ") {
values.Add("scope", s)
}

values.Add("ts", strconv.FormatInt(time.Now().Unix(), 10))

nonce, err := generateNonce(16)
if err != nil {
return nil, err
}
values.Add("nonce", nonce)

redirectURL.RawQuery = values.Encode()
return redirectURL, nil
return base64.RawURLEncoding.EncodeToString(nonce), nil
}

func SignRequest(method, url string, data []byte, headers http.Header) error {
func Sign(ctx context.Context, bts []byte) (string, error) {
home, err := os.UserHomeDir()
if err != nil {
return err
}

keyPath := filepath.Join(home, ".ollama", KeyType)

rawKey, err := os.ReadFile(keyPath)
if err != nil {
slog.Info(fmt.Sprintf("Failed to load private key: %v", err))
return err
}

s := SignatureData{
Method: method,
Path: url,
Data: data,
}

sig, err := s.Sign(rawKey)
if err != nil {
return err
}

headers.Set("Authorization", sig)
return nil
}

func GetAuthToken(ctx context.Context, redirData AuthRedirect) (string, error) {
redirectURL, err := redirData.URL()
if err != nil {
return "", err
}

headers := make(http.Header)
err = SignRequest(http.MethodGet, redirectURL.String(), nil, headers)
if err != nil {
return "", err
}
resp, err := MakeRequest(ctx, http.MethodGet, redirectURL, headers, nil, nil)
if err != nil {
slog.Info(fmt.Sprintf("couldn't get token: %q", err))
return "", err
}
defer resp.Body.Close()

if resp.StatusCode >= http.StatusBadRequest {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("%d: %v", resp.StatusCode, err)
} else if len(responseBody) > 0 {
return "", fmt.Errorf("%d: %s", resp.StatusCode, responseBody)
}

return "", fmt.Errorf("%s", resp.Status)
}
keyPath := filepath.Join(home, ".ollama", defaultPrivateKey)

respBody, err := io.ReadAll(resp.Body)
privateKeyFile, err := os.ReadFile(keyPath)
if err != nil {
slog.Info(fmt.Sprintf("Failed to load private key: %v", err))
return "", err
}

var tok api.TokenResponse
if err := json.Unmarshal(respBody, &tok); err != nil {
return "", err
}

return tok.Token, nil
}

// Bytes returns a byte slice of the data to sign for the request
func (s SignatureData) Bytes() []byte {
// We first derive the content hash of the request body using:
// base64(hex(sha256(request body)))

hash := sha256.Sum256(s.Data)
hashHex := make([]byte, hex.EncodedLen(len(hash)))
hex.Encode(hashHex, hash[:])
contentHash := base64.StdEncoding.EncodeToString(hashHex)

// We then put the entire request together in a serialize string using:
// "<method>,<uri>,<content hash>"
// e.g. "GET,http://localhost,OTdkZjM1O..."

return []byte(strings.Join([]string{s.Method, s.Path, contentHash}, ","))
}

// SignData takes a SignatureData object and signs it with a raw private key
func (s SignatureData) Sign(rawKey []byte) (string, error) {
signer, err := ssh.ParsePrivateKey(rawKey)
privateKey, err := ssh.ParsePrivateKey(privateKeyFile)
if err != nil {
return "", err
}

// get the pubkey, but remove the type
pubKey := ssh.MarshalAuthorizedKey(signer.PublicKey())
parts := bytes.Split(pubKey, []byte(" "))
publicKey := ssh.MarshalAuthorizedKey(privateKey.PublicKey())
parts := bytes.Split(publicKey, []byte(" "))
if len(parts) < 2 {
return "", fmt.Errorf("malformed public key")
}

signedData, err := signer.Sign(nil, s.Bytes())
signedData, err := privateKey.Sign(rand.Reader, bts)
if err != nil {
return "", err
}

// signature is <pubkey>:<signature>
sig := fmt.Sprintf("%s:%s", bytes.TrimSpace(parts[1]), base64.StdEncoding.EncodeToString(signedData.Blob))
return sig, nil
return fmt.Sprintf("%s:%s", bytes.TrimSpace(parts[1]), base64.StdEncoding.EncodeToString(signedData.Blob)), nil
}
72 changes: 0 additions & 72 deletions auth/request.go

This file was deleted.

Loading

0 comments on commit e43648a

Please sign in to comment.