Skip to content

Commit

Permalink
fix issue with out-of-order packets
Browse files Browse the repository at this point in the history
  • Loading branch information
eikenb committed Feb 19, 2017
1 parent 99921c4 commit 94ae9c6
Show file tree
Hide file tree
Showing 5 changed files with 126 additions and 65 deletions.
2 changes: 1 addition & 1 deletion examples/request-server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (
"net"
"os"

"github.com/pkg/sftp"
"github.com/eikenb/sftp"
"golang.org/x/crypto/ssh"
)

Expand Down
36 changes: 23 additions & 13 deletions request-example.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"io"
"os"
"path/filepath"
"sync"
"syscall"
"time"
)
Expand All @@ -24,7 +25,7 @@ func InMemHandler() Handlers {
}

// Handlers
func (fs *root) Fileread(r *Request) (io.Reader, error) {
func (fs *root) Fileread(r *Request) (io.ReaderAt, error) {
file, err := fs.fetch(r.Filepath)
if err != nil {
return nil, err
Expand All @@ -35,10 +36,10 @@ func (fs *root) Fileread(r *Request) (io.Reader, error) {
return nil, err
}
}
return file.Reader()
return file.ReaderAt()
}

func (fs *root) Filewrite(r *Request) (io.Writer, error) {
func (fs *root) Filewrite(r *Request) (io.WriterAt, error) {
file, err := fs.fetch(r.Filepath)
if err == os.ErrNotExist {
dir, err := fs.fetch(filepath.Dir(r.Filepath))
Expand All @@ -51,7 +52,7 @@ func (fs *root) Filewrite(r *Request) (io.Writer, error) {
file = newMemFile(r.Filepath, false)
fs.files[r.Filepath] = file
}
return file.Writer()
return file.WriterAt()
}

func (fs *root) Filecmd(r *Request) error {
Expand Down Expand Up @@ -144,11 +145,12 @@ func (fs *root) fetch(path string) (*memFile, error) {
// Implements os.FileInfo, Reader and Writer interfaces.
// These are the 3 interfaces necessary for the Handlers.
type memFile struct {
name string
content []byte
modtime time.Time
symlink string
isdir bool
name string
modtime time.Time
symlink string
isdir bool
content []byte
contentLock sync.RWMutex
}

// factory to make sure modtime is set
Expand Down Expand Up @@ -180,22 +182,30 @@ func (f *memFile) Sys() interface{} {
}

// Read/Write
func (f *memFile) Reader() (io.Reader, error) {
func (f *memFile) ReaderAt() (io.ReaderAt, error) {
if f.isdir {
return nil, os.ErrInvalid
}
return bytes.NewReader(f.content), nil
}

func (f *memFile) Writer() (io.Writer, error) {
func (f *memFile) WriterAt() (io.WriterAt, error) {
if f.isdir {
return nil, os.ErrInvalid
}
return f, nil
}
func (f *memFile) Write(p []byte) (int, error) {
func (f *memFile) WriteAt(p []byte, off int64) (int, error) {
// mimic write delays, should be optional
time.Sleep(time.Microsecond * time.Duration(len(p)))
f.content = append(f.content, p...)
f.contentLock.Lock()
defer f.contentLock.Unlock()
plen := len(p) + int(off)
if plen >= len(f.content) {
nc := make([]byte, plen)
copy(nc, f.content)
f.content = nc
}
copy(f.content[off:], p)
return len(p), nil
}
4 changes: 2 additions & 2 deletions request-interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@ import (

// FileReader should return an io.Reader for the filepath
type FileReader interface {
Fileread(*Request) (io.Reader, error)
Fileread(*Request) (io.ReaderAt, error)
}

// FileWriter should return an io.Writer for the filepath
type FileWriter interface {
Filewrite(*Request) (io.Writer, error)
Filewrite(*Request) (io.WriterAt, error)
}

// FileCmder should return an error (rename, remove, setstate, etc.)
Expand Down
92 changes: 63 additions & 29 deletions request.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"os"
"path"
"path/filepath"
"sync"
"syscall"
)

Expand All @@ -17,21 +18,47 @@ type Request struct {
Attrs []byte // convert to sub-struct
Target string // for renames and sym-links
// packet data
pkt_id uint32
data []byte
length uint32
packets []packet_data
packetsLock sync.RWMutex
// reader/writer from handlers
put_writer io.Writer
get_reader io.Reader
put_writer io.WriterAt
get_reader io.ReaderAt
eof bool // hack for readdir to keep eof state
}

type packet_data struct {
id uint32
data []byte
length uint32
offset int64
}

// Here mainly to specify that Filepath is required
func newRequest(path string) *Request {
request := &Request{Filepath: filepath.Clean(path)}
return request
}

// push packet_data into fifo
func (r *Request) pushPacket(pd packet_data) {
r.packetsLock.Lock()
defer r.packetsLock.Unlock()
r.packets = append(r.packets, pd)
}

// pop packet_data into fifo
func (r *Request) popPacket() packet_data {
r.packetsLock.Lock()
defer r.packetsLock.Unlock()
var pd packet_data
pd, r.packets = r.packets[0], r.packets[1:]
return pd
}

func (r *Request) pkt_id() uint32 {
return r.packets[0].id
}

// called from worker to handle packet/request
func (r *Request) handle(handlers Handlers) (responsePacket, error) {
var err error
Expand Down Expand Up @@ -59,13 +86,15 @@ func fileget(h FileReader, r *Request) (responsePacket, error) {
r.get_reader = reader
}
reader := r.get_reader
data := make([]byte, clamp(r.length, maxTxPacket))
n, err := reader.Read(data)

pd := r.popPacket()
data := make([]byte, clamp(pd.length, maxTxPacket))
n, err := reader.ReadAt(data, pd.offset)
if err != nil && (err != io.EOF || n == 0) {
return nil, err
}
return &sshFxpDataPacket{
ID: r.pkt_id,
ID: pd.id,
Length: uint32(n),
Data: data[:n],
}, nil
Expand All @@ -82,12 +111,13 @@ func fileput(h FileWriter, r *Request) (responsePacket, error) {
}
writer := r.put_writer

_, err := writer.Write(r.data)
pd := r.popPacket()
_, err := writer.WriteAt(pd.data, pd.offset)
if err != nil {
return nil, err
}
return &sshFxpStatusPacket{
ID: r.pkt_id,
ID: pd.id,
StatusError: StatusError{
Code: ssh_FX_OK,
}}, nil
Expand All @@ -100,7 +130,7 @@ func filecmd(h FileCmder, r *Request) (responsePacket, error) {
return nil, err
}
return &sshFxpStatusPacket{
ID: r.pkt_id,
ID: r.pkt_id(),
StatusError: StatusError{
Code: ssh_FX_OK,
}}, nil
Expand All @@ -119,7 +149,7 @@ func fileinfo(h FileInfoer, r *Request) (responsePacket, error) {
switch r.Method {
case "List":
dirname := path.Base(r.Filepath)
ret := &sshFxpNamePacket{ID: r.pkt_id}
ret := &sshFxpNamePacket{ID: r.pkt_id()}
for _, fi := range finfo {
ret.NameAttrs = append(ret.NameAttrs, sshFxpNameAttr{
Name: fi.Name(),
Expand All @@ -136,7 +166,7 @@ func fileinfo(h FileInfoer, r *Request) (responsePacket, error) {
return nil, err
}
return &sshFxpStatResponse{
ID: r.pkt_id,
ID: r.pkt_id(),
info: finfo[0],
}, nil
case "Readlink":
Expand All @@ -147,7 +177,7 @@ func fileinfo(h FileInfoer, r *Request) (responsePacket, error) {
}
filename := finfo[0].Name()
return &sshFxpNamePacket{
ID: r.pkt_id,
ID: r.pkt_id(),
NameAttrs: []sshFxpNameAttr{{
Name: filename,
LongName: filename,
Expand All @@ -161,50 +191,54 @@ func fileinfo(h FileInfoer, r *Request) (responsePacket, error) {
// populate attributes of request object from packet data
func (r *Request) populate(p interface{}) {
// r.Filepath should already be set
var pd packet_data
switch p := p.(type) {
case *sshFxpSetstatPacket:
r.Method = "Setstat"
r.Attrs = p.Attrs.([]byte)
r.pkt_id = p.id()
pd.id = p.id()
case *sshFxpFsetstatPacket:
r.Method = "Setstat"
r.Attrs = p.Attrs.([]byte)
r.pkt_id = p.id()
pd.id = p.id()
case *sshFxpRenamePacket:
r.Method = "Rename"
r.Target = filepath.Clean(p.Newpath)
r.pkt_id = p.id()
pd.id = p.id()
case *sshFxpSymlinkPacket:
r.Method = "Symlink"
r.Target = filepath.Clean(p.Linkpath)
r.pkt_id = p.id()
pd.id = p.id()
case *sshFxpReadPacket:
r.Method = "Get"
r.length = p.Len
r.pkt_id = p.id()
pd.length = p.Len
pd.offset = int64(p.Offset)
pd.id = p.id()
case *sshFxpWritePacket:
r.Method = "Put"
r.data = p.Data
r.length = p.Length
r.pkt_id = p.id()
pd.id = p.id()
pd.data = p.Data
pd.length = p.Length
pd.offset = int64(p.Offset)
case *sshFxpReaddirPacket:
r.Method = "List"
r.pkt_id = p.id()
pd.id = p.id()
case *sshFxpRemovePacket:
r.Method = "Remove"
r.pkt_id = p.id()
pd.id = p.id()
case *sshFxpStatPacket, *sshFxpLstatPacket, *sshFxpFstatPacket:
r.Method = "Stat"
r.pkt_id = p.(packet).id()
pd.id = p.(packet).id()
case *sshFxpRmdirPacket:
r.Method = "Rmdir"
r.pkt_id = p.id()
pd.id = p.id()
case *sshFxpReadlinkPacket:
r.Method = "Readlink"
r.pkt_id = p.id()
pd.id = p.id()
case *sshFxpMkdirPacket:
r.Method = "Mkdir"
r.pkt_id = p.id()
pd.id = p.id()
//r.Attrs are ignored in ./packet.go
}
r.pushPacket(pd)
}
Loading

0 comments on commit 94ae9c6

Please sign in to comment.