Skip to content

Commit

Permalink
add secure connections for all communications
Browse files Browse the repository at this point in the history
Need testing. Only verified existing logic are not broken.
  • Loading branch information
chrislusf committed Feb 8, 2016
1 parent fbff8a2 commit 51927e1
Show file tree
Hide file tree
Showing 17 changed files with 205 additions and 118 deletions.
15 changes: 12 additions & 3 deletions agent/agent_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ package agent

import (
"bytes"
"crypto/tls"
"fmt"
"io/ioutil"
"log"
Expand All @@ -16,6 +17,7 @@ import (
"sync"

"github.com/chrislusf/glow/driver/cmd"
"github.com/chrislusf/glow/netchan"
"github.com/chrislusf/glow/resource"
"github.com/chrislusf/glow/resource/service_discovery/client"
"github.com/chrislusf/glow/util"
Expand All @@ -32,6 +34,7 @@ type AgentServerOption struct {
MemoryMB *int64
CPULevel *int
CleanRestart *bool
CertFiles netchan.CertFiles
}

type AgentServer struct {
Expand Down Expand Up @@ -69,7 +72,7 @@ func NewAgentServer(option *AgentServerOption) *AgentServer {
localExecutorManager: newLocalExecutorsManager(),
}

err = as.Init()
err = as.init()
if err != nil {
panic(err)
}
Expand All @@ -80,11 +83,17 @@ func NewAgentServer(option *AgentServerOption) *AgentServer {
// Start starts to listen on a port, returning the listening port
// r.Port can be pre-set or leave it as zero
// The actual port set to r.Port
func (r *AgentServer) Init() (err error) {
r.listener, err = net.Listen("tcp", ":"+strconv.Itoa(r.Port))
func (r *AgentServer) init() (err error) {
tlsConfig := r.Option.CertFiles.MakeTLSConfig()
if tlsConfig == nil {
r.listener, err = net.Listen("tcp", ":"+strconv.Itoa(r.Port))
} else {
r.listener, err = tls.Listen("tcp", ":"+strconv.Itoa(r.Port), tlsConfig)
}
if err != nil {
log.Fatal(err)
}
util.SetupHttpClient(tlsConfig)

r.Port = r.listener.Addr().(*net.TCPAddr).Port
fmt.Println("AgentServer starts on:", r.Port)
Expand Down
10 changes: 10 additions & 0 deletions driver/context_driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ import (
"github.com/chrislusf/glow/driver/rsync"
"github.com/chrislusf/glow/driver/scheduler"
"github.com/chrislusf/glow/flow"
"github.com/chrislusf/glow/netchan"
"github.com/chrislusf/glow/util"
)

type DriverOption struct {
Expand All @@ -27,6 +29,7 @@ type DriverOption struct {
RelatedFiles string
ShowFlowStats bool
ListenOn string
CertFiles netchan.CertFiles
}

func init() {
Expand All @@ -42,6 +45,9 @@ func init() {
flag.StringVar(&driverOption.RelatedFiles, "glow.related.files", "", strconv.QuoteRune(os.PathListSeparator)+" separated list of files")
flag.BoolVar(&driverOption.ShowFlowStats, "glow.flow.stat", false, "show flow details at the end of execution")
flag.StringVar(&driverOption.ListenOn, "glow.driver.listenOn", ":0", "listen on this address to copy itself and related files to agents")
flag.StringVar(&driverOption.CertFiles.CertFile, "cert.file", "", "A PEM eoncoded certificate file")
flag.StringVar(&driverOption.CertFiles.KeyFile, "key.file", "", "A PEM encoded private key file")
flag.StringVar(&driverOption.CertFiles.CaFile, "ca.file", "", "A PEM eoncoded CA's certificate file")

flow.RegisterContextRunner(NewFlowContextDriver(&driverOption))
}
Expand Down Expand Up @@ -81,6 +87,9 @@ func (fcd *FlowContextDriver) Run(fc *flow.FlowContext) {
return
}

tlsConfig := fcd.option.CertFiles.MakeTLSConfig()
util.SetupHttpClient(tlsConfig)

// start server to serve files to agents to run exectuors
rsyncServer, err := rsync.NewRsyncServer(fcd.option.ListenOn, os.Args[0], fcd.option.RelatedFileNames())
if err != nil {
Expand All @@ -99,6 +108,7 @@ func (fcd *FlowContextDriver) Run(fc *flow.FlowContext) {
Module: fcd.option.Module,
ExecutableFile: os.Args[0],
ExecutableFileHash: rsyncServer.ExecutableFileHash(),
TlsConfig: tlsConfig,
},
)

Expand Down
4 changes: 2 additions & 2 deletions driver/rsync/fetch_url.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ type ListFileResult struct {
}

func ListFiles(server string) ([]FileHash, error) {
jsonBlob, err := util.Get("http://" + server + "/list")
jsonBlob, err := util.Get(util.SchemePrefix + server + "/list")
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -48,7 +48,7 @@ func FetchFilesTo(driverAddress string, dir string) error {
// println("skip downloading same", fh.File)
continue
}
if err = FetchUrl("http://"+driverAddress+"/file/"+fh.File, toFile); err != nil {
if err = FetchUrl(util.SchemePrefix+driverAddress+"/file/"+fh.File, toFile); err != nil {
return fmt.Errorf("Failed to download file %s: %v", fh.File, err)
}
}
Expand Down
2 changes: 2 additions & 0 deletions driver/scheduler/scheduler.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
package scheduler

import (
"crypto/tls"
"sync"
"time"

Expand Down Expand Up @@ -40,6 +41,7 @@ type SchedulerOption struct {
Module string
ExecutableFile string
ExecutableFileHash string
TlsConfig *tls.Config
}

func NewScheduler(leader string, option *SchedulerOption) *Scheduler {
Expand Down
4 changes: 2 additions & 2 deletions driver/scheduler/scheduler_execute.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ func (s *Scheduler) setupInputChannels(fc *flow.FlowContext, task *flow.Task, lo
inputChanName := fmt.Sprintf("%s-ct-%d-input-%d-p-%d", s.option.ExecutableFileHash, fc.Id, ds.Id, i)
// println("setup input channel for", task.Name(), "on", location.URL())
s.shardLocator.SetShardLocation(inputChanName, location)
rawChan, err := netchan.GetDirectSendChannel(inputChanName, location.URL(), waitGroup)
rawChan, err := netchan.GetDirectSendChannel(s.option.TlsConfig, inputChanName, location.URL(), waitGroup)
if err != nil {
log.Panic(err)
}
Expand All @@ -106,7 +106,7 @@ func (s *Scheduler) setupOutputChannels(shards []*flow.DatasetShard, waitGroup *
// connect remote raw chan to local typed chan
readChanName := s.option.ExecutableFileHash + "-" + shard.Name()
location, _ := s.shardLocator.GetShardLocation(readChanName)
rawChan, err := netchan.GetDirectReadChannel(readChanName, location.URL(), 1024)
rawChan, err := netchan.GetDirectReadChannel(s.option.TlsConfig, readChanName, location.URL(), 1024)
if err != nil {
log.Panic(err)
}
Expand Down
2 changes: 1 addition & 1 deletion driver/scheduler/scheduler_fetch.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ func Assign(leader string, request *resource.AllocationRequest) (*resource.Alloc
values := make(url.Values)
requestBlob, _ := json.Marshal(request)
values.Add("request", string(requestBlob))
jsonBlob, err := util.Post("http://"+leader+"/agent/assign", values)
jsonBlob, err := util.Post(util.SchemePrefix+leader+"/agent/assign", values)
if err != nil {
return nil, err
}
Expand Down
2 changes: 2 additions & 0 deletions driver/task_option.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package driver

import (
"crypto/tls"
"flag"

"github.com/chrislusf/glow/flow"
Expand All @@ -14,6 +15,7 @@ type TaskOption struct {
ExecutableFileHash string
ChannelBufferSize int
RequestId uint64
TlsConfig *tls.Config
}

var taskOption TaskOption
Expand Down
6 changes: 3 additions & 3 deletions driver/task_runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ func (tr *TaskRunner) connectExternalInputs(wg *sync.WaitGroup, name2Location ma
d := shard.Parent
readChanName := tr.option.ExecutableFileHash + "-" + shard.Name()
// println("taskGroup", tr.option.TaskGroupId, "firstTask", firstTask.Name(), "trying to read from:", readChanName, len(firstTask.InputChans))
rawChan, err := netchan.GetDirectReadChannel(readChanName, name2Location[readChanName], tr.FlowContext.ChannelBufferSize)
rawChan, err := netchan.GetDirectReadChannel(tr.option.TlsConfig, readChanName, name2Location[readChanName], tr.FlowContext.ChannelBufferSize)
if err != nil {
log.Panic(err)
}
Expand All @@ -135,7 +135,7 @@ func (tr *TaskRunner) connectExternalInputChannels(wg *sync.WaitGroup) {
ds := firstTask.Outputs[0].Parent
for i, _ := range ds.ExternalInputChans {
inputChanName := fmt.Sprintf("%s-ct-%d-input-%d-p-%d", tr.option.ExecutableFileHash, tr.option.ContextId, ds.Id, i)
rawChan, err := netchan.GetLocalReadChannel(inputChanName, tr.FlowContext.ChannelBufferSize)
rawChan, err := netchan.GetLocalReadChannel(tr.option.TlsConfig, inputChanName, tr.FlowContext.ChannelBufferSize)
if err != nil {
log.Panic(err)
}
Expand All @@ -151,7 +151,7 @@ func (tr *TaskRunner) connectExternalOutputs(wg *sync.WaitGroup) {
for _, shard := range lastTask.Outputs {
writeChanName := tr.option.ExecutableFileHash + "-" + shard.Name()
// println("taskGroup", tr.option.TaskGroupId, "step", lastTask.Step.Id, "lastTask", lastTask.Id, "writing to:", writeChanName)
rawChan, err := netchan.GetLocalSendChannel(writeChanName, wg)
rawChan, err := netchan.GetLocalSendChannel(tr.option.TlsConfig, writeChanName, wg)
if err != nil {
log.Panic(err)
}
Expand Down
41 changes: 35 additions & 6 deletions glow.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,28 @@ import (
"strconv"
"sync"

"github.com/chrislusf/glow/netchan"

kingpin "gopkg.in/alecthomas/kingpin.v2"

a "github.com/chrislusf/glow/agent"
r "github.com/chrislusf/glow/netchan/receiver"
s "github.com/chrislusf/glow/netchan/sender"
m "github.com/chrislusf/glow/resource/service_discovery/master"
"github.com/chrislusf/glow/util"
)

var (
app = kingpin.New("glow", "A command-line net channel.")

master = app.Command("master", "Start a master process")
masterPort = master.Flag("port", "listening port").Default("8930").Int()
masterIp = master.Flag("ip", "listening IP adress").Default("localhost").String()
master = app.Command("master", "Start a master process")
masterPort = master.Flag("port", "listening port").Default("8930").Int()
masterIp = master.Flag("ip", "listening IP adress").Default("localhost").String()
masterCerts = netchan.CertFiles{
CertFile: *master.Flag("cert.file", "A PEM eoncoded certificate file").Default("").String(),
KeyFile: *master.Flag("key.file", "A PEM encoded private key file.").Default("").String(),
CaFile: *master.Flag("ca.file", "A PEM eoncoded CA's certificate file.").Default("").String(),
}

agent = app.Command("agent", "Channel Agent")
agentOption = &a.AgentServerOption{
Expand All @@ -34,27 +42,45 @@ var (
CPULevel: agent.Flag("cpu.level", "relative computing power of single cpu core").Default("1").Int(),
MemoryMB: agent.Flag("memory", "memory size in MB").Default("1024").Int64(),
CleanRestart: agent.Flag("clean.restart", "clean up previous dataset files").Default("true").Bool(),
CertFiles: netchan.CertFiles{
CertFile: *master.Flag("cert.file", "A PEM eoncoded certificate file").Default("").String(),
KeyFile: *master.Flag("key.file", "A PEM encoded private key file.").Default("").String(),
CaFile: *master.Flag("ca.file", "A PEM eoncoded CA's certificate file.").Default("").String(),
},
}

sender = app.Command("send", "Send data to a channel")
sendToChanName = sender.Flag("to", "Name of a channel").Required().String()
sendFile = sender.Flag("file", "file to post.").ExistingFile()
senderAgentPort = sender.Flag("port", "agent listening port").Default("8931").Int()
senderCerts = netchan.CertFiles{
CertFile: *sender.Flag("cert.file", "A PEM eoncoded certificate file").Default("").String(),
KeyFile: *sender.Flag("key.file", "A PEM encoded private key file.").Default("").String(),
CaFile: *sender.Flag("ca.file", "A PEM eoncoded CA's certificate file.").Default("").String(),
}
// sendDelimiter = sender.Flag("delimiter", "Verbose mode.").Short('d').String()

receiver = app.Command("receive", "Receive data from a channel")
receiveFromChanName = receiver.Flag("from", "Name of a source channel").Required().String()
receiverMaster = receiver.Flag("master", "ip:port format").Default("localhost:8930").String()
receiverCerts = netchan.CertFiles{
CertFile: *receiver.Flag("cert.file", "A PEM eoncoded certificate file").Default("").String(),
KeyFile: *receiver.Flag("key.file", "A PEM encoded private key file.").Default("").String(),
CaFile: *receiver.Flag("ca.file", "A PEM eoncoded CA's certificate file.").Default("").String(),
}
)

func main() {
switch kingpin.MustParse(app.Parse(os.Args[1:])) {
case master.FullCommand():
println("listening on", (*masterIp)+":"+strconv.Itoa(*masterPort))
m.RunMaster((*masterIp) + ":" + strconv.Itoa(*masterPort))
m.RunMaster(masterCerts.MakeTLSConfig(), (*masterIp)+":"+strconv.Itoa(*masterPort))
case sender.FullCommand():
tlsConfig := senderCerts.MakeTLSConfig()
util.SetupHttpClient(tlsConfig)

var wg sync.WaitGroup
sendChan, err := s.NewSendChannel(*sendToChanName, *senderAgentPort, &wg)
sendChan, err := s.NewSendChannel(tlsConfig, *sendToChanName, *senderAgentPort, &wg)
if err != nil {
panic(err)
}
Expand All @@ -81,8 +107,11 @@ func main() {
wg.Wait()

case receiver.FullCommand():
tlsConfig := receiverCerts.MakeTLSConfig()
util.SetupHttpClient(tlsConfig)

target := r.FindTarget(*receiveFromChanName, *receiverMaster)
rc := r.NewReceiveChannel(*receiveFromChanName, 0)
rc := r.NewReceiveChannel(tlsConfig, *receiveFromChanName, 0)
recvChan, err := rc.GetDirectChannel(target, 128)
if err != nil {
panic(err)
Expand Down
17 changes: 9 additions & 8 deletions netchan/channels.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ package netchan

import (
"bytes"
"crypto/tls"
"encoding/gob"
"flag"
"log"
Expand All @@ -26,21 +27,21 @@ func init() {
flag.IntVar(&Option.AgentPort, "glow.agent.port", 8931, "agent port")
}

func GetLocalSendChannel(name string, wg *sync.WaitGroup) (chan []byte, error) {
return sender.NewSendChannel(name, Option.AgentPort, wg)
func GetLocalSendChannel(tlsConfig *tls.Config, name string, wg *sync.WaitGroup) (chan []byte, error) {
return sender.NewSendChannel(tlsConfig, name, Option.AgentPort, wg)
}

func GetLocalReadChannel(name string, chanBufferSize int) (chan []byte, error) {
return GetDirectReadChannel(name, "localhost:"+strconv.Itoa(Option.AgentPort), chanBufferSize)
func GetLocalReadChannel(tlsConfig *tls.Config, name string, chanBufferSize int) (chan []byte, error) {
return GetDirectReadChannel(tlsConfig, name, "localhost:"+strconv.Itoa(Option.AgentPort), chanBufferSize)
}

func GetDirectReadChannel(name, location string, chanBufferSize int) (chan []byte, error) {
rc := receiver.NewReceiveChannel(name, 0)
func GetDirectReadChannel(tlsConfig *tls.Config, name, location string, chanBufferSize int) (chan []byte, error) {
rc := receiver.NewReceiveChannel(tlsConfig, name, 0)
return rc.GetDirectChannel(location, chanBufferSize)
}

func GetDirectSendChannel(name string, target string, wg *sync.WaitGroup) (chan []byte, error) {
return sender.NewDirectSendChannel(name, target, wg)
func GetDirectSendChannel(tlsConfig *tls.Config, name string, target string, wg *sync.WaitGroup) (chan []byte, error) {
return sender.NewDirectSendChannel(tlsConfig, name, target, wg)
}

func ConnectRawReadChannelToTyped(c chan []byte, out chan reflect.Value, t reflect.Type, wg *sync.WaitGroup) (status *util.ChannelStatus) {
Expand Down
Loading

0 comments on commit 51927e1

Please sign in to comment.