diff --git a/client/control.go b/client/control.go index 067fe37f581..b563be44618 100644 --- a/client/control.go +++ b/client/control.go @@ -234,8 +234,11 @@ func (ctl *Control) connectServer() (conn net.Conn, err error) { } } - address := net.JoinHostPort(ctl.clientCfg.ServerAddr, strconv.Itoa(ctl.clientCfg.ServerPort)) - conn, err = frpNet.ConnectServerByProxyWithTLS(ctl.clientCfg.HTTPProxy, ctl.clientCfg.Protocol, address, tlsConfig, ctl.clientCfg.DisableCustomTLSFirstByte) + conn, err = frpNet.DialWithOptions(net.JoinHostPort(ctl.clientCfg.ServerAddr, strconv.Itoa(ctl.clientCfg.ServerPort)), + frpNet.WithProxyURL(ctl.clientCfg.HTTPProxy), + frpNet.WithProtocol(ctl.clientCfg.Protocol), + frpNet.WithTLSConfig(tlsConfig), + frpNet.WithDisableCustomTLSHeadByte(ctl.clientCfg.DisableCustomTLSFirstByte)) if err != nil { xl.Warn("start new connection to server error: %v", err) diff --git a/client/proxy/proxy.go b/client/proxy/proxy.go index 47ab03ca42f..c535df5989d 100644 --- a/client/proxy/proxy.go +++ b/client/proxy/proxy.go @@ -790,7 +790,7 @@ func HandleTCPWorkConnection(ctx context.Context, localInfo *config.LocalSvrConf return } - localConn, err := frpNet.ConnectServer("tcp", fmt.Sprintf("%s:%d", localInfo.LocalIP, localInfo.LocalPort)) + localConn, err := frpNet.DialWithOptions(net.JoinHostPort(localInfo.LocalIP, strconv.Itoa(localInfo.LocalPort))) if err != nil { workConn.Close() xl.Error("connect to local service [%s:%d] error: %v", localInfo.LocalIP, localInfo.LocalPort, err) diff --git a/client/service.go b/client/service.go index 8b88003477a..d0f38453181 100644 --- a/client/service.go +++ b/client/service.go @@ -228,8 +228,12 @@ func (svr *Service) login() (conn net.Conn, session *fmux.Session, err error) { } } - address := net.JoinHostPort(svr.cfg.ServerAddr, strconv.Itoa(svr.cfg.ServerPort)) - conn, err = frpNet.ConnectServerByProxyWithTLS(svr.cfg.HTTPProxy, svr.cfg.Protocol, address, tlsConfig, svr.cfg.DisableCustomTLSFirstByte) + conn, err = frpNet.DialWithOptions(net.JoinHostPort(svr.cfg.ServerAddr, strconv.Itoa(svr.cfg.ServerPort)), + frpNet.WithProxyURL(svr.cfg.HTTPProxy), + frpNet.WithProtocol(svr.cfg.Protocol), + frpNet.WithTLSConfig(tlsConfig), + frpNet.WithDisableCustomTLSHeadByte(svr.cfg.DisableCustomTLSFirstByte)) + if err != nil { return } diff --git a/pkg/util/net/conn.go b/pkg/util/net/conn.go index ccb199e5d7c..366357d023e 100644 --- a/pkg/util/net/conn.go +++ b/pkg/util/net/conn.go @@ -16,15 +16,16 @@ package net import ( "context" - "crypto/tls" "errors" "fmt" "io" "net" + "net/url" "sync/atomic" "time" "github.com/fatedier/frp/pkg/util/xlog" + "golang.org/x/net/websocket" gnet "github.com/fatedier/golib/net" kcp "github.com/fatedier/kcp-go" @@ -194,50 +195,61 @@ func ConnectServer(protocol string, addr string) (c net.Conn, err error) { case "tcp": return net.Dial("tcp", addr) case "kcp": - kcpConn, errRet := kcp.DialWithOptions(addr, nil, 10, 3) - if errRet != nil { - err = errRet - return - } - kcpConn.SetStreamMode(true) - kcpConn.SetWriteDelay(true) - kcpConn.SetNoDelay(1, 20, 2, 1) - kcpConn.SetWindowSize(128, 512) - kcpConn.SetMtu(1350) - kcpConn.SetACKNoDelay(false) - kcpConn.SetReadBuffer(4194304) - kcpConn.SetWriteBuffer(4194304) - c = kcpConn - return + return DialKCPServer(addr) + case "websocket": + return DialWebsocketServer(addr) default: return nil, fmt.Errorf("unsupport protocol: %s", protocol) } } +func DialKCPServer(addr string) (c net.Conn, err error) { + kcpConn, errRet := kcp.DialWithOptions(addr, nil, 10, 3) + if errRet != nil { + err = errRet + return + } + kcpConn.SetStreamMode(true) + kcpConn.SetWriteDelay(true) + kcpConn.SetNoDelay(1, 20, 2, 1) + kcpConn.SetWindowSize(128, 512) + kcpConn.SetMtu(1350) + kcpConn.SetACKNoDelay(false) + kcpConn.SetReadBuffer(4194304) + kcpConn.SetWriteBuffer(4194304) + c = kcpConn + return +} + func ConnectServerByProxy(proxyURL string, protocol string, addr string) (c net.Conn, err error) { switch protocol { case "tcp": return gnet.DialTcpByProxy(proxyURL, addr) - case "kcp": - // http proxy is not supported for kcp - return ConnectServer(protocol, addr) - case "websocket": - return ConnectWebsocketServer(addr) default: - return nil, fmt.Errorf("unsupport protocol: %s", protocol) + return nil, fmt.Errorf("unsupport protocol: %s when connecting by proxy", protocol) } } -func ConnectServerByProxyWithTLS(proxyURL string, protocol string, addr string, tlsConfig *tls.Config, disableCustomTLSHeadByte bool) (c net.Conn, err error) { - c, err = ConnectServerByProxy(proxyURL, protocol, addr) +// addr: domain:port +func DialWebsocketServer(addr string) (net.Conn, error) { + addr = "ws://" + addr + FrpWebsocketPath + uri, err := url.Parse(addr) if err != nil { - return + return nil, err } - if tlsConfig == nil { - return + origin := "http://" + uri.Host + cfg, err := websocket.NewConfig(addr, origin) + if err != nil { + return nil, err + } + cfg.Dialer = &net.Dialer{ + Timeout: 10 * time.Second, } - c = WrapTLSClientConn(c, tlsConfig, disableCustomTLSHeadByte) - return + conn, err := websocket.DialConfig(cfg) + if err != nil { + return nil, err + } + return conn, nil } diff --git a/pkg/util/net/dial.go b/pkg/util/net/dial.go new file mode 100644 index 00000000000..2549821d0d6 --- /dev/null +++ b/pkg/util/net/dial.go @@ -0,0 +1,89 @@ +package net + +import ( + "crypto/tls" + "net" +) + +type dialOptions struct { + proxyURL string + protocol string + tlsConfig *tls.Config + disableCustomTLSHeadByte bool +} + +type DialOption interface { + apply(*dialOptions) +} + +type EmptyDialOption struct{} + +func (EmptyDialOption) apply(*dialOptions) {} + +type funcDialOption struct { + f func(*dialOptions) +} + +func (fdo *funcDialOption) apply(do *dialOptions) { + fdo.f(do) +} + +func newFuncDialOption(f func(*dialOptions)) *funcDialOption { + return &funcDialOption{ + f: f, + } +} + +func DefaultDialOptions() dialOptions { + return dialOptions{ + protocol: "tcp", + } +} + +func WithProxyURL(proxyURL string) DialOption { + return newFuncDialOption(func(do *dialOptions) { + do.proxyURL = proxyURL + }) +} + +func WithTLSConfig(tlsConfig *tls.Config) DialOption { + return newFuncDialOption(func(do *dialOptions) { + do.tlsConfig = tlsConfig + }) +} + +func WithDisableCustomTLSHeadByte(disableCustomTLSHeadByte bool) DialOption { + return newFuncDialOption(func(do *dialOptions) { + do.disableCustomTLSHeadByte = disableCustomTLSHeadByte + }) +} + +func WithProtocol(protocol string) DialOption { + return newFuncDialOption(func(do *dialOptions) { + do.protocol = protocol + }) +} + +func DialWithOptions(addr string, opts ...DialOption) (c net.Conn, err error) { + op := DefaultDialOptions() + + for _, opt := range opts { + opt.apply(&op) + } + + if op.proxyURL == "" { + c, err = ConnectServer(op.protocol, addr) + } else { + c, err = ConnectServerByProxy(op.proxyURL, op.protocol, addr) + } + if err != nil { + return nil, err + } + + if op.tlsConfig == nil { + return + } + + c = WrapTLSClientConn(c, op.tlsConfig, op.disableCustomTLSHeadByte) + return +} diff --git a/pkg/util/net/websocket.go b/pkg/util/net/websocket.go index 36b6440c5b5..7030787e700 100644 --- a/pkg/util/net/websocket.go +++ b/pkg/util/net/websocket.go @@ -5,8 +5,6 @@ import ( "fmt" "net" "net/http" - "net/url" - "time" "golang.org/x/net/websocket" ) @@ -77,27 +75,3 @@ func (p *WebsocketListener) Close() error { func (p *WebsocketListener) Addr() net.Addr { return p.ln.Addr() } - -// addr: domain:port -func ConnectWebsocketServer(addr string) (net.Conn, error) { - addr = "ws://" + addr + FrpWebsocketPath - uri, err := url.Parse(addr) - if err != nil { - return nil, err - } - - origin := "http://" + uri.Host - cfg, err := websocket.NewConfig(addr, origin) - if err != nil { - return nil, err - } - cfg.Dialer = &net.Dialer{ - Timeout: 10 * time.Second, - } - - conn, err := websocket.DialConfig(cfg) - if err != nil { - return nil, err - } - return conn, nil -}