Skip to content

Commit

Permalink
refactor: refine pkg net utils (fatedier#2720)
Browse files Browse the repository at this point in the history
* refactor: refine pkg net utils

* fix: x

Co-authored-by: blizard863 <[email protected]>
  • Loading branch information
bingtianbaihua and detry863 authored Dec 28, 2021
1 parent 0fb6aee commit ea568e8
Show file tree
Hide file tree
Showing 6 changed files with 142 additions and 60 deletions.
7 changes: 5 additions & 2 deletions client/control.go
Original file line number Diff line number Diff line change
Expand Up @@ -234,8 +234,11 @@ func (ctl *Control) connectServer() (conn net.Conn, err error) {
}
}

address := net.JoinHostPort(ctl.clientCfg.ServerAddr, strconv.Itoa(ctl.clientCfg.ServerPort))
conn, err = frpNet.ConnectServerByProxyWithTLS(ctl.clientCfg.HTTPProxy, ctl.clientCfg.Protocol, address, tlsConfig, ctl.clientCfg.DisableCustomTLSFirstByte)
conn, err = frpNet.DialWithOptions(net.JoinHostPort(ctl.clientCfg.ServerAddr, strconv.Itoa(ctl.clientCfg.ServerPort)),
frpNet.WithProxyURL(ctl.clientCfg.HTTPProxy),
frpNet.WithProtocol(ctl.clientCfg.Protocol),
frpNet.WithTLSConfig(tlsConfig),
frpNet.WithDisableCustomTLSHeadByte(ctl.clientCfg.DisableCustomTLSFirstByte))

