Skip to content

Commit

Permalink
Add proxy config for multi-sync (#237)
Browse files Browse the repository at this point in the history
  • Loading branch information
pulltheflower authored Jan 10, 2025
1 parent 5395960 commit 10a4559
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 94 deletions.
108 changes: 19 additions & 89 deletions builder/multisync/client.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
package multisync

import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"strings"

"opencsg.com/csghub-server/builder/rpc"
"opencsg.com/csghub-server/common/types"
)

Expand All @@ -20,132 +18,64 @@ type Client interface {
}

func FromOpenCSG(endpoint string, accessToken string) Client {
return &commonClient{
endpoint: endpoint,
hc: http.DefaultClient,
authToken: accessToken,
}
return &commonClient{rpcClent: rpc.NewHttpClient(endpoint, rpc.AuthWithApiKey(accessToken))}
}

type commonClient struct {
endpoint string
hc *http.Client
authToken string
rpcClent *rpc.HttpClient
}

func (c *commonClient) Latest(ctx context.Context, currentVersion int64) (types.SyncVersionResponse, error) {
url := fmt.Sprintf("%s/api/v1/sync/version/latest?cur=%d", c.endpoint, currentVersion)
req, _ := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
req.Header.Add("Authorization", "Bearer "+c.authToken)
resp, err := c.hc.Do(req)
if err != nil {
return types.SyncVersionResponse{}, fmt.Errorf("failed to get latest version from endpoint %s, param cur:%d, cause: %w",
c.endpoint, currentVersion, err)
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
var data bytes.Buffer
_, _ = data.ReadFrom(resp.Body)
return types.SyncVersionResponse{}, fmt.Errorf("failed to get latest version from endpoint %s, param cur:%d, status code: %d, body: %s",
c.endpoint, currentVersion, resp.StatusCode, data.String())
}
var svc types.SyncVersionResponse
err = json.NewDecoder(resp.Body).Decode(&svc)
path := fmt.Sprintf("/api/v1/sync/version/latest?cur=%d", currentVersion)

err := c.rpcClent.Get(ctx, path, &svc)
if err != nil {
return types.SyncVersionResponse{}, fmt.Errorf("failed to decode response body as types.SyncVersionResponse, cause: %w", err)
return svc, fmt.Errorf("failed to get latest version, cause: %w", err)
}
return svc, nil
}

func (c *commonClient) ModelInfo(ctx context.Context, v types.SyncVersion) (*types.Model, error) {
namespace, name, _ := strings.Cut(v.RepoPath, "/")
url := fmt.Sprintf("%s/api/v1/models/%s/%s", c.endpoint, namespace, name)
req, _ := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
req.Header.Add("Authorization", "Bearer "+c.authToken)
resp, err := c.hc.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to get model info from endpoint %s, repo path:%s, cause: %w",
c.endpoint, v.RepoPath, err)
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
return nil, fmt.Errorf("failed to get model info from endpoint %s, repo path:%s, status code: %d",
c.endpoint, v.RepoPath, resp.StatusCode)
}
url := fmt.Sprintf("/api/v1/models/%s/%s", namespace, name)
var res types.ModelResponse
err = json.NewDecoder(resp.Body).Decode(&res)
err := c.rpcClent.Get(ctx, url, &res)
if err != nil {
return nil, fmt.Errorf("failed to decode response body as types.Model, cause: %w", err)
return nil, fmt.Errorf("failed to get model info, cause: %w", err)
}
return &res.Data, nil
}

