Skip to content

Commit

Permalink
Merge pull request oauth2-proxy#2539 from isodude/fix-leaky-test
Browse files Browse the repository at this point in the history
pkg/http: Fix leaky test
  • Loading branch information
JoelSpeed authored Mar 30, 2024
2 parents 3b11a51 + 8f7209b commit 0678626
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 32 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

## Changes since v7.6.0

- [#2539](https://github.com/oauth2-proxy/oauth2-proxy/pull/2539) pkg/http: Fix leaky test (@isodude)

# V7.6.0

## Release Highlights
Expand Down
20 changes: 14 additions & 6 deletions pkg/http/http_suite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package http

import (
"bytes"
"context"
"crypto/tls"
"crypto/x509"
"encoding/pem"
Expand All @@ -18,7 +19,7 @@ import (
var ipv4CertData, ipv6CertData []byte
var ipv4CertDataSource, ipv4KeyDataSource options.SecretSource
var ipv6CertDataSource, ipv6KeyDataSource options.SecretSource
var client *http.Client
var transport *http.Transport

func TestHTTPSuite(t *testing.T) {
logger.SetOutput(GinkgoWriter)
Expand All @@ -28,6 +29,17 @@ func TestHTTPSuite(t *testing.T) {
RunSpecs(t, "HTTP")
}

func httpGet(ctx context.Context, url string) (*http.Response, error) {
c := &http.Client{
Transport: transport.Clone(),
}
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return nil, err
}
return c.Do(req)
}

var _ = BeforeSuite(func() {
By("Generating a ipv4 self-signed cert for TLS tests", func() {
certBytes, keyBytes, err := util.GenerateCert("127.0.0.1")
Expand Down Expand Up @@ -70,11 +82,7 @@ var _ = BeforeSuite(func() {
certpool.AddCert(ipv4certificate)
certpool.AddCert(ipv6certificate)

transport := http.DefaultTransport.(*http.Transport).Clone()
transport = http.DefaultTransport.(*http.Transport).Clone()
transport.TLSClientConfig.RootCAs = certpool

client = &http.Client{
Transport: transport,
}
})
})
55 changes: 29 additions & 26 deletions pkg/http/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
. "github.com/onsi/ginkgo"
. "github.com/onsi/ginkgo/extensions/table"
. "github.com/onsi/gomega"
. "github.com/onsi/gomega/gleak"
)

const hello = "Hello World!"
Expand Down Expand Up @@ -559,6 +560,8 @@ var _ = Describe("Server", func() {

AfterEach(func() {
cancel()
Eventually(Goroutines).ShouldNot(HaveLeaked())

})

Context("with an ipv4 http server", func() {
Expand All @@ -584,7 +587,7 @@ var _ = Describe("Server", func() {
Expect(srv.Start(ctx)).To(Succeed())
}()

resp, err := client.Get(listenAddr)
resp, err := httpGet(ctx, listenAddr)
Expect(err).ToNot(HaveOccurred())
Expect(resp.StatusCode).To(Equal(http.StatusOK))

Expand All @@ -599,13 +602,13 @@ var _ = Describe("Server", func() {
Expect(srv.Start(ctx)).To(Succeed())
}()

_, err := client.Get(listenAddr)
_, err := httpGet(ctx, listenAddr)
Expect(err).ToNot(HaveOccurred())

cancel()

Eventually(func() error {
_, err := client.Get(listenAddr)
_, err := httpGet(ctx, listenAddr)
return err
}).Should(HaveOccurred())
})
Expand Down Expand Up @@ -638,7 +641,7 @@ var _ = Describe("Server", func() {
Expect(srv.Start(ctx)).To(Succeed())
}()

resp, err := client.Get(secureListenAddr)
resp, err := httpGet(ctx, secureListenAddr)
Expect(err).ToNot(HaveOccurred())
Expect(resp.StatusCode).To(Equal(http.StatusOK))

Expand All @@ -653,13 +656,13 @@ var _ = Describe("Server", func() {
Expect(srv.Start(ctx)).To(Succeed())
}()

_, err := client.Get(secureListenAddr)
_, err := httpGet(ctx, secureListenAddr)
Expect(err).ToNot(HaveOccurred())

cancel()

Eventually(func() error {
_, err := client.Get(secureListenAddr)
_, err := httpGet(ctx, secureListenAddr)
return err
}).Should(HaveOccurred())
})
Expand All @@ -670,7 +673,7 @@ var _ = Describe("Server", func() {
Expect(srv.Start(ctx)).To(Succeed())
}()

resp, err := client.Get(secureListenAddr)
resp, err := httpGet(ctx, secureListenAddr)
Expect(err).ToNot(HaveOccurred())
Expect(resp.StatusCode).To(Equal(http.StatusOK))

Expand Down Expand Up @@ -709,7 +712,7 @@ var _ = Describe("Server", func() {
Expect(srv.Start(ctx)).To(Succeed())
}()

resp, err := client.Get(listenAddr)
resp, err := httpGet(ctx, listenAddr)
Expect(err).ToNot(HaveOccurred())
Expect(resp.StatusCode).To(Equal(http.StatusOK))

Expand All @@ -724,7 +727,7 @@ var _ = Describe("Server", func() {
Expect(srv.Start(ctx)).To(Succeed())
}()

resp, err := client.Get(secureListenAddr)
resp, err := httpGet(ctx, secureListenAddr)
Expect(err).ToNot(HaveOccurred())
Expect(resp.StatusCode).To(Equal(http.StatusOK))

Expand All @@ -739,19 +742,19 @@ var _ = Describe("Server", func() {
Expect(srv.Start(ctx)).To(Succeed())
}()

_, err := client.Get(listenAddr)
_, err := httpGet(ctx, listenAddr)
Expect(err).ToNot(HaveOccurred())
_, err = client.Get(secureListenAddr)
_, err = httpGet(ctx, secureListenAddr)
Expect(err).ToNot(HaveOccurred())

cancel()

Eventually(func() error {
_, err := client.Get(listenAddr)
_, err := httpGet(ctx, listenAddr)
return err
}).Should(HaveOccurred())
Eventually(func() error {
_, err := client.Get(secureListenAddr)
_, err := httpGet(ctx, secureListenAddr)
return err
}).Should(HaveOccurred())
})
Expand Down Expand Up @@ -781,7 +784,7 @@ var _ = Describe("Server", func() {
Expect(srv.Start(ctx)).To(Succeed())
}()

resp, err := client.Get(listenAddr)
resp, err := httpGet(ctx, listenAddr)
Expect(err).ToNot(HaveOccurred())
Expect(resp.StatusCode).To(Equal(http.StatusOK))

Expand All @@ -796,13 +799,13 @@ var _ = Describe("Server", func() {
Expect(srv.Start(ctx)).To(Succeed())
}()

_, err := client.Get(listenAddr)
_, err := httpGet(ctx, listenAddr)
Expect(err).ToNot(HaveOccurred())

cancel()

Eventually(func() error {
_, err := client.Get(listenAddr)
_, err := httpGet(ctx, listenAddr)
return err
}).Should(HaveOccurred())
})
Expand Down Expand Up @@ -836,7 +839,7 @@ var _ = Describe("Server", func() {
Expect(srv.Start(ctx)).To(Succeed())
}()

resp, err := client.Get(secureListenAddr)
resp, err := httpGet(ctx, secureListenAddr)
Expect(err).ToNot(HaveOccurred())
Expect(resp.StatusCode).To(Equal(http.StatusOK))

Expand All @@ -851,13 +854,13 @@ var _ = Describe("Server", func() {
Expect(srv.Start(ctx)).To(Succeed())
}()

_, err := client.Get(secureListenAddr)
_, err := httpGet(ctx, secureListenAddr)
Expect(err).ToNot(HaveOccurred())

cancel()

Eventually(func() error {
_, err := client.Get(secureListenAddr)
_, err := httpGet(ctx, secureListenAddr)
return err
}).Should(HaveOccurred())
})
Expand All @@ -868,7 +871,7 @@ var _ = Describe("Server", func() {
Expect(srv.Start(ctx)).To(Succeed())
}()

resp, err := client.Get(secureListenAddr)
resp, err := httpGet(ctx, secureListenAddr)
Expect(err).ToNot(HaveOccurred())
Expect(resp.StatusCode).To(Equal(http.StatusOK))

Expand Down Expand Up @@ -908,7 +911,7 @@ var _ = Describe("Server", func() {
Expect(srv.Start(ctx)).To(Succeed())
}()

resp, err := client.Get(listenAddr)
resp, err := httpGet(ctx, listenAddr)
Expect(err).ToNot(HaveOccurred())
Expect(resp.StatusCode).To(Equal(http.StatusOK))

Expand All @@ -923,7 +926,7 @@ var _ = Describe("Server", func() {
Expect(srv.Start(ctx)).To(Succeed())
}()

resp, err := client.Get(secureListenAddr)
resp, err := httpGet(ctx, secureListenAddr)
Expect(err).ToNot(HaveOccurred())
Expect(resp.StatusCode).To(Equal(http.StatusOK))

Expand All @@ -938,19 +941,19 @@ var _ = Describe("Server", func() {
Expect(srv.Start(ctx)).To(Succeed())
}()

_, err := client.Get(listenAddr)
_, err := httpGet(ctx, listenAddr)
Expect(err).ToNot(HaveOccurred())
_, err = client.Get(secureListenAddr)
_, err = httpGet(ctx, secureListenAddr)
Expect(err).ToNot(HaveOccurred())

cancel()

Eventually(func() error {
_, err := client.Get(listenAddr)
_, err := httpGet(ctx, listenAddr)
return err
}).Should(HaveOccurred())
Eventually(func() error {
_, err := client.Get(secureListenAddr)
_, err := httpGet(ctx, secureListenAddr)
return err
}).Should(HaveOccurred())
})
Expand Down

0 comments on commit 0678626

Please sign in to comment.