Skip to content

Commit

Permalink
Let Backends return their own error code (custom result or error) (fl…
Browse files Browse the repository at this point in the history
…ashmob#113)

add ability for backends to specify a custom return code, fixes flashmob#78
  • Loading branch information
flashmob authored Jun 8, 2018
1 parent 29dd65f commit 15a3295
Show file tree
Hide file tree
Showing 9 changed files with 269 additions and 170 deletions.
59 changes: 58 additions & 1 deletion api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@ package guerrilla

import (
"bufio"
"errors"
"fmt"
"github.com/flashmob/go-guerrilla/backends"
"github.com/flashmob/go-guerrilla/log"
"github.com/flashmob/go-guerrilla/mail"
"github.com/flashmob/go-guerrilla/response"
"io/ioutil"
"net"
"os"
Expand Down Expand Up @@ -349,7 +351,7 @@ var funkyLogger = func() backends.Decorator {
return backends.ProcessWith(
func(e *mail.Envelope, task backends.SelectTask) (backends.Result, error) {
if task == backends.TaskValidateRcpt {
// validate the last recipient appended to e.Rcpt
// log the last recipient appended to e.Rcpt
backends.Log().Infof(
"another funky recipient [%s]",
e.RcptTo[len(e.RcptTo)-1])
Expand Down Expand Up @@ -556,3 +558,58 @@ func TestSkipAllowsHost(t *testing.T) {
}

}

var customBackend2 = func() backends.Decorator {

return func(p backends.Processor) backends.Processor {
return backends.ProcessWith(
func(e *mail.Envelope, task backends.SelectTask) (backends.Result, error) {
if task == backends.TaskValidateRcpt {
return p.Process(e, task)
} else if task == backends.TaskSaveMail {
backends.Log().Info("Another funky email!")
err := errors.New("system shock")
return backends.NewResult(response.Canned.FailReadErrorDataCmd, response.SP, err), err
}
return p.Process(e, task)
})
}
}

// Test a custom backend response
func TestCustomBackendResult(t *testing.T) {
os.Truncate("tests/testlog", 0)
cfg := &AppConfig{
LogFile: "tests/testlog",
AllowedHosts: []string{"grr.la"},
BackendConfig: backends.BackendConfig{
"save_process": "HeadersParser|Debugger|Custom",
"validate_process": "Custom",
},
}
d := Daemon{Config: cfg}
d.AddProcessor("Custom", customBackend2)

if err := d.Start(); err != nil {
t.Error(err)
}
// lets have a talk with the server
talkToServer("127.0.0.1:2525")

d.Shutdown()

b, err := ioutil.ReadFile("tests/testlog")
if err != nil {
t.Error("could not read logfile")
return
}
// lets check for fingerprints
if strings.Index(string(b), "451 4.3.0 Error") < 0 {
t.Error("did not log: 451 4.3.0 Error")
}

if strings.Index(string(b), "system shock") < 0 {
t.Error("did not log: system shock")
}

}
29 changes: 22 additions & 7 deletions backends/backend.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package backends

import (
"bytes"
"fmt"
"github.com/flashmob/go-guerrilla/log"
"github.com/flashmob/go-guerrilla/mail"
Expand Down Expand Up @@ -54,6 +55,7 @@ type BaseConfig interface{}
type notifyMsg struct {
err error
queuedID string
result Result
}

// Result represents a response to an SMTP client after receiving DATA.
Expand All @@ -66,16 +68,18 @@ type Result interface {
}

// Internal implementation of BackendResult for use by backend implementations.
type result string
type result struct {
bytes.Buffer
}

func (br result) String() string {
return string(br)
func (r *result) String() string {
return r.Buffer.String()
}

// Parses the SMTP code from the first 3 characters of the SMTP message.
// Returns 554 if code cannot be parsed.
func (br result) Code() int {
trimmed := strings.TrimSpace(string(br))
func (r *result) Code() int {
trimmed := strings.TrimSpace(r.String())
if len(trimmed) < 3 {
return 554
}
Expand All @@ -86,8 +90,19 @@ func (br result) Code() int {
return code
}

func NewResult(message string) Result {
return result(message)
func NewResult(r ...interface{}) Result {
buf := new(result)
for _, item := range r {
switch v := item.(type) {
case error:
buf.WriteString(v.Error())
case fmt.Stringer:
buf.WriteString(v.String())
case string:
buf.WriteString(v)
}
}
return buf
}

type processorInitializer interface {
Expand Down
54 changes: 30 additions & 24 deletions backends/gateway.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ func (w *workerMsg) reset(e *mail.Envelope, task SelectTask) {
// Process distributes an envelope to one of the backend workers with a TaskSaveMail task
func (gw *BackendGateway) Process(e *mail.Envelope) Result {
if gw.State != BackendStateRunning {
return NewResult(response.Canned.FailBackendNotRunning + gw.State.String())
return NewResult(response.Canned.FailBackendNotRunning, response.SP, gw.State)
}
// borrow a workerMsg from the pool
workerMsg := workerMsgPool.Get().(*workerMsg)
Expand All @@ -139,11 +139,32 @@ func (gw *BackendGateway) Process(e *mail.Envelope) Result {
// or timeout
select {
case status := <-workerMsg.notifyMe:
workerMsgPool.Put(workerMsg) // can be recycled since we used the notifyMe channel
// email saving transaction completed
if status.result == BackendResultOK && status.queuedID != "" {
return NewResult(response.Canned.SuccessMessageQueued, response.SP, status.queuedID)
}

// A custom result, there was probably an error, if so, log it
if status.result != nil {
if status.err != nil {
Log().Error(status.err)
}
return status.result
}

// if there was no result, but there's an error, then make a new result from the error
if status.err != nil {
return NewResult(response.Canned.FailBackendTransaction + status.err.Error())
if _, err := strconv.Atoi(status.err.Error()[:3]); err != nil {
return NewResult(response.Canned.FailBackendTransaction, response.SP, status.err)
}
return NewResult(status.err)
}
return NewResult(response.Canned.SuccessMessageQueued + status.queuedID)

// both result & error are nil (should not happen)
err := errors.New("no response from backend - processor did not return a result or an error")
Log().Error(err)
return NewResult(response.Canned.FailBackendTransaction, response.SP, err)

case <-time.After(gw.saveTimeout()):
Log().Error("Backend has timed out while saving email")
e.Lock() // lock the envelope - it's still processing here, we don't want the server to recycle it
Expand Down Expand Up @@ -434,27 +455,12 @@ func (gw *BackendGateway) workDispatcher(
return
case msg = <-workIn:
state = dispatcherStateWorking // recovers from panic if in this state
result, err := save.Process(msg.e, msg.task)
state = dispatcherStateNotify
if msg.task == TaskSaveMail {
// process the email here
result, _ := save.Process(msg.e, TaskSaveMail)
state = dispatcherStateNotify
if result.Code() < 300 {
// if all good, let the gateway know that it was saved
msg.notifyMe <- &notifyMsg{nil, msg.e.QueuedId}
} else {
// notify the gateway about the error
msg.notifyMe <- &notifyMsg{err: errors.New(result.String())}
}
} else if msg.task == TaskValidateRcpt {
_, err := validate.Process(msg.e, TaskValidateRcpt)
state = dispatcherStateNotify
if err != nil {
// validation failed
msg.notifyMe <- &notifyMsg{err: err}
} else {
// all good.
msg.notifyMe <- &notifyMsg{err: nil}
}
msg.notifyMe <- &notifyMsg{err: err, result: result, queuedID: msg.e.QueuedId}
} else {
msg.notifyMe <- &notifyMsg{err: err, result: result}
}
}
state = dispatcherStateIdle
Expand Down
60 changes: 30 additions & 30 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,9 @@ type client struct {
errors int
state ClientState
messagesSent int
// Response to be written to the client
// Response to be written to the client (for debugging)
response bytes.Buffer
bufErr error
conn net.Conn
bufin *smtpBufferedReader
bufout *bufio.Writer
Expand Down Expand Up @@ -69,39 +70,38 @@ func NewClient(conn net.Conn, clientID uint64, logger log.Logger, envelope *mail
return c
}

// setResponse adds a response to be written on the next turn
// sendResponse adds a response to be written on the next turn
// the response gets buffered
func (c *client) sendResponse(r ...interface{}) {
c.bufout.Reset(c.conn)
if c.log.IsDebug() {
// us additional buffer so that we can log the response in debug mode only
// an additional buffer so that we can log the response in debug mode only
c.response.Reset()
}
var out string
if c.bufErr != nil {
c.bufErr = nil
}
for _, item := range r {
switch v := item.(type) {
case string:
if _, err := c.bufout.WriteString(v); err != nil {
c.log.WithError(err).Error("could not write to c.bufout")
}
if c.log.IsDebug() {
c.response.WriteString(v)
}
case error:
if _, err := c.bufout.WriteString(v.Error()); err != nil {
c.log.WithError(err).Error("could not write to c.bufout")
}
if c.log.IsDebug() {
c.response.WriteString(v.Error())
}
out = v.Error()
case fmt.Stringer:
if _, err := c.bufout.WriteString(v.String()); err != nil {
c.log.WithError(err).Error("could not write to c.bufout")
}
if c.log.IsDebug() {
c.response.WriteString(v.String())
}
out = v.String()
case string:
out = v
}
if _, c.bufErr = c.bufout.WriteString(out); c.bufErr != nil {
c.log.WithError(c.bufErr).Error("could not write to c.bufout")
}
if c.log.IsDebug() {
c.response.WriteString(out)
}
if c.bufErr != nil {
return
}
}
c.bufout.WriteString("\r\n")
_, c.bufErr = c.bufout.WriteString("\r\n")
if c.log.IsDebug() {
c.response.WriteString("\r\n")
}
Expand Down Expand Up @@ -176,20 +176,20 @@ func (c *client) getID() uint64 {
}

// UpgradeToTLS upgrades a client connection to TLS
func (client *client) upgradeToTLS(tlsConfig *tls.Config) error {
func (c *client) upgradeToTLS(tlsConfig *tls.Config) error {
var tlsConn *tls.Conn
// wrap client.conn in a new TLS server side connection
tlsConn = tls.Server(client.conn, tlsConfig)
// wrap c.conn in a new TLS server side connection
tlsConn = tls.Server(c.conn, tlsConfig)
// Call handshake here to get any handshake error before reading starts
err := tlsConn.Handshake()
if err != nil {
return err
}
// convert tlsConn to net.Conn
client.conn = net.Conn(tlsConn)
client.bufout.Reset(client.conn)
client.bufin.Reset(client.conn)
client.TLS = true
c.conn = net.Conn(tlsConn)
c.bufout.Reset(c.conn)
c.bufin.Reset(c.conn)
c.TLS = true
return err
}

Expand Down
2 changes: 1 addition & 1 deletion cmd/guerrillad/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ func readConfig(path string, pidFile string) (*guerrilla.AppConfig, error) {
// command line flags can override config values
appConfig, err := d.LoadConfig(path)
if err != nil {
return &appConfig, fmt.Errorf("Could not read config file: %s", err.Error())
return &appConfig, fmt.Errorf("could not read config file: %s", err.Error())
}
// override config pidFile with with flag from the command line
if len(pidFile) > 0 {
Expand Down
Loading

0 comments on commit 15a3295

Please sign in to comment.