@@ -43,7 +43,7 @@ Disallow: /disallowed
43
43
Disallow: /allowed*q=
44
44
`
45
45
46
- func newTestServer () * httptest.Server {
46
+ func newUnstartedTestServer () * httptest.Server {
47
47
mux := http .NewServeMux ()
48
48
49
49
mux .HandleFunc ("/" , func (w http.ResponseWriter , r * http.Request ) {
@@ -253,7 +253,13 @@ y">link</a>
253
253
}
254
254
})
255
255
256
- return httptest .NewServer (mux )
256
+ return httptest .NewUnstartedServer (mux )
257
+ }
258
+
259
+ func newTestServer () * httptest.Server {
260
+ srv := newUnstartedTestServer ()
261
+ srv .Start ()
262
+ return srv
257
263
}
258
264
259
265
var newCollectorTests = map [string ]func (* testing.T ){
@@ -712,6 +718,33 @@ func TestCollectorURLRevisitCheck(t *testing.T) {
712
718
}
713
719
}
714
720
721
+ func TestSetCookieRedirect (t * testing.T ) {
722
+ type middleware = func (http.Handler ) http.Handler
723
+ for _ , m := range []middleware {
724
+ requireSessionCookieSimple ,
725
+ requireSessionCookieAuthPage ,
726
+ } {
727
+ t .Run ("" , func (t * testing.T ) {
728
+ ts := newUnstartedTestServer ()
729
+ ts .Config .Handler = m (ts .Config .Handler )
730
+ ts .Start ()
731
+ defer ts .Close ()
732
+ c := NewCollector ()
733
+ c .OnResponse (func (r * Response ) {
734
+ if got , want := r .Body , serverIndexResponse ; ! bytes .Equal (got , want ) {
735
+ t .Errorf ("bad response body got=%q want=%q" , got , want )
736
+ }
737
+ if got , want := r .StatusCode , http .StatusOK ; got != want {
738
+ t .Errorf ("bad response code got=%d want=%d" , got , want )
739
+ }
740
+ })
741
+ if err := c .Visit (ts .URL ); err != nil {
742
+ t .Fatal (err )
743
+ }
744
+ })
745
+ }
746
+ }
747
+
715
748
func TestCollectorPostURLRevisitCheck (t * testing.T ) {
716
749
ts := newTestServer ()
717
750
defer ts .Close ()
@@ -1587,3 +1620,35 @@ func BenchmarkOnResponse(b *testing.B) {
1587
1620
c .Visit (ts .URL )
1588
1621
}
1589
1622
}
1623
+
1624
+ func requireSessionCookieSimple (handler http.Handler ) http.Handler {
1625
+ const cookieName = "session_id"
1626
+
1627
+ return http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
1628
+ if _ , err := r .Cookie (cookieName ); err == http .ErrNoCookie {
1629
+ http .SetCookie (w , & http.Cookie {Name : cookieName , Value : "1" })
1630
+ http .Redirect (w , r , r .RequestURI , http .StatusFound )
1631
+ return
1632
+ }
1633
+ handler .ServeHTTP (w , r )
1634
+ })
1635
+ }
1636
+
1637
+ func requireSessionCookieAuthPage (handler http.Handler ) http.Handler {
1638
+ const setCookiePath = "/auth"
1639
+ const cookieName = "session_id"
1640
+
1641
+ return http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
1642
+ if r .URL .Path == setCookiePath {
1643
+ destination := r .URL .Query ().Get ("return" )
1644
+ http .Redirect (w , r , destination , http .StatusFound )
1645
+ return
1646
+ }
1647
+ if _ , err := r .Cookie (cookieName ); err == http .ErrNoCookie {
1648
+ http .SetCookie (w , & http.Cookie {Name : cookieName , Value : "1" })
1649
+ http .Redirect (w , r , setCookiePath + "?return=" + url .QueryEscape (r .RequestURI ), http .StatusFound )
1650
+ return
1651
+ }
1652
+ handler .ServeHTTP (w , r )
1653
+ })
1654
+ }
0 commit comments