Skip to content
This repository has been archived by the owner on Nov 4, 2023. It is now read-only.

Commit

Permalink
Optimization: refactor picker
Browse files Browse the repository at this point in the history
  • Loading branch information
Dreamacro committed Jul 2, 2019
1 parent 0eff851 commit 7c6c147
Show file tree
Hide file tree
Showing 9 changed files with 123 additions and 104 deletions.
13 changes: 11 additions & 2 deletions adapters/outbound/base.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package adapters

import (
"context"
"encoding/json"
"errors"
"net"
Expand Down Expand Up @@ -99,7 +100,7 @@ func (p *Proxy) MarshalJSON() ([]byte, error) {
}

// URLTest get the delay for the specified URL
func (p *Proxy) URLTest(url string) (t uint16, err error) {
func (p *Proxy) URLTest(ctx context.Context, url string) (t uint16, err error) {
defer func() {
p.alive = err == nil
record := C.DelayHistory{Time: time.Now()}
Expand All @@ -123,6 +124,13 @@ func (p *Proxy) URLTest(url string) (t uint16, err error) {
return
}
defer instance.Close()

req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
return
}
req = req.WithContext(ctx)

transport := &http.Transport{
Dial: func(string, string) (net.Conn, error) {
return instance, nil
Expand All @@ -133,8 +141,9 @@ func (p *Proxy) URLTest(url string) (t uint16, err error) {
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}

client := http.Client{Transport: transport}
resp, err := client.Get(url)
resp, err := client.Do(req)
if err != nil {
return
}
Expand Down
3 changes: 2 additions & 1 deletion adapters/outbound/fallback.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package adapters

import (
"context"
"encoding/json"
"errors"
"net"
Expand Down Expand Up @@ -90,7 +91,7 @@ func (f *Fallback) validTest() {

for _, p := range f.proxies {
go func(p C.Proxy) {
p.URLTest(f.rawURL)
p.URLTest(context.Background(), f.rawURL)
wg.Done()
}(p)
}
Expand Down
3 changes: 2 additions & 1 deletion adapters/outbound/loadbalance.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package adapters

import (
"context"
"encoding/json"
"errors"
"net"
Expand Down Expand Up @@ -95,7 +96,7 @@ func (lb *LoadBalance) validTest() {

for _, p := range lb.proxies {
go func(p C.Proxy) {
p.URLTest(lb.rawURL)
p.URLTest(context.Background(), lb.rawURL)
wg.Done()
}(p)
}
Expand Down
38 changes: 12 additions & 26 deletions adapters/outbound/urltest.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"encoding/json"
"errors"
"net"
"sync"
"sync/atomic"
"time"

Expand Down Expand Up @@ -103,35 +102,22 @@ func (u *URLTest) speedTest() {
}
defer atomic.StoreInt32(&u.once, 0)

wg := sync.WaitGroup{}
wg.Add(len(u.proxies))
c := make(chan interface{})
fast := picker.SelectFast(context.Background(), c)
timer := time.NewTimer(u.interval)

ctx, cancel := context.WithTimeout(context.Background(), u.interval)
defer cancel()
picker, ctx := picker.WithContext(ctx)
for _, p := range u.proxies {
go func(p C.Proxy) {
_, err := p.URLTest(u.rawURL)
if err == nil {
c <- p
picker.Go(func() (interface{}, error) {
_, err := p.URLTest(ctx, u.rawURL)
if err != nil {
return nil, err
}
wg.Done()
}(p)
return p, nil
})
}

go func() {
wg.Wait()
close(c)
}()

select {
case <-timer.C:
// Wait for fast to return or close.
<-fast
case p, open := <-fast:
if open {
u.fast = p.(C.Proxy)
}
fast := picker.Wait()
if fast != nil {
u.fast = fast.(C.Proxy)
}
}

Expand Down
59 changes: 45 additions & 14 deletions common/picker/picker.go
Original file line number Diff line number Diff line change
@@ -1,22 +1,53 @@
package picker

import "context"
import (
"context"
"sync"
)

// Picker provides synchronization, and Context cancelation
// for groups of goroutines working on subtasks of a common task.
// Inspired by errGroup
type Picker struct {
cancel func()

wg sync.WaitGroup

once sync.Once
result interface{}
}

// WithContext returns a new Picker and an associated Context derived from ctx.
func WithContext(ctx context.Context) (*Picker, context.Context) {
ctx, cancel := context.WithCancel(ctx)
return &Picker{cancel: cancel}, ctx
}

// Wait blocks until all function calls from the Go method have returned,
// then returns the first nil error result (if any) from them.
func (p *Picker) Wait() interface{} {
p.wg.Wait()
if p.cancel != nil {
p.cancel()
}
return p.result
}

// Go calls the given function in a new goroutine.
// The first call to return a nil error cancels the group; its result will be returned by Wait.
func (p *Picker) Go(f func() (interface{}, error)) {
p.wg.Add(1)

func SelectFast(ctx context.Context, in <-chan interface{}) <-chan interface{} {
out := make(chan interface{})
go func() {
select {
case p, open := <-in:
if open {
out <- p
}
case <-ctx.Done():
}
defer p.wg.Done()

close(out)
for range in {
if ret, err := f(); err == nil {
p.once.Do(func() {
p.result = ret
if p.cancel != nil {
p.cancel()
}
})
}
}()

return out
}
44 changes: 21 additions & 23 deletions common/picker/picker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,39 +6,37 @@ import (
"time"
)

func sleepAndSend(delay int, in chan<- interface{}, input interface{}) {
time.Sleep(time.Millisecond * time.Duration(delay))
in <- input
}

func sleepAndClose(delay int, in chan interface{}) {
time.Sleep(time.Millisecond * time.Duration(delay))
close(in)
func sleepAndSend(ctx context.Context, delay int, input interface{}) func() (interface{}, error) {
return func() (interface{}, error) {
timer := time.NewTimer(time.Millisecond * time.Duration(delay))
select {
case <-timer.C:
return input, nil
case <-ctx.Done():
return nil, ctx.Err()
}
}
}

func TestPicker_Basic(t *testing.T) {
in := make(chan interface{})
fast := SelectFast(context.Background(), in)
go sleepAndSend(20, in, 1)
go sleepAndSend(30, in, 2)
go sleepAndClose(40, in)
picker, ctx := WithContext(context.Background())
picker.Go(sleepAndSend(ctx, 30, 2))
picker.Go(sleepAndSend(ctx, 20, 1))

number, exist := <-fast
if !exist || number != 1 {
t.Error("should recv 1", exist, number)
number := picker.Wait()
if number != nil && number.(int) != 1 {
t.Error("should recv 1", number)
}
}

func TestPicker_Timeout(t *testing.T) {
in := make(chan interface{})
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*5)
defer cancel()
fast := SelectFast(ctx, in)
go sleepAndSend(20, in, 1)
go sleepAndClose(30, in)
picker, ctx := WithContext(ctx)
picker.Go(sleepAndSend(ctx, 20, 1))

_, exist := <-fast
if exist {
t.Error("should recv false")
number := picker.Wait()
if number != nil {
t.Error("should recv nil")
}
}
3 changes: 2 additions & 1 deletion constant/adapters.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package constant

import (
"context"
"net"
"time"
)
Expand Down Expand Up @@ -44,7 +45,7 @@ type Proxy interface {
Alive() bool
DelayHistory() []DelayHistory
LastDelay() uint16
URLTest(url string) (uint16, error)
URLTest(ctx context.Context, url string) (uint16, error)
}

// AdapterType is enum of adapter type
Expand Down
24 changes: 7 additions & 17 deletions dns/resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,32 +163,22 @@ func (r *Resolver) IsFakeIP() bool {
}

func (r *Resolver) batchExchange(clients []resolver, m *D.Msg) (msg *D.Msg, err error) {
in := make(chan interface{})
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
fast := picker.SelectFast(ctx, in)
fast, ctx := picker.WithContext(ctx)

wg := sync.WaitGroup{}
wg.Add(len(clients))
for _, r := range clients {
go func(r resolver) {
defer wg.Done()
fast.Go(func() (interface{}, error) {
msg, err := r.ExchangeContext(ctx, m)
if err != nil || msg.Rcode != D.RcodeSuccess {
return
return nil, errors.New("resolve error")
}
in <- msg
}(r)
return msg, nil
})
}

// release in channel
go func() {
wg.Wait()
close(in)
}()

elm, exist := <-fast
if !exist {
elm := fast.Wait()
if elm == nil {
return nil, errors.New("All DNS requests failed")
}

Expand Down
40 changes: 21 additions & 19 deletions hub/route/proxies.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"time"

A "github.com/Dreamacro/clash/adapters/outbound"
"github.com/Dreamacro/clash/common/picker"
C "github.com/Dreamacro/clash/constant"
T "github.com/Dreamacro/clash/tunnel"

Expand Down Expand Up @@ -110,27 +111,28 @@ func getProxyDelay(w http.ResponseWriter, r *http.Request) {

proxy := r.Context().Value(CtxKeyProxy).(C.Proxy)

sigCh := make(chan uint16)
go func() {
t, err := proxy.URLTest(url)
if err != nil {
sigCh <- 0
}
sigCh <- t
}()
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*time.Duration(timeout))
defer cancel()
picker, ctx := picker.WithContext(ctx)
picker.Go(func() (interface{}, error) {
return proxy.URLTest(ctx, url)
})

select {
case <-time.After(time.Millisecond * time.Duration(timeout)):
elm := picker.Wait()
if elm == nil {
render.Status(r, http.StatusRequestTimeout)
render.JSON(w, r, ErrRequestTimeout)
case t := <-sigCh:
if t == 0 {
render.Status(r, http.StatusServiceUnavailable)
render.JSON(w, r, newError("An error occurred in the delay test"))
} else {
render.JSON(w, r, render.M{
"delay": t,
})
}
return
}

delay := elm.(uint16)
if delay == 0 {
render.Status(r, http.StatusServiceUnavailable)
render.JSON(w, r, newError("An error occurred in the delay test"))
return
}

render.JSON(w, r, render.M{
"delay": delay,
})
}

0 comments on commit 7c6c147

Please sign in to comment.