Skip to content

Commit

Permalink
decouple api module, update proxy_test.go
Browse files Browse the repository at this point in the history
  • Loading branch information
p4gefau1t committed May 2, 2020
1 parent 85b8817 commit 6ac495e
Show file tree
Hide file tree
Showing 7 changed files with 55 additions and 20 deletions.
5 changes: 5 additions & 0 deletions api/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"github.com/p4gefau1t/trojan-go/common"
"github.com/p4gefau1t/trojan-go/conf"
"github.com/p4gefau1t/trojan-go/log"
"github.com/p4gefau1t/trojan-go/proxy"
"github.com/p4gefau1t/trojan-go/stat"
"google.golang.org/grpc"
)
Expand Down Expand Up @@ -74,3 +75,7 @@ func RunClientAPI(ctx context.Context, config *conf.GlobalConfig, auth stat.Auth
return nil
}
}

func init() {
proxy.RegisterAPI(conf.Client, RunClientAPI)
}
5 changes: 5 additions & 0 deletions api/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"github.com/p4gefau1t/trojan-go/common"
"github.com/p4gefau1t/trojan-go/conf"
"github.com/p4gefau1t/trojan-go/log"
"github.com/p4gefau1t/trojan-go/proxy"
"github.com/p4gefau1t/trojan-go/stat"
grpc "google.golang.org/grpc"
)
Expand Down Expand Up @@ -173,3 +174,7 @@ func RunServerAPI(ctx context.Context, config *conf.GlobalConfig, auth stat.Auth
return nil
}
}

func init() {
proxy.RegisterAPI(conf.Server, RunServerAPI)
}
1 change: 1 addition & 0 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (

//the following modules are optional
//you can comment some of them if you don't need them
_ "github.com/p4gefau1t/trojan-go/api"
_ "github.com/p4gefau1t/trojan-go/cert"
_ "github.com/p4gefau1t/trojan-go/daemon"
_ "github.com/p4gefau1t/trojan-go/easy"
Expand Down
3 changes: 1 addition & 2 deletions proxy/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"io"
"net"

"github.com/p4gefau1t/trojan-go/api"
"github.com/p4gefau1t/trojan-go/common"
"github.com/p4gefau1t/trojan-go/conf"
"github.com/p4gefau1t/trojan-go/log"
Expand Down Expand Up @@ -322,7 +321,7 @@ func (c *Client) Run() error {
go c.listenTCP(errChan)
if c.config.API.Enabled {
go func() {
errChan <- api.RunClientAPI(c.ctx, c.config, c.auth)
errChan <- proxy.RunAPIService(conf.Client, c.ctx, c.config, c.auth)
}()
}
select {
Expand Down
23 changes: 20 additions & 3 deletions proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"github.com/p4gefau1t/trojan-go/log"
"github.com/p4gefau1t/trojan-go/protocol"
"github.com/p4gefau1t/trojan-go/router"
"github.com/p4gefau1t/trojan-go/stat"
)

type Buildable interface {
Expand Down Expand Up @@ -113,16 +114,32 @@ func ProxyPacketWithRouter(ctx context.Context, from protocol.PacketReadWriter,
}
}

var buildableMap = make(map[conf.RunType]Buildable)
var proxys = make(map[conf.RunType]Buildable)

func NewProxy(config *conf.GlobalConfig) (common.Runnable, error) {
runType := config.RunType
if buildable, found := buildableMap[runType]; found {
if buildable, found := proxys[runType]; found {
return buildable.Build(config)
}
return nil, common.NewError("invalid run_type " + string(runType))
}

func RegisterProxy(t conf.RunType, b Buildable) {
buildableMap[t] = b
proxys[t] = b
}

type APIRunner func(context.Context, *conf.GlobalConfig, stat.Authenticator) error

var apis = make(map[conf.RunType]APIRunner)

func RegisterAPI(t conf.RunType, r APIRunner) {
apis[t] = r
}

func RunAPIService(t conf.RunType, ctx context.Context, config *conf.GlobalConfig, auth stat.Authenticator) error {
r, ok := apis[t]
if !ok {
return common.NewError("api module for" + string(t) + "not found")
}
return r(ctx, config, auth)
}
3 changes: 1 addition & 2 deletions proxy/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"fmt"
"net"

"github.com/p4gefau1t/trojan-go/api"
"github.com/p4gefau1t/trojan-go/common"
"github.com/p4gefau1t/trojan-go/conf"
"github.com/p4gefau1t/trojan-go/log"
Expand Down Expand Up @@ -191,7 +190,7 @@ func (s *Server) Run() error {
if s.config.API.Enabled {
log.Info("api enabled")
go func() {
errChan <- api.RunServerAPI(s.ctx, s.config, s.auth)
errChan <- proxy.RunAPIService(conf.Server, s.ctx, s.config, s.auth)
}()
}
go s.ListenTCP(errChan)
Expand Down
35 changes: 22 additions & 13 deletions test/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"fmt"
"io/ioutil"
"net"
"os"
"sync"
"testing"
"time"
Expand Down Expand Up @@ -186,14 +187,20 @@ func addTCPOption(config *conf.GlobalConfig) *conf.GlobalConfig {
return config
}

func addMySQLConfig(config *conf.GlobalConfig) *conf.GlobalConfig {
func addMySQLConfig(t *testing.T, config *conf.GlobalConfig) *conf.GlobalConfig {
database := os.Getenv("mysql_database")
username := os.Getenv("mysql_username")
password := os.Getenv("mysql_password")
if database == "" || username == "" || password == "" {
t.Skip("skipping mysql test")
}
config.MySQL = conf.MySQLConfig{
Enabled: true,
ServerHost: "127.0.0.1",
ServerPort: 3306,
Database: "trojan",
Username: "root",
Password: "password",
Database: database,
Username: username,
Password: password,
CheckRate: 1,
}
return config
Expand Down Expand Up @@ -352,14 +359,15 @@ func MultiThreadSpeedTestClientServer(b *testing.B, clientConfig *conf.GlobalCon
cancel()
}

func TestIt(t *testing.T) {
/*
clientConfig := getBasicClientConfig()
serverConfig := getBasicServerConfig()
go RunClient(context.Background(), clientConfig)
go RunHelloHTTPServer(context.Background())
RunServer(context.Background(), serverConfig)
*/
func TestRealProxy(t *testing.T) {
if os.Getenv("real_test") == "" {
t.Skip("skipping real proxy test")
}
clientConfig := getBasicClientConfig()
serverConfig := getBasicServerConfig()
go RunClient(context.Background(), clientConfig)
go RunHelloHTTPServer(context.Background())
RunServer(context.Background(), serverConfig)
}

func TestNormal(t *testing.T) {
Expand Down Expand Up @@ -506,7 +514,7 @@ func TestTCPOptions(t *testing.T) {
}

func TestMySQL(t *testing.T) {
serverConfig := addMySQLConfig(getBasicServerConfig())
serverConfig := addMySQLConfig(t, getBasicServerConfig())
clientConfig := getBasicClientConfig()
clientConfig.Passwords = getPasswords("mysqlpassword")
clientConfig.Hash = getHash("mysqlpassword")
Expand All @@ -527,6 +535,7 @@ func TestServerAPI(t *testing.T) {

time.Sleep(time.Second * 2)
grpcConn, err := grpc.Dial("127.0.0.1:10000", grpc.WithInsecure())
common.Must(err)
server := api.NewTrojanServerServiceClient(grpcConn)

listUserStream, err := server.ListUsers(ctx, &api.ListUserRequest{})
Expand Down

0 comments on commit 6ac495e

Please sign in to comment.