func (c *commonClient) DatasetInfo(ctx context.Context, v types.SyncVersion) (*types.Dataset, error) {
namespace, name, _ := strings.Cut(v.RepoPath, "/")
url := fmt.Sprintf("%s/api/v1/datasets/%s/%s", c.endpoint, namespace, name)
req, _ := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
req.Header.Add("Authorization", "Bearer "+c.authToken)
resp, err := c.hc.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to get dataset info from endpoint %s, repo path:%s, cause: %w",
c.endpoint, v.RepoPath, err)
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
return nil, fmt.Errorf("failed to get dataset info from endpoint %s, repo path:%s, status code: %d",
c.endpoint, v.RepoPath, resp.StatusCode)
}
url := fmt.Sprintf("/api/v1/datasets/%s/%s", namespace, name)
var res types.DatasetResponse
err = json.NewDecoder(resp.Body).Decode(&res)
err := c.rpcClent.Get(ctx, url, &res)
if err != nil {
return nil, fmt.Errorf("failed to decode response body as types.Dataset, cause: %w", err)
return nil, fmt.Errorf("failed to get dataset info, cause: %w", err)
}
return &res.Data, nil
}

func (c *commonClient) ReadMeData(ctx context.Context, v types.SyncVersion) (string, error) {
namespace, name, _ := strings.Cut(v.RepoPath, "/")
url := fmt.Sprintf("%s/api/v1/%ss/%s/%s/raw/README.md", c.endpoint, v.RepoType, namespace, name)
req, _ := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
req.Header.Add("Authorization", "Bearer "+c.authToken)
resp, err := c.hc.Do(req)
if err != nil {
return "", fmt.Errorf("failed to get readme data endpoint %s, repo path:%s, cause: %w",
c.endpoint, v.RepoPath, err)
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
return "", fmt.Errorf("failed to get readme data from endpoint %s, repo path:%s, status code: %d",
c.endpoint, v.RepoPath, resp.StatusCode)
}
url := fmt.Sprintf("/api/v1/%ss/%s/%s/raw/README.md", v.RepoType, namespace, name)
var res types.ReadMeResponse
err = json.NewDecoder(resp.Body).Decode(&res)
err := c.rpcClent.Get(ctx, url, &res)
if err != nil {
return "", fmt.Errorf("failed to decode response body as types.Dataset, cause: %w", err)
return "", fmt.Errorf("failed to get dataset info, cause: %w", err)
}
return res.Data, nil
}

func (c *commonClient) FileList(ctx context.Context, v types.SyncVersion) ([]types.File, error) {
namespace, name, _ := strings.Cut(v.RepoPath, "/")
url := fmt.Sprintf("%s/api/v1/%ss/%s/%s/all_files", c.endpoint, v.RepoType, namespace, name)
req, _ := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
req.Header.Add("Authorization", "Bearer "+c.authToken)
resp, err := c.hc.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to get readme data endpoint %s, repo path:%s, cause: %w",
c.endpoint, v.RepoPath, err)
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
return nil, fmt.Errorf("failed to get readme data from endpoint %s, repo path:%s, status code: %d",
c.endpoint, v.RepoPath, resp.StatusCode)
}
url := fmt.Sprintf("/api/v1/%ss/%s/%s/all_files", v.RepoType, namespace, name)
var res types.AllFilesResponse
err = json.NewDecoder(resp.Body).Decode(&res)
err := c.rpcClent.Get(ctx, url, &res)
if err != nil {
return nil, fmt.Errorf("failed to decode response body as types.Dataset, cause: %w", err)
return nil, fmt.Errorf("failed to get file list, cause: %w", err)
}
return res.Data, nil
}
34 changes: 33 additions & 1 deletion builder/rpc/http_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,46 @@ import (
"encoding/json"
"fmt"
"net/http"
"net/url"
"strings"

"opencsg.com/csghub-server/common/config"
)