if err != nil {
xl.Warn("start new connection to server error: %v", err)
Expand Down
2 changes: 1 addition & 1 deletion client/proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -790,7 +790,7 @@ func HandleTCPWorkConnection(ctx context.Context, localInfo *config.LocalSvrConf
return
}

localConn, err := frpNet.ConnectServer("tcp", fmt.Sprintf("%s:%d", localInfo.LocalIP, localInfo.LocalPort))
localConn, err := frpNet.DialWithOptions(net.JoinHostPort(localInfo.LocalIP, strconv.Itoa(localInfo.LocalPort)))
if err != nil {
workConn.Close()
xl.Error("connect to local service [%s:%d] error: %v", localInfo.LocalIP, localInfo.LocalPort, err)
Expand Down
8 changes: 6 additions & 2 deletions client/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -228,8 +228,12 @@ func (svr *Service) login() (conn net.Conn, session *fmux.Session, err error) {
}
}

address := net.JoinHostPort(svr.cfg.ServerAddr, strconv.Itoa(svr.cfg.ServerPort))
conn, err = frpNet.ConnectServerByProxyWithTLS(svr.cfg.HTTPProxy, svr.cfg.Protocol, address, tlsConfig, svr.cfg.DisableCustomTLSFirstByte)
conn, err = frpNet.DialWithOptions(net.JoinHostPort(svr.cfg.ServerAddr, strconv.Itoa(svr.cfg.ServerPort)),
frpNet.WithProxyURL(svr.cfg.HTTPProxy),
frpNet.WithProtocol(svr.cfg.Protocol),
frpNet.WithTLSConfig(tlsConfig),
frpNet.WithDisableCustomTLSHeadByte(svr.cfg.DisableCustomTLSFirstByte))

if err != nil {
return
}
Expand Down
70 changes: 41 additions & 29 deletions pkg/util/net/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,16 @@ package net

import (
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"net"
"net/url"
"sync/atomic"
"time"

"github.com/fatedier/frp/pkg/util/xlog"
"golang.org/x/net/websocket"

gnet "github.com/fatedier/golib/net"
kcp "github.com/fatedier/kcp-go"
Expand Down Expand Up @@ -194,50 +195,61 @@ func ConnectServer(protocol string, addr string) (c net.Conn, err error) {
case "tcp":
return net.Dial("tcp", addr)
case "kcp":
kcpConn, errRet := kcp.DialWithOptions(addr, nil, 10, 3)
if errRet != nil {
err = errRet
return
}
kcpConn.SetStreamMode(true)
kcpConn.SetWriteDelay(true)
kcpConn.SetNoDelay(1, 20, 2, 1)
kcpConn.SetWindowSize(128, 512)
kcpConn.SetMtu(1350)
kcpConn.SetACKNoDelay(false)
kcpConn.SetReadBuffer(4194304)
kcpConn.SetWriteBuffer(4194304)
c = kcpConn
return
return DialKCPServer(addr)
case "websocket":
return DialWebsocketServer(addr)
default:
return nil, fmt.Errorf("unsupport protocol: %s", protocol)
}
}

func DialKCPServer(addr string) (c net.Conn, err error) {
kcpConn, errRet := kcp.DialWithOptions(addr, nil, 10, 3)
if errRet != nil {
err = errRet
return
}
kcpConn.SetStreamMode(true)
kcpConn.SetWriteDelay(true)
kcpConn.SetNoDelay(1, 20, 2, 1)
kcpConn.SetWindowSize(128, 512)
kcpConn.SetMtu(1350)
kcpConn.SetACKNoDelay(false)
kcpConn.SetReadBuffer(4194304)
kcpConn.SetWriteBuffer(4194304)
c = kcpConn
return
}

func ConnectServerByProxy(proxyURL string, protocol string, addr string) (c net.Conn, err error) {
switch protocol {
case "tcp":
return gnet.DialTcpByProxy(proxyURL, addr)
case "kcp":
// http proxy is not supported for kcp
return ConnectServer(protocol, addr)
case "websocket":
return ConnectWebsocketServer(addr)
default:
return nil, fmt.Errorf("unsupport protocol: %s", protocol)
return nil, fmt.Errorf("unsupport protocol: %s when connecting by proxy", protocol)
}
}

func ConnectServerByProxyWithTLS(proxyURL string, protocol string, addr string, tlsConfig *tls.Config, disableCustomTLSHeadByte bool) (c net.Conn, err error) {
c, err = ConnectServerByProxy(proxyURL, protocol, addr)
// addr: domain:port
func DialWebsocketServer(addr string) (net.Conn, error) {
addr = "ws://" + addr + FrpWebsocketPath
uri, err := url.Parse(addr)
if err != nil {
return
return nil, err
}

if tlsConfig == nil {
return
origin := "http://" + uri.Host
cfg, err := websocket.NewConfig(addr, origin)
if err != nil {
return nil, err
}
cfg.Dialer = &net.Dialer{
Timeout: 10 * time.Second,
}

c = WrapTLSClientConn(c, tlsConfig, disableCustomTLSHeadByte)
return
conn, err := websocket.DialConfig(cfg)
if err != nil {
return nil, err
}
return conn, nil
}
89 changes: 89 additions & 0 deletions pkg/util/net/dial.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
package net

import (
"crypto/tls"
"net"
)

type dialOptions struct {
proxyURL string
protocol string
tlsConfig *tls.Config
disableCustomTLSHeadByte bool
}

type DialOption interface {
apply(*dialOptions)
}

type EmptyDialOption struct{}

func (EmptyDialOption) apply(*dialOptions) {}

type funcDialOption struct {
f func(*dialOptions)
}

func (fdo *funcDialOption) apply(do *dialOptions) {
fdo.f(do)
}

func newFuncDialOption(f func(*dialOptions)) *funcDialOption {
return &funcDialOption{
f: f,
}
}

func DefaultDialOptions() dialOptions {
return dialOptions{
protocol: "tcp",
}
}

func WithProxyURL(proxyURL string) DialOption {
return newFuncDialOption(func(do *dialOptions) {
do.proxyURL = proxyURL
})
}

func WithTLSConfig(tlsConfig *tls.Config) DialOption {
return newFuncDialOption(func(do *dialOptions) {
do.tlsConfig = tlsConfig
})
}

func WithDisableCustomTLSHeadByte(disableCustomTLSHeadByte bool) DialOption {
return newFuncDialOption(func(do *dialOptions) {
do.disableCustomTLSHeadByte = disableCustomTLSHeadByte
})
}

func WithProtocol(protocol string) DialOption {
return newFuncDialOption(func(do *dialOptions) {
do.protocol = protocol
})
}

func DialWithOptions(addr string, opts ...DialOption) (c net.Conn, err error) {
op := DefaultDialOptions()

for _, opt := range opts {
opt.apply(&op)
}

if op.proxyURL == "" {
c, err = ConnectServer(op.protocol, addr)
} else {
c, err = ConnectServerByProxy(op.proxyURL, op.protocol, addr)
}
if err != nil {
return nil, err
}

if op.tlsConfig == nil {
return
}

c = WrapTLSClientConn(c, op.tlsConfig, op.disableCustomTLSHeadByte)
return
}
26 changes: 0 additions & 26 deletions pkg/util/net/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@ import (
"fmt"
"net"
"net/http"
"net/url"
"time"

"golang.org/x/net/websocket"
)
Expand Down Expand Up @@ -77,27 +75,3 @@ func (p *WebsocketListener) Close() error {
func (p *WebsocketListener) Addr() net.Addr {
return p.ln.Addr()
}

// addr: domain:port
func ConnectWebsocketServer(addr string) (net.Conn, error) {
addr = "ws://" + addr + FrpWebsocketPath
uri, err := url.Parse(addr)
if err != nil {
return nil, err
}

origin := "http://" + uri.Host
cfg, err := websocket.NewConfig(addr, origin)
if err != nil {
return nil, err
}
cfg.Dialer = &net.Dialer{
Timeout: 10 * time.Second,
}

conn, err := websocket.DialConfig(cfg)
if err != nil {
return nil, err
}
return conn, nil
}

0 comments on commit ea568e8

Please sign in to comment.