Skip to content

Commit

Permalink
fix issue ginuerzh#173
Browse files Browse the repository at this point in the history
  • Loading branch information
rui.zheng committed Nov 2, 2017
1 parent c82f2d9 commit e3120ca
Show file tree
Hide file tree
Showing 7 changed files with 104 additions and 66 deletions.
43 changes: 41 additions & 2 deletions chain.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ package gost
import (
"errors"
"net"
"strings"

"github.com/go-log/log"
)

var (
Expand Down Expand Up @@ -122,13 +125,18 @@ func (c *Chain) getConn() (conn net.Conn, nodes []Node, err error) {
if selector == nil {
selector = &defaultSelector{}
}
// select node from node group
node, err := selector.Select(groups[0].Nodes(), groups[0].Options...)
if err != nil {
return
}
nodes = append(nodes, node)

cn, err := node.Client.Dial(node.Addr, node.DialOptions...)
addr, err := selectIP(&node)
if err != nil {
return
}
cn, err := node.Client.Dial(addr, node.DialOptions...)
if err != nil {
return
}
Expand All @@ -154,8 +162,13 @@ func (c *Chain) getConn() (conn net.Conn, nodes []Node, err error) {
}
nodes = append(nodes, node)

addr, err = selectIP(&node)
if err != nil {
return
}

var cc net.Conn
cc, err = preNode.Client.Connect(cn, node.Addr)
cc, err = preNode.Client.Connect(cn, addr)
if err != nil {
cn.Close()
return
Expand All @@ -172,3 +185,29 @@ func (c *Chain) getConn() (conn net.Conn, nodes []Node, err error) {
conn = cn
return
}

func selectIP(node *Node) (string, error) {
addr := node.Addr
s := node.IPSelector
if s == nil {
s = &RandomIPSelector{}
}
// select IP from IP list
ip, err := s.Select(node.IPs)
if err != nil {
return "", err
}
if ip != "" {
if !strings.Contains(ip, ":") {
_, sport, err := net.SplitHostPort(addr)
if err != nil {
return "", err
}
ip = ip + ":" + sport
}
addr = ip
node.HandshakeOptions = append(node.HandshakeOptions, AddrHandshakeOption(addr))
}
log.Log("select IP:", node.Addr, node.IPs, addr)
return addr, nil
}
21 changes: 1 addition & 20 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"crypto/tls"
"net"
"net/url"
"sync/atomic"
"time"
)

Expand Down Expand Up @@ -64,7 +63,6 @@ type Transporter interface {
}

type tcpTransporter struct {
count uint64
}

// TCPTransporter creates a transporter for TCP proxy client.
Expand All @@ -78,16 +76,6 @@ func (tr *tcpTransporter) Dial(addr string, options ...DialOption) (net.Conn, er
option(opts)
}

if len(opts.IPs) > 0 {
count := atomic.AddUint64(&tr.count, 1)
_, sport, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
n := uint64(len(opts.IPs))
addr = opts.IPs[int(count%n)] + ":" + sport
}

if opts.Chain == nil {
return net.DialTimeout("tcp", addr, opts.Timeout)
}
Expand All @@ -106,7 +94,7 @@ func (tr *tcpTransporter) Multiplex() bool {
type DialOptions struct {
Timeout time.Duration
Chain *Chain
IPs []string
// IPs []string
}

// DialOption allows a common way to set dial options.
Expand All @@ -126,13 +114,6 @@ func ChainDialOption(chain *Chain) DialOption {
}
}

// IPDialOption specifies an IP list used by Transporter.Dial
func IPDialOption(ips ...string) DialOption {
return func(opts *DialOptions) {
opts.IPs = ips
}
}

// HandshakeOptions describes the options for handshake.
type HandshakeOptions struct {
Addr string
Expand Down
13 changes: 8 additions & 5 deletions cmd/gost/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,10 @@ func initChain() (*gost.Chain, error) {
if err != nil {
return nil, err
}

node.IPs = parseIP(node.Values.Get("ip"))
node.IPSelector = &gost.RoundRobinIPSelector{}

users, err := parseUsers(node.Values.Get("secrets"))
if err != nil {
return nil, err
Expand Down Expand Up @@ -201,7 +205,6 @@ func initChain() (*gost.Chain, error) {
timeout, _ := strconv.Atoi(node.Values.Get("timeout"))
node.DialOptions = append(node.DialOptions,
gost.TimeoutDialOption(time.Duration(timeout)*time.Second),
gost.IPDialOption(parseIP(node.Values.Get("ip"))...),
)

interval, _ := strconv.Atoi(node.Values.Get("ping"))
Expand Down Expand Up @@ -511,9 +514,11 @@ func parseIP(s string) (ips []string) {
if err != nil {
ss := strings.Split(s, ",")
for _, s := range ss {
if ip := net.ParseIP(s); ip != nil {
s = strings.TrimSpace(s)
if s != "" {
ips = append(ips, s)
}

}
return
}
Expand All @@ -524,9 +529,7 @@ func parseIP(s string) (ips []string) {
if line == "" || strings.HasPrefix(line, "#") {
continue
}
if ip := net.ParseIP(line); ip != nil {
ips = append(ips, line)
}
ips = append(ips, line)
}
return
}
4 changes: 3 additions & 1 deletion node.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
// Node is a proxy node, mainly used to construct a proxy chain.
type Node struct {
Addr string
IPs []string
Protocol string
Transport string
Remote string // remote address, used by tcp/udp port forwarding
Expand All @@ -16,6 +17,7 @@ type Node struct {
DialOptions []DialOption
HandshakeOptions []HandshakeOption
Client *Client
IPSelector IPSelector
}

// ParseNode parses the node info.
Expand Down Expand Up @@ -81,7 +83,7 @@ func ParseNode(s string) (node Node, err error) {
type NodeGroup struct {
nodes []Node
Options []SelectOption
Selector Selector
Selector NodeSelector
}

// NewNodeGroup creates a node group
Expand Down
51 changes: 48 additions & 3 deletions selector.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
package gost

import "errors"
import (
"errors"
"sync/atomic"
"time"
)

var (
// ErrNoneAvailable indicates there is no node available
Expand All @@ -10,8 +14,8 @@ var (
// SelectOption used when making a select call
type SelectOption func(*SelectOptions)

// Selector as a mechanism to pick nodes and mark their status.
type Selector interface {
// NodeSelector as a mechanism to pick nodes and mark their status.
type NodeSelector interface {
Select(nodes []Node, opts ...SelectOption) (Node, error)
// Mark(node Node)
String() string
Expand Down Expand Up @@ -71,3 +75,44 @@ func WithStrategy(s Strategy) SelectOption {
o.Strategy = s
}
}

// IPSelector as a mechanism to pick IPs and mark their status.
type IPSelector interface {
Select(ips []string) (string, error)
String() string
}

// RandomIPSelector is an IP Selector that selects an IP with random strategy.
type RandomIPSelector struct {
}

// Select selects an IP from ips list.
func (s *RandomIPSelector) Select(ips []string) (string, error) {
if len(ips) == 0 {
return "", nil
}
return ips[time.Now().Nanosecond()%len(ips)], nil
}

func (s *RandomIPSelector) String() string {
return "random"
}

// RoundRobinIPSelector is an IP Selector that selects an IP with round-robin strategy.
type RoundRobinIPSelector struct {
count uint64
}

// Select selects an IP from ips list.
func (s *RoundRobinIPSelector) Select(ips []string) (string, error) {
if len(ips) == 0 {
return "", nil
}

count := atomic.AddUint64(&s.count, 1)
return ips[int(count%uint64(len(ips)))], nil
}

func (s *RoundRobinIPSelector) String() string {
return "round"
}
13 changes: 1 addition & 12 deletions tls.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"errors"
"net"
"sync"
"sync/atomic"
"time"

"github.com/go-log/log"
Expand Down Expand Up @@ -53,20 +52,10 @@ func (tr *mtlsTransporter) Dial(addr string, options ...DialOption) (conn net.Co
option(opts)
}

if len(opts.IPs) > 0 {
count := atomic.AddUint64(&tr.count, 1)
_, sport, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
n := uint64(len(opts.IPs))
addr = opts.IPs[int(count%n)] + ":" + sport
}

tr.sessionMutex.Lock()
defer tr.sessionMutex.Unlock()

session, ok := tr.sessions[addr] // TODO: the addr may be changed.
session, ok := tr.sessions[addr]
if !ok {
if opts.Chain == nil {
conn, err = net.DialTimeout("tcp", addr, opts.Timeout)
Expand Down
25 changes: 2 additions & 23 deletions ws.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
"net/http"
"net/http/httputil"
"sync"
"sync/atomic"
"time"

"net/url"
Expand Down Expand Up @@ -155,20 +154,10 @@ func (tr *mwsTransporter) Dial(addr string, options ...DialOption) (conn net.Con
option(opts)
}

if len(opts.IPs) > 0 {
count := atomic.AddUint64(&tr.count, 1)
_, sport, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
n := uint64(len(opts.IPs))
addr = opts.IPs[int(count%n)] + ":" + sport
}

tr.sessionMutex.Lock()
defer tr.sessionMutex.Unlock()

session, ok := tr.sessions[addr] // TODO: the addr may be changed.
session, ok := tr.sessions[addr]
if !ok {
if opts.Chain == nil {
conn, err = net.DialTimeout("tcp", addr, opts.Timeout)
Expand Down Expand Up @@ -288,20 +277,10 @@ func (tr *mwssTransporter) Dial(addr string, options ...DialOption) (conn net.Co
option(opts)
}

if len(opts.IPs) > 0 {
count := atomic.AddUint64(&tr.count, 1)
_, sport, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
n := uint64(len(opts.IPs))
addr = opts.IPs[int(count%n)] + ":" + sport
}

tr.sessionMutex.Lock()
defer tr.sessionMutex.Unlock()

session, ok := tr.sessions[addr] // TODO: the addr may be changed.
session, ok := tr.sessions[addr]
if !ok {
if opts.Chain == nil {
conn, err = net.DialTimeout("tcp", addr, opts.Timeout)
Expand Down

0 comments on commit e3120ca

Please sign in to comment.