Skip to content

Commit

Permalink
fix revisit detection and fixing/adding tests
Browse files Browse the repository at this point in the history
  • Loading branch information
asalih committed Apr 26, 2020
1 parent 7b7ce71 commit adc32e2
Show file tree
Hide file tree
Showing 2 changed files with 144 additions and 7 deletions.
36 changes: 31 additions & 5 deletions colly.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
package colly

import (
"bufio"
"bytes"
"context"
"crypto/rand"
Expand Down Expand Up @@ -449,10 +448,13 @@ func (c *Collector) Visit(URL string) error {

// HasVisited checks if the provided URL has been visited
func (c *Collector) HasVisited(URL string) (bool, error) {
h := fnv.New64a()
h.Write([]byte(URL))
return c.checkHasVisited(URL, nil)
}

return c.store.IsVisited(h.Sum64())
// HasPosted checks if the provided URL and requestData has been visited
// This method is useful more likely to prevent re-visit same URL and POST body
func (c *Collector) HasPosted(URL string, requestData map[string]string) (bool, error) {
return c.checkHasVisited(URL, requestData)
}

// Head starts a collector job by creating a HEAD request.
Expand Down Expand Up @@ -719,7 +721,7 @@ func (c *Collector) requestCheck(u string, parsedURL *url.URL, method string, re
if method == "GET" {
uHash = h.Sum64()
} else if requestData != nil {
h.Write(bufio.NewScanner(requestData).Bytes())
h.Write(streamToByte(requestData))
uHash = h.Sum64()
} else {
return nil
Expand Down Expand Up @@ -1294,6 +1296,17 @@ func (c *Collector) parseSettingsFromEnv() {
}
}

func (c *Collector) checkHasVisited(URL string, requestData map[string]string) (bool, error) {
h := fnv.New64a()
h.Write([]byte(URL))

if requestData != nil {
h.Write(streamToByte(createFormReader(requestData)))
}

return c.store.IsVisited(h.Sum64())
}

// SanitizeFileName replaces dangerous characters in a string
// so the return value can be used as a safe file name.
func SanitizeFileName(fileName string) string {
Expand Down Expand Up @@ -1402,3 +1415,16 @@ func isMatchingFilter(fs []*regexp.Regexp, d []byte) bool {
}
return false
}

func streamToByte(r io.Reader) []byte {
buf := new(bytes.Buffer)
buf.ReadFrom(r)

if strReader, k := r.(*strings.Reader); k {
strReader.Seek(0, 0)
} else if bReader, kb := r.(*bytes.Reader); kb {
bReader.Seek(0, 0)
}

return buf.Bytes()
}
115 changes: 113 additions & 2 deletions colly_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -530,8 +530,12 @@ func TestCollectorPostRevisit(t *testing.T) {

c.Post(ts.URL+"/login", postData)
c.Post(ts.URL+"/login", postData)
c.Post(ts.URL+"/login", map[string]string{
"name": postValue,
"lastname": "world",
})

if visitCount != 1 {
if visitCount != 2 {
t.Error("URL POST revisited")
}

Expand All @@ -540,7 +544,7 @@ func TestCollectorPostRevisit(t *testing.T) {
c.Post(ts.URL+"/login", postData)
c.Post(ts.URL+"/login", postData)

if visitCount != 3 {
if visitCount != 4 {
t.Error("URL POST not revisited")
}
}
Expand Down Expand Up @@ -574,6 +578,63 @@ func TestCollectorURLRevisitCheck(t *testing.T) {
}
}

func TestCollectorPostURLRevisitCheck(t *testing.T) {
ts := newTestServer()
defer ts.Close()

c := NewCollector()

postValue := "hello"
postData := map[string]string{
"name": postValue,
}

posted, err := c.HasPosted(ts.URL+"/login", postData)

if err != nil {
t.Error(err.Error())
}

if posted != false {
t.Error("Expected URL to NOT have been visited")
}

c.Post(ts.URL+"/login", postData)

posted, err = c.HasPosted(ts.URL+"/login", postData)

if err != nil {
t.Error(err.Error())
}

if posted != true {
t.Error("Expected URL to have been visited")
}

postData["lastname"] = "world"
posted, err = c.HasPosted(ts.URL+"/login", postData)

if err != nil {
t.Error(err.Error())
}

if posted != false {
t.Error("Expected URL to NOT have been visited")
}

c.Post(ts.URL+"/login", postData)

posted, err = c.HasPosted(ts.URL+"/login", postData)

if err != nil {
t.Error(err.Error())
}

if posted != true {
t.Error("Expected URL to have been visited")
}
}

// TestCollectorURLRevisitDisallowed ensures that disallowed URL is not considered visited.
func TestCollectorURLRevisitDomainDisallowed(t *testing.T) {
ts := newTestServer()
Expand Down Expand Up @@ -614,6 +675,56 @@ func TestCollectorPost(t *testing.T) {
})
}

func TestCollectorPostRaw(t *testing.T) {
ts := newTestServer()
defer ts.Close()

postValue := "hello"
c := NewCollector()

c.OnResponse(func(r *Response) {
if postValue != string(r.Body) {
t.Error("Failed to send data with POST")
}
})

c.PostRaw(ts.URL+"/login", []byte("name="+postValue))
}

func TestCollectorPostRawRevisit(t *testing.T) {
ts := newTestServer()
defer ts.Close()

postValue := "hello"
postData := "name=" + postValue
visitCount := 0

c := NewCollector()
c.OnResponse(func(r *Response) {
if postValue != string(r.Body) {
t.Error("Failed to send data with POST RAW")
}
visitCount++
})

c.PostRaw(ts.URL+"/login", []byte(postData))
c.PostRaw(ts.URL+"/login", []byte(postData))
c.PostRaw(ts.URL+"/login", []byte(postData+"&lastname=world"))

if visitCount != 2 {
t.Error("URL POST RAW revisited")
}

c.AllowURLRevisit = true

c.PostRaw(ts.URL+"/login", []byte(postData))
c.PostRaw(ts.URL+"/login", []byte(postData))

if visitCount != 4 {
t.Error("URL POST RAW not revisited")
}
}

func TestRedirect(t *testing.T) {
ts := newTestServer()
defer ts.Close()
Expand Down

0 comments on commit adc32e2

Please sign in to comment.