func NewHttpClient(endpoint string, opts ...RequestOption) *HttpClient {
return &HttpClient{
defaultClient := &HttpClient{
endpoint: endpoint,
hc: http.DefaultClient,
authOpts: opts,
}
cfg, err := config.LoadConfig()
if err != nil {
return defaultClient
}
if !cfg.Proxy.Enable || cfg.Proxy.URL == "" {
return defaultClient
}
proxyHosts := cfg.Proxy.Hosts
proxyURL, err := url.Parse(cfg.Proxy.URL)
if err != nil {
return defaultClient
}

for _, host := range proxyHosts {
if strings.Contains(endpoint, host) {
return &HttpClient{
endpoint: endpoint,
hc: &http.Client{
Transport: &http.Transport{
Proxy: http.ProxyURL(proxyURL),
},
},
authOpts: opts,
}
}
}

return defaultClient
}

type HttpClient struct {
Expand Down
6 changes: 6 additions & 0 deletions common/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,12 @@ type Config struct {
SyncAsClientCronExpression string `env:"STARHUB_SERVER_CRON_JOB_SYNC_AS_CLIENT_CRON_EXPRESSION, default=0 * * * *"`
CalcRecomScoreCronExpression string `env:"STARHUB_SERVER_CRON_JOB_CLAC_RECOM_SCORE_CRON_EXPRESSION, default=0 1 * * *"`
}

Proxy struct {
Enable bool `env:"STARHUB_SERVER_PROXY_ENABLE, default=false"`
URL string `env:"STARHUB_SERVER_PROXY_URL, default="`
Hosts []string `env:"STARHUB_SERVER_PROXY_HOSTS, delimiter=;"`
}
}

func SetConfigFile(file string) {
Expand Down
3 changes: 3 additions & 0 deletions common/config/config.toml.example
Original file line number Diff line number Diff line change
Expand Up @@ -162,3 +162,6 @@ endpoint = "localhost:7233"
[cron_job]
sync_as_client_cron_expression = "0 * * * *"
calc_recom_score_cron_expression = "0 1 * * *"

[proxy]
hosts = "opencsg.com;sync.opencsg.com"
20 changes: 16 additions & 4 deletions mirror/lfssyncer/minio.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ type MinioLFSSyncWorker struct {
config *config.Config
repoStore database.RepoStore
numWorkers int
httpClient *http.Client
}

func NewMinioLFSSyncWorker(config *config.Config, numWorkers int) (*MinioLFSSyncWorker, error) {
Expand All @@ -53,6 +54,19 @@ func NewMinioLFSSyncWorker(config *config.Config, numWorkers int) (*MinioLFSSync
}
w.mq = mq
w.tasks = make(chan queue.MirrorTask)
if !config.Proxy.Enable || config.Proxy.URL == "" {
w.httpClient = &http.Client{}
} else {
proxyURL, err := url.Parse(config.Proxy.URL)
if err != nil {
return nil, fmt.Errorf("fail to parse proxy url: %w", err)
}
w.httpClient = &http.Client{
Transport: &http.Transport{
Proxy: http.ProxyURL(proxyURL),
},
}
}
return w, nil
}

Expand Down Expand Up @@ -188,8 +202,7 @@ func (w *MinioLFSSyncWorker) GetLFSDownloadURLs(ctx context.Context, mirror *dat
req.Header.Set("Content-Type", "application/vnd.git-lfs+json; charset=utf-8")
req.Header.Set("User-Agent", "git-lfs/3.5.1")

client := &http.Client{}
resp, err := client.Do(req)
resp, err := w.httpClient.Do(req)
if err != nil {
return resPointers, fmt.Errorf("failed to send LFS batch request: %v", err)
}
Expand Down Expand Up @@ -282,8 +295,7 @@ func (w *MinioLFSSyncWorker) DownloadAndUploadLFSFile(ctx context.Context, mirro
req.Header.Set("Content-Type", "application/vnd.git-lfs+json; charset=utf-8")
req.Header.Set("User-Agent", "git-lfs/3.5.1")

client := &http.Client{}
resp, err := client.Do(req)
resp, err := w.httpClient.Do(req)
if err != nil {
return fmt.Errorf("failed to download LFS file: %w", err)
}
Expand Down

0 comments on commit 10a4559

Please sign in to comment.