Skip to content

Commit

Permalink
Add OnResponseHeaders callback
Browse files Browse the repository at this point in the history
The callback is called just before reading the response body,
and can be used to abort the request.

Closes gocolly#228.
  • Loading branch information
WGH- committed Feb 4, 2020
1 parent 1d31d5b commit 6b290a1
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 20 deletions.
71 changes: 55 additions & 16 deletions colly.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,26 +109,30 @@ type Collector struct {
CheckHead bool
// TraceHTTP enables capturing and reporting request performance for crawler tuning.
// When set to true, the Response.Trace will be filled in with an HTTPTrace object.
TraceHTTP bool
store storage.Storage
debugger debug.Debugger
robotsMap map[string]*robotstxt.RobotsData
htmlCallbacks []*htmlCallbackContainer
xmlCallbacks []*xmlCallbackContainer
requestCallbacks []RequestCallback
responseCallbacks []ResponseCallback
errorCallbacks []ErrorCallback
scrapedCallbacks []ScrapedCallback
requestCount uint32
responseCount uint32
backend *httpBackend
wg *sync.WaitGroup
lock *sync.RWMutex
TraceHTTP bool
store storage.Storage
debugger debug.Debugger
robotsMap map[string]*robotstxt.RobotsData
htmlCallbacks []*htmlCallbackContainer
xmlCallbacks []*xmlCallbackContainer
requestCallbacks []RequestCallback
responseCallbacks []ResponseCallback
responseHeadersCallbacks []ResponseHeadersCallback
errorCallbacks []ErrorCallback
scrapedCallbacks []ScrapedCallback
requestCount uint32
responseCount uint32
backend *httpBackend
wg *sync.WaitGroup
lock *sync.RWMutex
}

// RequestCallback is a type alias for OnRequest callback functions
type RequestCallback func(*Request)

// ResponseHeadersCallback is a type alias for OnResponseHeaders callback functions
type ResponseHeadersCallback func(*Response)

// ResponseCallback is a type alias for OnResponse callback functions
type ResponseCallback func(*Response)

Expand Down Expand Up @@ -196,6 +200,8 @@ var (
ErrNoPattern = errors.New("No pattern defined in LimitRule")
// ErrEmptyProxyURL is the error type for empty Proxy URL list
ErrEmptyProxyURL = errors.New("Proxy URL list is empty")
// ErrAbortedAfterHeaders is the error returned when OnResponseHeaders aborts the transfer.
ErrAbortedAfterHeaders = errors.New("Aborted after receiving response headers")
)

