Skip to content

Commit

Permalink
Fix: dial tcp with context to avoid margin of error
Browse files Browse the repository at this point in the history
  • Loading branch information
Dreamacro committed Oct 12, 2019
1 parent 0cdc40b commit 7c4a359
Show file tree
Hide file tree
Showing 14 changed files with 47 additions and 30 deletions.
10 changes: 8 additions & 2 deletions adapters/outbound/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,13 @@ func (p *Proxy) Alive() bool {
}

func (p *Proxy) Dial(metadata *C.Metadata) (C.Conn, error) {
conn, err := p.ProxyAdapter.Dial(metadata)
ctx, cancel := context.WithTimeout(context.Background(), tcpTimeout)
defer cancel()
return p.DialContext(ctx, metadata)
}

func (p *Proxy) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, error) {
conn, err := p.ProxyAdapter.DialContext(ctx, metadata)
if err != nil {
p.alive = false
}
Expand Down Expand Up @@ -157,7 +163,7 @@ func (p *Proxy) URLTest(ctx context.Context, url string) (t uint16, err error) {
}

start := time.Now()
instance, err := p.Dial(&addr)
instance, err := p.DialContext(ctx, &addr)
if err != nil {
return
}
Expand Down
5 changes: 3 additions & 2 deletions adapters/outbound/direct.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package adapters

import (
"context"
"net"

C "github.com/Dreamacro/clash/constant"
Expand All @@ -10,13 +11,13 @@ type Direct struct {
*Base
}

func (d *Direct) Dial(metadata *C.Metadata) (C.Conn, error) {
func (d *Direct) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, error) {
address := net.JoinHostPort(metadata.Host, metadata.DstPort)
if metadata.DstIP != nil {
address = net.JoinHostPort(metadata.DstIP.String(), metadata.DstPort)
}

c, err := dialTimeout("tcp", address, tcpTimeout)
c, err := dialContext(ctx, "tcp", address)
if err != nil {
return nil, err
}
Expand Down
4 changes: 2 additions & 2 deletions adapters/outbound/fallback.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ func (f *Fallback) Now() string {
return proxy.Name()
}

func (f *Fallback) Dial(metadata *C.Metadata) (C.Conn, error) {
func (f *Fallback) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, error) {
proxy := f.findAliveProxy()
c, err := proxy.Dial(metadata)
c, err := proxy.DialContext(ctx, metadata)
if err == nil {
c.AppendToChains(f)
}
Expand Down
5 changes: 3 additions & 2 deletions adapters/outbound/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package adapters
import (
"bufio"
"bytes"
"context"
"crypto/tls"
"encoding/base64"
"errors"
Expand Down Expand Up @@ -35,8 +36,8 @@ type HttpOption struct {
SkipCertVerify bool `proxy:"skip-cert-verify,omitempty"`
}

func (h *Http) Dial(metadata *C.Metadata) (C.Conn, error) {
c, err := dialTimeout("tcp", h.addr, tcpTimeout)
func (h *Http) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, error) {
c, err := dialContext(ctx, "tcp", h.addr)
if err == nil && h.tls {
cc := tls.Client(c, h.tlsConfig)
err = cc.Handshake()
Expand Down
6 changes: 3 additions & 3 deletions adapters/outbound/loadbalance.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ func jumpHash(key uint64, buckets int32) int32 {
return int32(b)
}

func (lb *LoadBalance) Dial(metadata *C.Metadata) (c C.Conn, err error) {
func (lb *LoadBalance) DialContext(ctx context.Context, metadata *C.Metadata) (c C.Conn, err error) {
defer func() {
if err == nil {
c.AppendToChains(lb)
Expand All @@ -67,11 +67,11 @@ func (lb *LoadBalance) Dial(metadata *C.Metadata) (c C.Conn, err error) {
idx := jumpHash(key, buckets)
proxy := lb.proxies[idx]
if proxy.Alive() {
c, err = proxy.Dial(metadata)
c, err = proxy.DialContext(ctx, metadata)
return
}
}
c, err = lb.proxies[0].Dial(metadata)
c, err = lb.proxies[0].DialContext(ctx, metadata)
return
}

Expand Down
3 changes: 2 additions & 1 deletion adapters/outbound/reject.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package adapters

import (
"context"
"io"
"net"
"time"
Expand All @@ -12,7 +13,7 @@ type Reject struct {
*Base
}

func (r *Reject) Dial(metadata *C.Metadata) (C.Conn, error) {
func (r *Reject) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, error) {
return newConn(&NopConn{}, r), nil
}

Expand Down
5 changes: 3 additions & 2 deletions adapters/outbound/selector.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package adapters

import (
"context"
"encoding/json"
"errors"
"net"
Expand All @@ -20,8 +21,8 @@ type SelectorOption struct {
Proxies []string `proxy:"proxies"`
}

func (s *Selector) Dial(metadata *C.Metadata) (C.Conn, error) {
c, err := s.selected.Dial(metadata)
func (s *Selector) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, error) {
c, err := s.selected.DialContext(ctx, metadata)
if err == nil {
c.AppendToChains(s)
}
Expand Down
5 changes: 3 additions & 2 deletions adapters/outbound/shadowsocks.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package adapters

import (
"context"
"crypto/tls"
"encoding/json"
"fmt"
Expand Down Expand Up @@ -57,8 +58,8 @@ type v2rayObfsOption struct {
Mux bool `obfs:"mux,omitempty"`
}

func (ss *ShadowSocks) Dial(metadata *C.Metadata) (C.Conn, error) {
c, err := dialTimeout("tcp", ss.server, tcpTimeout)
func (ss *ShadowSocks) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, error) {
c, err := dialContext(ctx, "tcp", ss.server)
if err != nil {
return nil, fmt.Errorf("%s connect error: %s", ss.server, err.Error())
}
Expand Down
5 changes: 3 additions & 2 deletions adapters/outbound/snell.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package adapters

import (
"context"
"fmt"
"net"
"strconv"
Expand All @@ -26,8 +27,8 @@ type SnellOption struct {
ObfsOpts map[string]interface{} `proxy:"obfs-opts,omitempty"`
}

func (s *Snell) Dial(metadata *C.Metadata) (C.Conn, error) {
c, err := dialTimeout("tcp", s.server, tcpTimeout)
func (s *Snell) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, error) {
c, err := dialContext(ctx, "tcp", s.server)
if err != nil {
return nil, fmt.Errorf("%s connect error: %s", s.server, err.Error())
}
Expand Down
9 changes: 6 additions & 3 deletions adapters/outbound/socks5.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package adapters

import (
"context"
"crypto/tls"
"fmt"
"io"
Expand Down Expand Up @@ -33,8 +34,8 @@ type Socks5Option struct {
SkipCertVerify bool `proxy:"skip-cert-verify,omitempty"`
}

func (ss *Socks5) Dial(metadata *C.Metadata) (C.Conn, error) {
c, err := dialTimeout("tcp", ss.addr, tcpTimeout)
func (ss *Socks5) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, error) {
c, err := dialContext(ctx, "tcp", ss.addr)

if err == nil && ss.tls {
cc := tls.Client(c, ss.tlsConfig)
Expand All @@ -60,7 +61,9 @@ func (ss *Socks5) Dial(metadata *C.Metadata) (C.Conn, error) {
}

func (ss *Socks5) DialUDP(metadata *C.Metadata) (_ C.PacketConn, _ net.Addr, err error) {
c, err := dialTimeout("tcp", ss.addr, tcpTimeout)
ctx, cancel := context.WithTimeout(context.Background(), tcpTimeout)
defer cancel()
c, err := dialContext(ctx, "tcp", ss.addr)
if err != nil {
err = fmt.Errorf("%s connect error", ss.addr)
return
Expand Down
4 changes: 2 additions & 2 deletions adapters/outbound/urltest.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ func (u *URLTest) Now() string {
return u.fast.Name()
}

func (u *URLTest) Dial(metadata *C.Metadata) (c C.Conn, err error) {
func (u *URLTest) DialContext(ctx context.Context, metadata *C.Metadata) (c C.Conn, err error) {
for i := 0; i < 3; i++ {
c, err = u.fast.Dial(metadata)
c, err = u.fast.DialContext(ctx, metadata)
if err == nil {
c.AppendToChains(u)
return
Expand Down
4 changes: 1 addition & 3 deletions adapters/outbound/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,15 +86,13 @@ func serializesSocksAddr(metadata *C.Metadata) []byte {
return bytes.Join(buf, nil)
}

func dialTimeout(network, address string, timeout time.Duration) (net.Conn, error) {
func dialContext(ctx context.Context, network, address string) (net.Conn, error) {
host, port, err := net.SplitHostPort(address)
if err != nil {
return nil, err
}

dialer := net.Dialer{}
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()

returned := make(chan struct{})
defer close(returned)
Expand Down
9 changes: 6 additions & 3 deletions adapters/outbound/vmess.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package adapters

import (
"context"
"fmt"
"net"
"strconv"
Expand Down Expand Up @@ -31,8 +32,8 @@ type VmessOption struct {
SkipCertVerify bool `proxy:"skip-cert-verify,omitempty"`
}

func (v *Vmess) Dial(metadata *C.Metadata) (C.Conn, error) {
c, err := dialTimeout("tcp", v.server, tcpTimeout)
func (v *Vmess) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, error) {
c, err := dialContext(ctx, "tcp", v.server)
if err != nil {
return nil, fmt.Errorf("%s connect error", v.server)
}
Expand All @@ -42,7 +43,9 @@ func (v *Vmess) Dial(metadata *C.Metadata) (C.Conn, error) {
}

func (v *Vmess) DialUDP(metadata *C.Metadata) (C.PacketConn, net.Addr, error) {
c, err := dialTimeout("tcp", v.server, tcpTimeout)
ctx, cancel := context.WithTimeout(context.Background(), tcpTimeout)
defer cancel()
c, err := dialContext(ctx, "tcp", v.server)
if err != nil {
return nil, nil, fmt.Errorf("%s connect error", v.server)
}
Expand Down
3 changes: 2 additions & 1 deletion constant/adapters.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ type PacketConn interface {
type ProxyAdapter interface {
Name() string
Type() AdapterType
Dial(metadata *Metadata) (Conn, error)
DialContext(ctx context.Context, metadata *Metadata) (Conn, error)
DialUDP(metadata *Metadata) (PacketConn, net.Addr, error)
SupportUDP() bool
Destroy()
Expand All @@ -74,6 +74,7 @@ type Proxy interface {
ProxyAdapter
Alive() bool
DelayHistory() []DelayHistory
Dial(metadata *Metadata) (Conn, error)
LastDelay() uint16
URLTest(ctx context.Context, url string) (uint16, error)
}
Expand Down

0 comments on commit 7c4a359

Please sign in to comment.