Skip to content

Commit

Permalink
Cleaned up, added checking that server is known
Browse files Browse the repository at this point in the history
  • Loading branch information
alexlyulkov committed Nov 5, 2015
1 parent 01cef17 commit 0748070
Show file tree
Hide file tree
Showing 11 changed files with 178 additions and 130 deletions.
33 changes: 19 additions & 14 deletions lib/reversetunnel/srv.go
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ func (s *server) FindSimilarSite(fqdn string) (RemoteSite, error) {
if result != -1 {
return s.sites[result], nil
} else {
return nil, trace.Errorf("Site not found")
return nil, trace.Errorf("site not found")
}
}

Expand Down Expand Up @@ -339,18 +339,15 @@ func (s *remoteSite) GetLastConnected() time.Time {
func (s *remoteSite) ConnectToServer(server, user string, auth []ssh.AuthMethod) (*ssh.Client, error) {
ch, _, err := s.conn.OpenChannel(chanTransport, nil)
if err != nil {
log.Errorf("remoteSite:connectToServer %v", err)
return nil, err
return nil, trace.Wrap(err)
}
// ask remote channel to dial
dialed, err := ch.SendRequest(chanTransportDialReq, true, []byte(server))
if err != nil {
log.Errorf("failed to process request: %v", err)
return nil, err
return nil, trace.Wrap(err)
}
if !dialed {
log.Errorf("remote end failed to dial: %v", err)
return nil, fmt.Errorf("remote server %v is not available", server)
return nil, trace.Errorf("remote server %v is not available", server)
}
transportConn := newChConn(s.conn, ch)
conn, chans, reqs, err := ssh.NewClientConn(
Expand All @@ -367,21 +364,29 @@ func (s *remoteSite) ConnectToServer(server, user string, auth []ssh.AuthMethod)
}

func (s *remoteSite) DialServer(server string) (net.Conn, error) {
// TODO: check if server is known
serverIsKnown := false
knownServers, err := s.GetServers()
fmt.Println(server, "Known Servers:", knownServers)
for _, srv := range knownServers {
if srv.Addr == server {
serverIsKnown = true
}
}
serverIsKnown = serverIsKnown
if !serverIsKnown {
return nil, trace.Errorf("can't dial server %v, server is unknown", server)
}
ch, _, err := s.conn.OpenChannel(chanTransport, nil)
if err != nil {
log.Errorf("remoteSite:connectToServer %v", err)
return nil, err
return nil, trace.Wrap(err)
}
// ask remote channel to dial
dialed, err := ch.SendRequest(chanTransportDialReq, true, []byte(server))
if err != nil {
log.Errorf("failed to process request: %v", err)
return nil, err
return nil, trace.Wrap(err)
}
if !dialed {
log.Errorf("remote end failed to dial: %v", err)
return nil, fmt.Errorf("remote server %v is not available", server)
return nil, trace.Errorf("remote server %v is not available", server)
}
return newChConn(s.conn, ch), nil
}
Expand Down
31 changes: 28 additions & 3 deletions lib/srv/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ import (
"fmt"
"io"
"strings"
"sync"

"github.com/gravitational/teleport/Godeps/_workspace/src/github.com/gravitational/log"
"github.com/gravitational/teleport/Godeps/_workspace/src/github.com/gravitational/trace"
"github.com/gravitational/teleport/Godeps/_workspace/src/golang.org/x/crypto/ssh"
)
Expand All @@ -21,7 +23,7 @@ type proxySubsys struct {
func parseProxySubsys(name string, srv *Server) (*proxySubsys, error) {
out := strings.Split(name, ":")
if len(out) != 3 {
return nil, fmt.Errorf("invalid format for proxy request: '%v', expected 'proxy:host:port'", name)
return nil, trace.Errorf("invalid format for proxy request: '%v', expected 'proxy:host:port'", name)
}
return &proxySubsys{
srv: srv,
Expand All @@ -41,8 +43,31 @@ func (t *proxySubsys) execute(sconn *ssh.ServerConn, ch ssh.Channel, req *ssh.Re
}

conn, err := remoteSrv.DialServer(t.host + ":" + t.port)
if err != nil {
return trace.Wrap(err)
}

wg := &sync.WaitGroup{}
wg.Add(2)

go func() {
defer wg.Done()
_, err := io.Copy(ch, conn)
if err != nil {
log.Errorf(err.Error())
}
ch.Close()
}()
go func() {
defer wg.Done()
_, err := io.Copy(conn, ch)
if err != nil {
log.Errorf(err.Error())
}
conn.Close()
}()

wg.Wait()

go io.Copy(ch, conn)
io.Copy(conn, ch)
return nil
}
10 changes: 5 additions & 5 deletions lib/srv/srv.go
Original file line number Diff line number Diff line change
Expand Up @@ -266,22 +266,22 @@ func (s *Server) HandleRequest(r *ssh.Request) {
}

func (s *Server) HandleNewChan(sconn *ssh.ServerConn, nch ssh.NewChannel) {
cht := nch.ChannelType()
channelType := nch.ChannelType()

if s.proxyMode {
if cht == "session" { // interactive sessions
if channelType == "session" { // interactive sessions
ch, requests, err := nch.Accept()
if err != nil {
log.Infof("could not accept channel (%s)", err)
}
go s.handleSessionRequests(sconn, ch, requests)
} else {
nch.Reject(ssh.UnknownChannelType, fmt.Sprintf("unknown channel type: %v", cht))
nch.Reject(ssh.UnknownChannelType, fmt.Sprintf("unknown channel type: %v", channelType))
}
return
}

switch cht {
switch channelType {
case "session": // interactive sessions
ch, requests, err := nch.Accept()
if err != nil {
Expand All @@ -300,7 +300,7 @@ func (s *Server) HandleNewChan(sconn *ssh.ServerConn, nch ssh.NewChannel) {
}
go s.handleDirectTCPIPRequest(sconn, sshCh, req)
default:
nch.Reject(ssh.UnknownChannelType, fmt.Sprintf("unknown channel type: %v", cht))
nch.Reject(ssh.UnknownChannelType, fmt.Sprintf("unknown channel type: %v", channelType))
}
}

Expand Down
72 changes: 56 additions & 16 deletions lib/srv/srv_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,15 @@ import (
func TestSrv(t *testing.T) { TestingT(t) }

type SrvSuite struct {
srv *Server
clt *ssh.Client
bk *encryptedbk.ReplicatedBackend
a *auth.AuthServer
up *upack
scrt secret.SecretService
signer ssh.Signer
dir string
srv *Server
srvAddress string
clt *ssh.Client
bk *encryptedbk.ReplicatedBackend
a *auth.AuthServer
up *upack
scrt secret.SecretService
signer ssh.Signer
dir string
}

var _ = Suite(&SrvSuite{})
Expand Down Expand Up @@ -80,9 +81,9 @@ func (s *SrvSuite) SetUpTest(c *C) {
c.Assert(err, IsNil)

ap := auth.NewBackendAccessPoint(s.bk)

s.srvAddress = "localhost:30185"
srv, err := New(
utils.NetAddr{Network: "tcp", Addr: "localhost:30185"},
utils.NetAddr{Network: "tcp", Addr: s.srvAddress},
[]ssh.Signer{s.signer},
ap,
SetShell("/bin/sh"),
Expand Down Expand Up @@ -274,22 +275,25 @@ func (s *SrvSuite) TestProxy(c *C) {
Auth: []ssh.AuthMethod{ssh.PublicKeys(up.certSigner)},
}

// Trying to connect to unregistered ssh node

client, err := ssh.Dial("tcp", proxy.Addr(), sshConfig)
c.Assert(err, IsNil)
c.Assert(agent.ForwardToAgent(client, keyring), IsNil)

se, err := client.NewSession()
se0, err := client.NewSession()
c.Assert(err, IsNil)
defer se.Close()
defer se0.Close()

writer, err := se.StdinPipe()
writer, err := se0.StdinPipe()
c.Assert(err, IsNil)

reader, err := se.StdoutPipe()
reader, err := se0.StdoutPipe()
c.Assert(err, IsNil)

// Request opening TCP connection to the remote host
c.Assert(se.RequestSubsystem(fmt.Sprintf("proxy:%v", s.srv.Addr())), IsNil)
unregisteredAddress := s.srv.Addr() // proper ssh node address but with 127.0.0.1 instead of localhost
c.Assert(se0.RequestSubsystem(fmt.Sprintf("proxy:%v", unregisteredAddress)), IsNil)

local, err := net.ResolveTCPAddr("tcp", proxy.Addr())
c.Assert(err, IsNil)
Expand All @@ -299,14 +303,50 @@ func (s *SrvSuite) TestProxy(c *C) {
pipeNetConn := utils.NewPipeNetConn(
reader,
writer,
se,
se0,
local,
remote,
)

// Open SSH connection via TCP
conn, chans, reqs, err := ssh.NewClientConn(pipeNetConn,
s.srv.Addr(), sshConfig)
c.Assert(err, NotNil)

// Connect to node using registered address
client, err = ssh.Dial("tcp", proxy.Addr(), sshConfig)
c.Assert(err, IsNil)
c.Assert(agent.ForwardToAgent(client, keyring), IsNil)

se, err := client.NewSession()
c.Assert(err, IsNil)
defer se.Close()

writer, err = se.StdinPipe()
c.Assert(err, IsNil)

reader, err = se.StdoutPipe()
c.Assert(err, IsNil)

// Request opening TCP connection to the remote host
c.Assert(se.RequestSubsystem(fmt.Sprintf("proxy:%v", s.srvAddress)), IsNil)

local, err = net.ResolveTCPAddr("tcp", proxy.Addr())
c.Assert(err, IsNil)
remote, err = net.ResolveTCPAddr("tcp", s.srv.Addr())
c.Assert(err, IsNil)

pipeNetConn = utils.NewPipeNetConn(
reader,
writer,
se,
local,
remote,
)

// Open SSH connection via TCP
conn, chans, reqs, err = ssh.NewClientConn(pipeNetConn,
s.srv.Addr(), sshConfig)
c.Assert(err, IsNil)

// using this connection as regular SSH
Expand Down
72 changes: 29 additions & 43 deletions lib/teleagent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package teleagent
import (
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/agent"
"io"
"net"
"time"

Expand All @@ -15,7 +16,14 @@ import (
)

type TeleAgent struct {
keys []Key
agent agent.Agent
}

func NewTeleAgent() *TeleAgent {
ta := TeleAgent{
agent: agent.NewKeyring(),
}
return &ta
}

func (a *TeleAgent) Start(agentAddr string) error {
Expand All @@ -32,48 +40,23 @@ func (a *TeleAgent) Start(agentAddr string) error {
go func() {
for {
conn, err := l.Accept()
ag, err := a.GetAgent()
if err != nil {
log.Errorf(err.Error())
} else {
go func() {
if err := agent.ServeAgent(ag, conn); err != nil {
continue
}
go func() {
if err := agent.ServeAgent(a.agent, conn); err != nil {
if err != io.EOF {
log.Errorf(err.Error())
}
}()
}
}
}()
}
}()

return nil
}

func (a *TeleAgent) GetAgent() (agent.Agent, error) {
ag := agent.NewKeyring()

for _, key := range a.keys {
k, err := ssh.ParseRawPrivateKey(key.Priv)
if err != nil {
log.Errorf("failed to add: %v", err)
return nil, trace.Wrap(err)
}
addedKey := agent.AddedKey{
PrivateKey: k,
Certificate: key.Cert,
Comment: "",
LifetimeSecs: 0,
ConfirmBeforeUse: false,
}
if err := ag.Add(addedKey); err != nil {
log.Errorf("failed to add: %v", err)
return nil, trace.Wrap(err)
}
}

return ag, nil

}

func (a *TeleAgent) Login(proxyAddr string, user string, pass string,
hotpToken string, ttl time.Duration) error {
priv, pub, err := native.New().GenerateKeyPair("")
Expand All @@ -92,21 +75,24 @@ func (a *TeleAgent) Login(proxyAddr string, user string, pass string,
return trace.Wrap(err)
}

key := Key{
Priv: priv,
Cert: pcert.(*ssh.Certificate),
pk, err := ssh.ParseRawPrivateKey(priv)
if err != nil {
return trace.Wrap(err)
}
addedKey := agent.AddedKey{
PrivateKey: pk,
Certificate: pcert.(*ssh.Certificate),
Comment: "",
LifetimeSecs: 0,
ConfirmBeforeUse: false,
}
if err := a.agent.Add(addedKey); err != nil {
return trace.Wrap(err)
}

a.keys = append(a.keys, key)

return nil
}

type Key struct {
Priv []byte
Cert *ssh.Certificate
}

const (
DefaultAgentAddress = "unix:///tmp/teleport.agent.sock"
)
Loading

0 comments on commit 0748070

Please sign in to comment.