var envMap = map[string]func(*Collector, string){
Expand Down Expand Up @@ -626,9 +632,13 @@ func (c *Collector) fetch(u, method string, depth int, requestData io.Reader, ct
hTrace = &HTTPTrace{}
req = hTrace.WithTrace(req)
}
checkHeadersFunc := func(statusCode int, headers http.Header) bool {
c.handleOnResponseHeaders(&Response{Ctx: ctx, Request: request, StatusCode: statusCode, Headers: &headers})
return !request.abort
}

origURL := req.URL
response, err := c.backend.Cache(req, c.MaxBodySize, c.CacheDir)
response, err := c.backend.Cache(req, c.MaxBodySize, checkHeadersFunc, c.CacheDir)
if proxyURL, ok := req.Context().Value(ProxyURLKey).(string); ok {
request.ProxyURL = proxyURL
}
Expand Down Expand Up @@ -791,6 +801,23 @@ func (c *Collector) OnRequest(f RequestCallback) {
c.lock.Unlock()
}

// OnResponseHeaders registers a function. Function will be executed on every response
// when headers and status are already received, but body is not yet read.
//
// Like in OnRequest, you can call Request.Abort to abort the transfer. This might be
// useful if, for example, you're following all hyperlinks, but want to avoid
// downloading files.
//
// Be aware that using this will prevent HTTP/1.1 connection reuse, as
// the only way to abort a download is to immediately close the connection.
// HTTP/2 doesn't suffer from this problem, as it's possible to close
// specific stream inside the connection.
func (c *Collector) OnResponseHeaders(f ResponseHeadersCallback) {
c.lock.Lock()
c.responseHeadersCallbacks = append(c.responseHeadersCallbacks, f)
c.lock.Unlock()
}

// OnResponse registers a function. Function will be executed on every response
func (c *Collector) OnResponse(f ResponseCallback) {
c.lock.Lock()
Expand Down Expand Up @@ -987,6 +1014,18 @@ func (c *Collector) handleOnResponse(r *Response) {
}
}

func (c *Collector) handleOnResponseHeaders(r *Response) {
if c.debugger != nil {
c.debugger.Event(createEvent("responseHeaders", r.Request.ID, c.ID, map[string]string{
"url": r.Request.URL.String(),
"status": http.StatusText(r.StatusCode),
}))
}
for _, f := range c.responseHeadersCallbacks {
f(r)
}
}

func (c *Collector) handleOnHTML(resp *Response) error {
if len(c.htmlCallbacks) == 0 || !strings.Contains(strings.ToLower(resp.Headers.Get("Content-Type")), "html") {
return nil
Expand Down
35 changes: 35 additions & 0 deletions colly_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package colly

import (
"bufio"
"bytes"
"fmt"
"net/http"
Expand Down Expand Up @@ -138,6 +139,18 @@ func newTestServer() *httptest.Server {
`))
})

mux.HandleFunc("/large_binary", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/octet-stream")
ww := bufio.NewWriter(w)
defer ww.Flush()
for {
// have to check error to detect client aborting download
if _, err := ww.Write([]byte{0x41}); err != nil {
return
}
}
})

return httptest.NewServer(mux)
}

Expand Down Expand Up @@ -388,6 +401,28 @@ func TestCollectorVisitWithDisallowedDomains(t *testing.T) {
}
}

func TestCollectorVisitResponseHeaders(t *testing.T) {
ts := newTestServer()
defer ts.Close()

var onResponseHeadersCalled bool

c := NewCollector()
c.OnResponseHeaders(func(r *Response) {
onResponseHeadersCalled = true
if r.Headers.Get("Content-Type") == "application/octet-stream" {
r.Request.Abort()
}
})
c.OnResponse(func(r *Response) {
t.Error("OnResponse was called")
})
c.Visit(ts.URL + "/large_binary")
if !onResponseHeadersCalled {
t.Error("OnResponseHeaders was not called")
}
}

func TestCollectorOnHTML(t *testing.T) {
ts := newTestServer()
defer ts.Close()
Expand Down
15 changes: 11 additions & 4 deletions http_backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ type httpBackend struct {
lock *sync.RWMutex
}

type checkHeadersFunc func(statusCode int, header http.Header) bool

// LimitRule provides connection restrictions for domains.
// Both DomainRegexp and DomainGlob can be used to specify
// the included domains patterns, but at least one is required.
Expand Down Expand Up @@ -127,9 +129,9 @@ func (h *httpBackend) GetMatchingRule(domain string) *LimitRule {
return nil
}

func (h *httpBackend) Cache(request *http.Request, bodySize int, cacheDir string) (*Response, error) {
func (h *httpBackend) Cache(request *http.Request, bodySize int, checkHeadersFunc checkHeadersFunc, cacheDir string) (*Response, error) {
if cacheDir == "" || request.Method != "GET" {
return h.Do(request, bodySize)
return h.Do(request, bodySize, checkHeadersFunc)
}
sum := sha1.Sum([]byte(request.URL.String()))
hash := hex.EncodeToString(sum[:])
Expand All @@ -143,7 +145,7 @@ func (h *httpBackend) Cache(request *http.Request, bodySize int, cacheDir string
return resp, err
}
}
resp, err := h.Do(request, bodySize)
resp, err := h.Do(request, bodySize, checkHeadersFunc)
if err != nil || resp.StatusCode >= 500 {
return resp, err
}
Expand All @@ -164,7 +166,7 @@ func (h *httpBackend) Cache(request *http.Request, bodySize int, cacheDir string
return resp, os.Rename(filename+"~", filename)
}

func (h *httpBackend) Do(request *http.Request, bodySize int) (*Response, error) {
func (h *httpBackend) Do(request *http.Request, bodySize int, checkHeadersFunc checkHeadersFunc) (*Response, error) {
r := h.GetMatchingRule(request.URL.Host)
if r != nil {
r.waitChan <- true
Expand All @@ -186,6 +188,11 @@ func (h *httpBackend) Do(request *http.Request, bodySize int) (*Response, error)
if res.Request != nil {
*request = *res.Request
}
if !checkHeadersFunc(res.StatusCode, res.Header) {
// closing res.Body (see defer above) without reading it aborts
// the download
return nil, ErrAbortedAfterHeaders
}

var bodyReader io.Reader = res.Body
if bodySize > 0 {
Expand Down

0 comments on commit 6b290a1

Please sign in to comment.