Skip to content

Commit f1652cd

Browse files
committed
支持负载均衡下载
1 parent d1e936d commit f1652cd

File tree

6 files changed

+216
-66
lines changed

6 files changed

+216
-66
lines changed

internal/pcscommand/download.go

+57-45
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
"io"
1414
"net/http"
1515
"net/http/cookiejar"
16+
"net/url"
1617
"os"
1718
"path/filepath"
1819
"strconv"
@@ -47,7 +48,7 @@ type DownloadOption struct {
4748
Parallel int
4849
}
4950

50-
func download(id int, downloadURL, savePath string, client *requester.HTTPClient, cfg *downloader.Config, isPrintStatus, isExecutedPermission bool) error {
51+
func download(id int, downloadURL, savePath string, loadBalansers []string, client *requester.HTTPClient, cfg *downloader.Config, isPrintStatus, isExecutedPermission bool) error {
5152
var (
5253
file *os.File
5354
writerAt io.WriterAt
@@ -87,10 +88,25 @@ func download(id int, downloadURL, savePath string, client *requester.HTTPClient
8788
download := downloader.NewDownloader(downloadURL, writerAt, cfg)
8889
download.SetClient(client)
8990
download.TryHTTP(!pcsconfig.Config.EnableHTTPS())
91+
download.AddLoadBalanceServer(loadBalansers...)
9092

9193
exitChan = make(chan struct{})
9294

9395
download.OnExecute(func() {
96+
if isPrintStatus {
97+
go func() {
98+
for {
99+
time.Sleep(1 * time.Second)
100+
select {
101+
case <-exitChan:
102+
return
103+
default:
104+
download.PrintAllWorkers()
105+
}
106+
}
107+
}()
108+
}
109+
94110
if cfg.IsTest {
95111
fmt.Printf("[%d] 测试下载开始\n\n", id)
96112
}
@@ -126,19 +142,6 @@ func download(id int, downloadURL, savePath string, client *requester.HTTPClient
126142
}
127143
})
128144

129-
if isPrintStatus {
130-
go func() {
131-
for {
132-
time.Sleep(1 * time.Second)
133-
select {
134-
case <-exitChan:
135-
return
136-
default:
137-
download.PrintAllWorkers()
138-
}
139-
}
140-
}()
141-
}
142145
err = download.Execute()
143146
close(exitChan)
144147
if err != nil {
@@ -161,27 +164,6 @@ func download(id int, downloadURL, savePath string, client *requester.HTTPClient
161164
return nil
162165
}
163166

164-
func getDownloadFunc(id int, savePath string, cfg *downloader.Config, isPrintStatus, isExecutedPermission bool) baidupcs.DownloadFunc {
165-
if cfg == nil {
166-
cfg = downloader.NewConfig()
167-
}
168-
169-
return func(downloadURL string, jar *cookiejar.Jar) error {
170-
h := requester.NewHTTPClient()
171-
h.SetCookiejar(jar)
172-
h.SetKeepAlive(true)
173-
h.SetTimeout(10 * time.Minute)
174-
setupHTTPClient(h)
175-
176-
err := download(id, downloadURL, savePath, h, cfg, isPrintStatus, isExecutedPermission)
177-
if err != nil {
178-
return err
179-
}
180-
181-
return nil
182-
}
183-
}
184-
185167
// RunDownload 执行下载网盘内文件
186168
func RunDownload(paths []string, option DownloadOption) {
187169
// 设置下载配置
@@ -331,10 +313,27 @@ func RunDownload(paths []string, option DownloadOption) {
331313
}
332314

333315
// 获取直链, 或者以分享文件的方式获取下载链接来下载
334-
var dlink string
316+
var (
317+
dlink string
318+
dlinks []string
319+
)
320+
335321
if option.IsLocateDownload {
336-
dlink = getDownloadLink(task.path)
322+
// 提取直链下载
323+
rawDlinks := getDownloadLinks(task.path)
324+
if len(rawDlinks) > 0 {
325+
dlink = rawDlinks[0].String()
326+
dlinks = make([]string, 0, len(rawDlinks)-1)
327+
for _, rawDlink := range rawDlinks[1:len(rawDlinks)] {
328+
if rawDlink == nil {
329+
continue
330+
}
331+
332+
dlinks = append(dlinks, rawDlink.String())
333+
}
334+
}
337335
} else if option.IsShareDownload {
336+
// 分享下载
338337
dlink = getShareDLink(task.path)
339338
}
340339

@@ -354,9 +353,22 @@ func RunDownload(paths []string, option DownloadOption) {
354353
client.SetTimeout(20 * time.Minute)
355354
client.SetKeepAlive(true)
356355
setupHTTPClient(client)
357-
err = download(task.ID, dlink, task.savePath, client, cfg, option.IsPrintStatus, option.IsExecutedPermission)
356+
err = download(task.ID, dlink, task.savePath, dlinks, client, cfg, option.IsPrintStatus, option.IsExecutedPermission)
358357
} else {
359-
err = pcs.DownloadFile(task.path, getDownloadFunc(task.ID, task.savePath, cfg, option.IsPrintStatus, option.IsExecutedPermission))
358+
err = pcs.DownloadFile(task.path, func(downloadURL string, jar *cookiejar.Jar) error {
359+
h := requester.NewHTTPClient()
360+
h.SetCookiejar(jar)
361+
h.SetKeepAlive(true)
362+
h.SetTimeout(10 * time.Minute)
363+
setupHTTPClient(h)
364+
365+
err := download(task.ID, downloadURL, task.savePath, dlinks, h, cfg, option.IsPrintStatus, option.IsExecutedPermission)
366+
if err != nil {
367+
return err
368+
}
369+
370+
return nil
371+
})
360372
}
361373

362374
if err != nil {
@@ -398,21 +410,21 @@ func RunLocateDownload(pcspaths ...string) {
398410
}
399411
}
400412

401-
func getDownloadLink(pcspath string) string {
413+
func getDownloadLinks(pcspath string) (dlinks []*url.URL) {
402414
pcs := GetBaiduPCS()
403415
dInfo, pcsError := pcs.LocateDownload(pcspath)
404416
if pcsError != nil {
405417
pcsCommandVerbose.Warn(pcsError.Error())
406-
return ""
418+
return
407419
}
408420

409-
u := dInfo.SingleURL(pcsconfig.Config.EnableHTTPS())
410-
if u == nil {
421+
us := dInfo.URLStrings(pcsconfig.Config.EnableHTTPS())
422+
if len(us) == 0 {
411423
pcsCommandVerbose.Warn("no any url")
412-
return ""
424+
return
413425
}
414426

415-
return u.String()
427+
return us
416428
}
417429

418430
// fileExist 检查文件是否存在,

requester/downloader/download_test.go

+6
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,12 @@ var (
1414
url2 = "https://git.oschina.net/lufenping/pixabay_img/raw/master/tiny-20170712/lizard-2427248_1920.jpg"
1515
)
1616

17+
func TestRandomNumber(t *testing.T) {
18+
for i := 0; i < 10; i++ {
19+
fmt.Println(RandomNumber(0, 5))
20+
}
21+
}
22+
1723
func TestExample(t *testing.T) {
1824
DoDownload(url2, "lizard-2427248_1920.jpg", nil)
1925
}

requester/downloader/downloader.go

+65-18
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,11 @@ import (
55
"context"
66
"errors"
77
"fmt"
8+
"github.com/iikira/BaiduPCS-Go/pcsutil/waitgroup"
89
"github.com/iikira/BaiduPCS-Go/pcsverbose"
910
"github.com/iikira/BaiduPCS-Go/requester"
1011
"io"
12+
"net/http"
1113
"sync"
1214
"time"
1315
)
@@ -28,6 +30,7 @@ type Downloader struct {
2830
executeTime time.Time
2931
executed bool
3032
durl string
33+
loadBalansers []string
3134
tryHTTP bool
3235
writer io.WriterAt
3336
client *requester.HTTPClient
@@ -103,19 +106,62 @@ func (der *Downloader) Execute() error {
103106
status.totalSize = resp.ContentLength
104107

105108
var (
106-
req = resp.Request
107-
durl, referer string
108-
)
109+
loadBalancerResponses = make([]*LoadBalancerResponse, 0, len(der.loadBalansers)+1)
110+
handleLoadBalancer = func(req *http.Request) {
111+
if req != nil {
112+
if der.tryHTTP {
113+
req.URL.Scheme = "http"
114+
}
109115

110-
if req != nil {
111-
referer = req.Referer()
116+
loadBalancer := &LoadBalancerResponse{
117+
URL: req.URL.String(),
118+
Referer: req.Referer(),
119+
}
112120

113-
if der.tryHTTP {
114-
req.URL.Scheme = "http"
121+
loadBalancerResponses = append(loadBalancerResponses, loadBalancer)
122+
pcsverbose.Verbosef("DEBUG: download task: URL: %s, Referer: %s\n", loadBalancer.URL, loadBalancer.Referer)
123+
}
115124
}
116-
durl = req.URL.String()
117-
pcsverbose.Verbosef("DEBUG: download task: URL: %s, Referer: %s\n", durl, referer)
125+
)
126+
127+
handleLoadBalancer(resp.Request)
128+
129+
// 负载均衡
130+
wg := waitgroup.NewWaitGroup(10)
131+
privTimeout := der.client.Client.Timeout
132+
der.client.SetTimeout(5 * time.Second)
133+
for _, loadBalanser := range der.loadBalansers {
134+
wg.AddDelta()
135+
go func(loadBalanser string) {
136+
defer wg.Done()
137+
138+
subResp, subErr := der.client.Req("HEAD", loadBalanser, nil, nil)
139+
if subResp != nil {
140+
defer subResp.Body.Close()
141+
}
142+
if subErr != nil {
143+
pcsverbose.Verbosef("DEBUG: loadBalanser Error: %s\n", subErr)
144+
return
145+
}
146+
147+
if !ServerEqual(resp, subResp) {
148+
pcsverbose.Verbosef("DEBUG: loadBalanser not equal to main server: %s\n", subErr)
149+
return
150+
}
151+
152+
if subResp.Request != nil {
153+
loadBalancerResponses = append(loadBalancerResponses, &LoadBalancerResponse{
154+
URL: subResp.Request.URL.String(),
155+
})
156+
}
157+
handleLoadBalancer(subResp.Request)
158+
159+
}(loadBalanser)
118160
}
161+
wg.Wait()
162+
der.client.SetTimeout(privTimeout)
163+
164+
loadBalancerResponseList := NewLoadBalancerResponseList(loadBalancerResponses)
119165

120166
//load breakpoint
121167
err = der.initInstanceState()
@@ -178,16 +224,17 @@ func (der *Downloader) Execute() error {
178224
writerAt = der.writer
179225
}
180226

181-
workerInit := func(wer *Worker) {
182-
wer.SetClient(der.client)
183-
wer.SetCacheSize(der.config.cacheSize)
184-
wer.SetWriteMutex(writeMu)
185-
wer.SetReferer(referer)
186-
}
187-
188227
for i := 0; i < der.config.parallel; i++ {
189-
worker := NewWorker(int32(i), durl, writerAt)
190-
workerInit(worker)
228+
loadBalancer := loadBalancerResponseList.SequentialGet()
229+
if loadBalancer == nil {
230+
continue
231+
}
232+
233+
worker := NewWorker(i, loadBalancer.URL, writerAt)
234+
worker.SetClient(der.client)
235+
worker.SetCacheSize(der.config.cacheSize)
236+
worker.SetWriteMutex(writeMu)
237+
worker.SetReferer(loadBalancer.Referer)
191238

192239
// 分配线程
193240
if isRange {

requester/downloader/loadbalance.go

+73
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
package downloader
2+
3+
import (
4+
"net/http"
5+
"sync/atomic"
6+
)
7+
8+
// LoadBalancerResponse 负载均衡响应状态
9+
type LoadBalancerResponse struct {
10+
URL string
11+
Referer string
12+
}
13+
14+
// LoadBalancerResponseList 负载均衡列表
15+
type LoadBalancerResponseList struct {
16+
lbr []*LoadBalancerResponse
17+
cursor int32
18+
}
19+
20+
// NewLoadBalancerResponseList 初始化负载均衡列表
21+
func NewLoadBalancerResponseList(lbr []*LoadBalancerResponse) *LoadBalancerResponseList {
22+
return &LoadBalancerResponseList{
23+
lbr: lbr,
24+
}
25+
}
26+
27+
// SequentialGet 顺序获取
28+
func (lbrl *LoadBalancerResponseList) SequentialGet() *LoadBalancerResponse {
29+
if len(lbrl.lbr) == 0 {
30+
return nil
31+
}
32+
33+
if int(lbrl.cursor) >= len(lbrl.lbr) {
34+
lbrl.cursor = 0
35+
}
36+
37+
lbr := lbrl.lbr[int(lbrl.cursor)]
38+
atomic.AddInt32(&lbrl.cursor, 1)
39+
return lbr
40+
}
41+
42+
// RandomGet 随机获取
43+
func (lbrl *LoadBalancerResponseList) RandomGet() *LoadBalancerResponse {
44+
return lbrl.lbr[RandomNumber(0, len(lbrl.lbr))]
45+
}
46+
47+
// AddLoadBalanceServer 增加负载均衡服务器
48+
func (der *Downloader) AddLoadBalanceServer(urls ...string) {
49+
der.loadBalansers = append(der.loadBalansers, urls...)
50+
}
51+
52+
// ServerEqual 检测负载均衡的服务器是否一致
53+
func ServerEqual(resp, subResp *http.Response) bool {
54+
if resp == nil || subResp == nil {
55+
return false
56+
}
57+
58+
if resp.ContentLength != subResp.ContentLength {
59+
return false
60+
}
61+
62+
header, subHeader := resp.Header, subResp.Header
63+
if header.Get("Content-MD5") != subHeader.Get("Content-MD5") {
64+
return false
65+
}
66+
if header.Get("Content-Type") != subHeader.Get("Content-Type") {
67+
return false
68+
}
69+
if header.Get("x-bs-meta-crc32") != subHeader.Get("x-bs-meta-crc32") {
70+
return false
71+
}
72+
return true
73+
}

0 commit comments

Comments
 (0)