Skip to content

Commit

Permalink
Merge pull request gocolly#468 from asalih/master
Browse files Browse the repository at this point in the history
Revisit detection improvement
  • Loading branch information
asciimoo authored Apr 28, 2020
2 parents 2f67e35 + adc32e2 commit 473feba
Show file tree
Hide file tree
Showing 2 changed files with 190 additions and 7 deletions.
51 changes: 44 additions & 7 deletions colly.go
Original file line number Diff line number Diff line change
Expand Up @@ -448,10 +448,13 @@ func (c *Collector) Visit(URL string) error {

// HasVisited checks if the provided URL has been visited
func (c *Collector) HasVisited(URL string) (bool, error) {
h := fnv.New64a()
h.Write([]byte(URL))
return c.checkHasVisited(URL, nil)
}

return c.store.IsVisited(h.Sum64())
// HasPosted checks if the provided URL and requestData has been visited
// This method is useful more likely to prevent re-visit same URL and POST body
func (c *Collector) HasPosted(URL string, requestData map[string]string) (bool, error) {
return c.checkHasVisited(URL, requestData)
}

// Head starts a collector job by creating a HEAD request.
Expand Down Expand Up @@ -537,7 +540,7 @@ func (c *Collector) scrape(u, method string, depth int, requestData io.Reader, c
if err != nil {
return err
}
if err := c.requestCheck(u, parsedURL, method, depth, checkRevisit); err != nil {
if err := c.requestCheck(u, parsedURL, method, requestData, depth, checkRevisit); err != nil {
return err
}

Expand Down Expand Up @@ -685,7 +688,7 @@ func (c *Collector) fetch(u, method string, depth int, requestData io.Reader, ct
return err
}

func (c *Collector) requestCheck(u string, parsedURL *url.URL, method string, depth int, checkRevisit bool) error {
func (c *Collector) requestCheck(u string, parsedURL *url.URL, method string, requestData io.Reader, depth int, checkRevisit bool) error {
if u == "" {
return ErrMissingURL
}
Expand All @@ -710,10 +713,20 @@ func (c *Collector) requestCheck(u string, parsedURL *url.URL, method string, de
return err
}
}
if checkRevisit && !c.AllowURLRevisit && method == "GET" {
if checkRevisit && !c.AllowURLRevisit {
h := fnv.New64a()
h.Write([]byte(u))
uHash := h.Sum64()

var uHash uint64
if method == "GET" {
uHash = h.Sum64()
} else if requestData != nil {
h.Write(streamToByte(requestData))
uHash = h.Sum64()
} else {
return nil
}

visited, err := c.store.IsVisited(uHash)
if err != nil {
return err
Expand Down Expand Up @@ -1283,6 +1296,17 @@ func (c *Collector) parseSettingsFromEnv() {
}
}

func (c *Collector) checkHasVisited(URL string, requestData map[string]string) (bool, error) {
h := fnv.New64a()
h.Write([]byte(URL))

if requestData != nil {
h.Write(streamToByte(createFormReader(requestData)))
}

return c.store.IsVisited(h.Sum64())
}

// SanitizeFileName replaces dangerous characters in a string
// so the return value can be used as a safe file name.
func SanitizeFileName(fileName string) string {
Expand Down Expand Up @@ -1391,3 +1415,16 @@ func isMatchingFilter(fs []*regexp.Regexp, d []byte) bool {
}
return false
}

func streamToByte(r io.Reader) []byte {
buf := new(bytes.Buffer)
buf.ReadFrom(r)

if strReader, k := r.(*strings.Reader); k {
strReader.Seek(0, 0)
} else if bReader, kb := r.(*bytes.Reader); kb {
bReader.Seek(0, 0)
}

return buf.Bytes()
}
146 changes: 146 additions & 0 deletions colly_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,45 @@ func TestCollectorURLRevisit(t *testing.T) {
}
}

func TestCollectorPostRevisit(t *testing.T) {
ts := newTestServer()
defer ts.Close()

postValue := "hello"
postData := map[string]string{
"name": postValue,
}
visitCount := 0

c := NewCollector()
c.OnResponse(func(r *Response) {
if postValue != string(r.Body) {
t.Error("Failed to send data with POST")
}
visitCount++
})

c.Post(ts.URL+"/login", postData)
c.Post(ts.URL+"/login", postData)
c.Post(ts.URL+"/login", map[string]string{
"name": postValue,
"lastname": "world",
})

if visitCount != 2 {
t.Error("URL POST revisited")
}

c.AllowURLRevisit = true

c.Post(ts.URL+"/login", postData)
c.Post(ts.URL+"/login", postData)

if visitCount != 4 {
t.Error("URL POST not revisited")
}
}

func TestCollectorURLRevisitCheck(t *testing.T) {
ts := newTestServer()
defer ts.Close()
Expand Down Expand Up @@ -539,6 +578,63 @@ func TestCollectorURLRevisitCheck(t *testing.T) {
}
}

func TestCollectorPostURLRevisitCheck(t *testing.T) {
ts := newTestServer()
defer ts.Close()

c := NewCollector()

postValue := "hello"
postData := map[string]string{
"name": postValue,
}

posted, err := c.HasPosted(ts.URL+"/login", postData)

if err != nil {
t.Error(err.Error())
}

if posted != false {
t.Error("Expected URL to NOT have been visited")
}

c.Post(ts.URL+"/login", postData)

posted, err = c.HasPosted(ts.URL+"/login", postData)

if err != nil {
t.Error(err.Error())
}

if posted != true {
t.Error("Expected URL to have been visited")
}

postData["lastname"] = "world"
posted, err = c.HasPosted(ts.URL+"/login", postData)

if err != nil {
t.Error(err.Error())
}

if posted != false {
t.Error("Expected URL to NOT have been visited")
}

c.Post(ts.URL+"/login", postData)

posted, err = c.HasPosted(ts.URL+"/login", postData)

if err != nil {
t.Error(err.Error())
}

if posted != true {
t.Error("Expected URL to have been visited")
}
}

// TestCollectorURLRevisitDisallowed ensures that disallowed URL is not considered visited.
func TestCollectorURLRevisitDomainDisallowed(t *testing.T) {
ts := newTestServer()
Expand Down Expand Up @@ -579,6 +675,56 @@ func TestCollectorPost(t *testing.T) {
})
}

func TestCollectorPostRaw(t *testing.T) {
ts := newTestServer()
defer ts.Close()

postValue := "hello"
c := NewCollector()

c.OnResponse(func(r *Response) {
if postValue != string(r.Body) {
t.Error("Failed to send data with POST")
}
})

c.PostRaw(ts.URL+"/login", []byte("name="+postValue))
}

func TestCollectorPostRawRevisit(t *testing.T) {
ts := newTestServer()
defer ts.Close()

postValue := "hello"
postData := "name=" + postValue
visitCount := 0

c := NewCollector()
c.OnResponse(func(r *Response) {
if postValue != string(r.Body) {
t.Error("Failed to send data with POST RAW")
}
visitCount++
})

c.PostRaw(ts.URL+"/login", []byte(postData))
c.PostRaw(ts.URL+"/login", []byte(postData))
c.PostRaw(ts.URL+"/login", []byte(postData+"&lastname=world"))

if visitCount != 2 {
t.Error("URL POST RAW revisited")
}

c.AllowURLRevisit = true

c.PostRaw(ts.URL+"/login", []byte(postData))
c.PostRaw(ts.URL+"/login", []byte(postData))

if visitCount != 4 {
t.Error("URL POST RAW not revisited")
}
}

func TestRedirect(t *testing.T) {
ts := newTestServer()
defer ts.Close()
Expand Down

0 comments on commit 473feba

Please sign in to comment.