Skip to content

Commit

Permalink
net/http: add Transport.GetProxyConnectHeader
Browse files Browse the repository at this point in the history
Fixes golang#41048

Change-Id: I38e01605bffb6f85100c098051b0c416dd77f261
Reviewed-on: https://go-review.googlesource.com/c/go/+/259917
Trust: Brad Fitzpatrick <[email protected]>
Run-TryBot: Brad Fitzpatrick <[email protected]>
TryBot-Result: Go Bot <[email protected]>
Reviewed-by: Damien Neil <[email protected]>
  • Loading branch information
bradfitz committed Oct 6, 2020
1 parent db428ad commit 930fa89
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 1 deletion.
23 changes: 22 additions & 1 deletion src/net/http/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -240,8 +240,18 @@ type Transport struct {

// ProxyConnectHeader optionally specifies headers to send to
// proxies during CONNECT requests.
// To set the header dynamically, see GetProxyConnectHeader.
ProxyConnectHeader Header

// GetProxyConnectHeader optionally specifies a func to return
// headers to send to proxyURL during a CONNECT request to the
// ip:port target.
// If it returns an error, the Transport's RoundTrip fails with
// that error. It can return (nil, nil) to not add headers.
// If GetProxyConnectHeader is non-nil, ProxyConnectHeader is
// ignored.
GetProxyConnectHeader func(ctx context.Context, proxyURL *url.URL, target string) (Header, error)

// MaxResponseHeaderBytes specifies a limit on how many
// response bytes are allowed in the server's response
// header.
Expand Down Expand Up @@ -313,6 +323,7 @@ func (t *Transport) Clone() *Transport {
ResponseHeaderTimeout: t.ResponseHeaderTimeout,
ExpectContinueTimeout: t.ExpectContinueTimeout,
ProxyConnectHeader: t.ProxyConnectHeader.Clone(),
GetProxyConnectHeader: t.GetProxyConnectHeader,
MaxResponseHeaderBytes: t.MaxResponseHeaderBytes,
ForceAttemptHTTP2: t.ForceAttemptHTTP2,
WriteBufferSize: t.WriteBufferSize,
Expand Down Expand Up @@ -1623,7 +1634,17 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers
}
case cm.targetScheme == "https":
conn := pconn.conn
hdr := t.ProxyConnectHeader
var hdr Header
if t.GetProxyConnectHeader != nil {
var err error
hdr, err = t.GetProxyConnectHeader(ctx, cm.proxyURL, cm.targetAddr)
if err != nil {
conn.Close()
return nil, err
}
} else {
hdr = t.ProxyConnectHeader
}
if hdr == nil {
hdr = make(Header)
}
Expand Down
52 changes: 52 additions & 0 deletions src/net/http/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5174,6 +5174,57 @@ func TestTransportProxyConnectHeader(t *testing.T) {
}
}

func TestTransportProxyGetConnectHeader(t *testing.T) {
defer afterTest(t)
reqc := make(chan *Request, 1)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
if r.Method != "CONNECT" {
t.Errorf("method = %q; want CONNECT", r.Method)
}
reqc <- r
c, _, err := w.(Hijacker).Hijack()
if err != nil {
t.Errorf("Hijack: %v", err)
return
}
c.Close()
}))
defer ts.Close()

c := ts.Client()
c.Transport.(*Transport).Proxy = func(r *Request) (*url.URL, error) {
return url.Parse(ts.URL)
}
// These should be ignored:
c.Transport.(*Transport).ProxyConnectHeader = Header{
"User-Agent": {"foo"},
"Other": {"bar"},
}
c.Transport.(*Transport).GetProxyConnectHeader = func(ctx context.Context, proxyURL *url.URL, target string) (Header, error) {
return Header{
"User-Agent": {"foo2"},
"Other": {"bar2"},
}, nil
}

res, err := c.Get("https://dummy.tld/") // https to force a CONNECT
if err == nil {
res.Body.Close()
t.Errorf("unexpected success")
}
select {
case <-time.After(3 * time.Second):
t.Fatal("timeout")
case r := <-reqc:
if got, want := r.Header.Get("User-Agent"), "foo2"; got != want {
t.Errorf("CONNECT request User-Agent = %q; want %q", got, want)
}
if got, want := r.Header.Get("Other"), "bar2"; got != want {
t.Errorf("CONNECT request Other = %q; want %q", got, want)
}
}
}

var errFakeRoundTrip = errors.New("fake roundtrip")

type funcRoundTripper func()
Expand Down Expand Up @@ -5842,6 +5893,7 @@ func TestTransportClone(t *testing.T) {
ResponseHeaderTimeout: time.Second,
ExpectContinueTimeout: time.Second,
ProxyConnectHeader: Header{},
GetProxyConnectHeader: func(context.Context, *url.URL, string) (Header, error) { return nil, nil },
MaxResponseHeaderBytes: 1,
ForceAttemptHTTP2: true,
TLSNextProto: map[string]func(authority string, c *tls.Conn) RoundTripper{
Expand Down

0 comments on commit 930fa89

Please sign in to comment.