diff --git a/header.go b/header.go index cd90cc43f3..184519edf6 100644 --- a/header.go +++ b/header.go @@ -1457,10 +1457,9 @@ func (h *ResponseHeader) AppendBytes(dst []byte) []byte { dst = append(dst, statusLine(statusCode)...) server := h.Server() - if len(server) == 0 { - server = defaultServerName + if len(server) != 0 { + dst = appendHeaderLine(dst, strServer, server) } - dst = appendHeaderLine(dst, strServer, server) dst = appendHeaderLine(dst, strDate, serverDate.Load().([]byte)) // Append Content-Type only for non-zero responses diff --git a/http_test.go b/http_test.go index fe91ca4c3e..28e202880e 100644 --- a/http_test.go +++ b/http_test.go @@ -1372,7 +1372,7 @@ func TestResponseSuccess(t *testing.T) { // response with missing server testResponseSuccess(t, 500, "aaa", "", "aaadfsd", - 500, "aaa", string(defaultServerName)) + 500, "aaa", "") // empty body testResponseSuccess(t, 200, "bbb", "qwer", "", diff --git a/server.go b/server.go index 36098c12e5..aa626cae3f 100644 --- a/server.go +++ b/server.go @@ -265,6 +265,15 @@ type Server struct { // * cONTENT-lenGTH -> Content-Length DisableHeaderNamesNormalizing bool + // NoDefaultServerHeader, when set to true, causes the default Server header + // to be excluded from the Response. + // + // The default Server header value is the value of the Name field or an + // internal default value in its absence. With this option set to true, + // the only time a Server header will be sent is if a non-zero length + // value is explicitly provided during a request. + NoDefaultServerHeader bool + // Logger, which is used by RequestCtx.Logger(). // // By default standard logger from log package is used. @@ -1509,7 +1518,10 @@ const DefaultMaxRequestBodySize = 4 * 1024 * 1024 func (s *Server) serveConn(c net.Conn) error { defer s.wg.Done() - serverName := s.getServerName() + var serverName []byte + if !s.NoDefaultServerHeader { + serverName = s.getServerName() + } connRequestNum := uint64(0) connID := nextConnID() currentTime := time.Now() @@ -1581,7 +1593,7 @@ func (s *Server) serveConn(c net.Conn) error { if err == io.EOF { err = nil } else { - bw = writeErrorResponse(bw, ctx, err) + bw = writeErrorResponse(bw, ctx, serverName, err) } break } @@ -1611,7 +1623,7 @@ func (s *Server) serveConn(c net.Conn) error { br = nil } if err != nil { - bw = writeErrorResponse(bw, ctx, err) + bw = writeErrorResponse(bw, ctx, serverName, err) break } } @@ -1619,7 +1631,9 @@ func (s *Server) serveConn(c net.Conn) error { connectionClose = s.DisableKeepalive || ctx.Request.Header.connectionCloseFast() isHTTP11 = ctx.Request.Header.IsHTTP11() - ctx.Response.Header.SetServerBytes(serverName) + if serverName != nil { + ctx.Response.Header.SetServerBytes(serverName) + } ctx.connID = connID ctx.connRequestNum = connRequestNum ctx.time = currentTime @@ -1665,7 +1679,7 @@ func (s *Server) serveConn(c net.Conn) error { ctx.Response.Header.SetCanonical(strConnection, strKeepAlive) } - if len(ctx.Response.Header.Server()) == 0 { + if serverName != nil && len(ctx.Response.Header.Server()) == 0 { ctx.Response.Header.SetServerBytes(serverName) } @@ -2038,22 +2052,31 @@ func (s *Server) getServerName() []byte { func (s *Server) writeFastError(w io.Writer, statusCode int, msg string) { w.Write(statusLine(statusCode)) + + server := "" + if !s.NoDefaultServerHeader { + server = fmt.Sprintf("Server: %s\r\n", s.getServerName()) + } + fmt.Fprintf(w, "Connection: close\r\n"+ - "Server: %s\r\n"+ + server+ "Date: %s\r\n"+ "Content-Type: text/plain\r\n"+ "Content-Length: %d\r\n"+ "\r\n"+ "%s", - s.getServerName(), serverDate.Load(), len(msg), msg) + serverDate.Load(), len(msg), msg) } -func writeErrorResponse(bw *bufio.Writer, ctx *RequestCtx, err error) *bufio.Writer { +func writeErrorResponse(bw *bufio.Writer, ctx *RequestCtx, serverName []byte, err error) *bufio.Writer { if _, ok := err.(*ErrSmallBuffer); ok { ctx.Error("Too big request header", StatusRequestHeaderFieldsTooLarge) } else { ctx.Error("Error when parsing request", StatusBadRequest) } + if serverName != nil { + ctx.Response.Header.SetServerBytes(serverName) + } ctx.SetConnectionClose() if bw == nil { bw = acquireWriter(ctx) diff --git a/server_test.go b/server_test.go index d5615f6280..da9c0c0a39 100644 --- a/server_test.go +++ b/server_test.go @@ -17,6 +17,67 @@ import ( "github.com/valyala/fasthttp/fasthttputil" ) +func TestServerName(t *testing.T) { + s := &Server{ + Handler: func(ctx *RequestCtx) { + }, + } + + getReponse := func() []byte { + rw := &readWriter{} + rw.r.WriteString("GET / HTTP/1.1\r\nHost: google.com\r\n\r\n") + + ch := make(chan error) + go func() { + ch <- s.ServeConn(rw) + }() + + select { + case err := <-ch: + if err != nil { + t.Fatalf("Unexpected error from serveConn: %s", err) + } + case <-time.After(100 * time.Millisecond): + t.Fatalf("timeout") + } + + resp, err := ioutil.ReadAll(&rw.w) + if err != nil { + t.Fatalf("Unexpected error from ReadAll: %s", err) + } + + return resp + } + + resp := getReponse() + if !bytes.Contains(resp, []byte("\r\nServer: "+string(defaultServerName)+"\r\n")) { + t.Fatalf("Unexpected response %q expected Server: "+string(defaultServerName), resp) + } + + // We can't just overwrite s.Name as fasthttp caches the name in an atomic.Value + s = &Server{ + Handler: func(ctx *RequestCtx) { + }, + Name: "foobar", + } + + resp = getReponse() + if !bytes.Contains(resp, []byte("\r\nServer: foobar\r\n")) { + t.Fatalf("Unexpected response %q expected Server: foobar", resp) + } + + s = &Server{ + Handler: func(ctx *RequestCtx) { + }, + NoDefaultServerHeader: true, + } + + resp = getReponse() + if bytes.Contains(resp, []byte("\r\nServer: ")) { + t.Fatalf("Unexpected response %q expected no Server header", resp) + } +} + func TestRequestCtxString(t *testing.T) { var ctx RequestCtx