Skip to content

Commit

Permalink
Merge pull request gocolly#763 from WGH-/fix-setcookie-self-redirect
Browse files Browse the repository at this point in the history
Support websites redirecting to the same page when AllowURLRevisit is disabled
  • Loading branch information
asciimoo authored Apr 17, 2023
2 parents 336c8f7 + b4ca6a7 commit 70168cf
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 9 deletions.
22 changes: 15 additions & 7 deletions colly.go
Original file line number Diff line number Diff line change
Expand Up @@ -1334,7 +1334,12 @@ func (c *Collector) checkRedirectFunc() func(req *http.Request, via []*http.Requ
return fmt.Errorf("Not following redirect to %q: %w", req.URL, err)
}

if !c.AllowURLRevisit {
// allow redirects to the original destination
// to support websites redirecting to the same page while setting
// session cookies
samePageRedirect := normalizeURL(req.URL.String()) == normalizeURL(via[0].URL.String())

if !c.AllowURLRevisit && !samePageRedirect {
var body io.ReadCloser
if req.GetBody != nil {
var err error
Expand Down Expand Up @@ -1506,16 +1511,19 @@ func isMatchingFilter(fs []*regexp.Regexp, d []byte) bool {
return false
}

func normalizeURL(u string) string {
parsed, err := urlParser.Parse(u)
if err != nil {
return u
}
return parsed.String()
}

func requestHash(url string, body io.Reader) uint64 {
h := fnv.New64a()
// reparse the url to fix ambiguities such as
// "http://example.com" vs "http://example.com/"
parsedWhatwgURL, err := whatwgUrl.Parse(url)
if err == nil {
h.Write([]byte(parsedWhatwgURL.String()))
} else {
h.Write([]byte(url))
}
io.WriteString(h, normalizeURL(url))
if body != nil {
io.Copy(h, body)
}
Expand Down
69 changes: 67 additions & 2 deletions colly_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ Disallow: /disallowed
Disallow: /allowed*q=
`

func newTestServer() *httptest.Server {
func newUnstartedTestServer() *httptest.Server {
mux := http.NewServeMux()

mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
Expand Down Expand Up @@ -253,7 +253,13 @@ y">link</a>
}
})

return httptest.NewServer(mux)
return httptest.NewUnstartedServer(mux)
}

func newTestServer() *httptest.Server {
srv := newUnstartedTestServer()
srv.Start()
return srv
}

var newCollectorTests = map[string]func(*testing.T){
Expand Down Expand Up @@ -712,6 +718,33 @@ func TestCollectorURLRevisitCheck(t *testing.T) {
}
}

func TestSetCookieRedirect(t *testing.T) {
type middleware = func(http.Handler) http.Handler
for _, m := range []middleware{
requireSessionCookieSimple,
requireSessionCookieAuthPage,
} {
t.Run("", func(t *testing.T) {
ts := newUnstartedTestServer()
ts.Config.Handler = m(ts.Config.Handler)
ts.Start()
defer ts.Close()
c := NewCollector()
c.OnResponse(func(r *Response) {
if got, want := r.Body, serverIndexResponse; !bytes.Equal(got, want) {
t.Errorf("bad response body got=%q want=%q", got, want)
}
if got, want := r.StatusCode, http.StatusOK; got != want {
t.Errorf("bad response code got=%d want=%d", got, want)
}
})
if err := c.Visit(ts.URL); err != nil {
t.Fatal(err)
}
})
}
}

func TestCollectorPostURLRevisitCheck(t *testing.T) {
ts := newTestServer()
defer ts.Close()
Expand Down Expand Up @@ -1587,3 +1620,35 @@ func BenchmarkOnResponse(b *testing.B) {
c.Visit(ts.URL)
}
}

func requireSessionCookieSimple(handler http.Handler) http.Handler {
const cookieName = "session_id"

return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if _, err := r.Cookie(cookieName); err == http.ErrNoCookie {
http.SetCookie(w, &http.Cookie{Name: cookieName, Value: "1"})
http.Redirect(w, r, r.RequestURI, http.StatusFound)
return
}
handler.ServeHTTP(w, r)
})
}

func requireSessionCookieAuthPage(handler http.Handler) http.Handler {
const setCookiePath = "/auth"
const cookieName = "session_id"

return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == setCookiePath {
destination := r.URL.Query().Get("return")
http.Redirect(w, r, destination, http.StatusFound)
return
}
if _, err := r.Cookie(cookieName); err == http.ErrNoCookie {
http.SetCookie(w, &http.Cookie{Name: cookieName, Value: "1"})
http.Redirect(w, r, setCookiePath+"?return="+url.QueryEscape(r.RequestURI), http.StatusFound)
return
}
handler.ServeHTTP(w, r)
})
}

0 comments on commit 70168cf

Please sign in to comment.